Skip to content

Commit

Permalink
Added option to export bandsplot data
Browse files Browse the repository at this point in the history
  • Loading branch information
lllangWV committed Aug 14, 2024
1 parent fce9c10 commit acc721c
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 15 deletions.
163 changes: 154 additions & 9 deletions pyprocar/plotter/ebs_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

import os
import yaml
import json
from typing import List

import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib as mpl
Expand Down Expand Up @@ -52,6 +53,9 @@ def __init__(self,
self.kpath = kpath
self.spins = spins
self.kdirect=kdirect

self.values_dict={}

if self.spins is None:
self.spins = range(self.ebs.nspins)
self.nspins = len(self.spins)
Expand Down Expand Up @@ -128,7 +132,7 @@ def plot_bands(self):
None.
"""

values_dict={}

for ispin in self.spins:
if len(self.spins)==1:
Expand All @@ -147,12 +151,31 @@ def plot_bands(self):
)
self.handles.append(handle)




band_name=f'band-{iband}_spinChannel-{str(ispin)}'
values_dict[f'bands_{band_name}']=self.ebs.bands[:, iband, ispin]

values_dict['kpath_values']=self.x
tick_names=[]
for i,x in enumerate(self.x):
tick_name=''
for i_tick, tick_position in enumerate(self.kpath.tick_positions):
if i == tick_position:
tick_name=self.kpath.tick_names[i_tick]
tick_names.append(tick_name)
values_dict['kpath_tick_names']=tick_names
self.values_dict=values_dict


def plot_scatter(self,
width_mask:np.ndarray=None,
color_mask:np.ndarray=None,
spins:List[int]=None,
width_weights:np.ndarray=None,
color_weights:np.ndarray=None,
labels=None,
):
"""A method to plot a scatter plot
Expand All @@ -169,6 +192,8 @@ def plot_scatter(self,
color_weights : np.ndarray, optional
The color weights at each point, by default None
"""
values_dict={}

if spins is None:
spins = range(self.ebs.nspins)
if self.ebs.is_non_collinear:
Expand Down Expand Up @@ -214,8 +239,7 @@ def plot_scatter(self,
self.x,
mbands[:, iband, ispin],
c=color,
s=width_weights[:, iband, ispin].round(
2)*markersize[ispin],
s=width_weights[:, iband, ispin]*markersize[ispin],
# edgecolors="none",
linewidths=self.config.linewidth[ispin],
cmap=self.config.cmap,
Expand All @@ -228,8 +252,8 @@ def plot_scatter(self,
sc = self.ax.scatter(
self.x,
mbands[:, iband, ispin],
c=color_weights[:, iband, ispin].round(2),
s=width_weights[:, iband, ispin].round(2)*markersize[ispin],
c=color_weights[:, iband, ispin],
s=width_weights[:, iband, ispin]*markersize[ispin],
# edgecolors="none",
linewidths=self.config.linewidth[ispin],
cmap=self.config.cmap,
Expand All @@ -238,17 +262,40 @@ def plot_scatter(self,
marker=self.config.marker[ispin],
alpha=self.config.opacity[ispin],
)

band_name=f'band-{iband}_spinChannel-{str(ispin)}'
values_dict[f'bands__{band_name}']=self.ebs.bands[:, iband, ispin]
projection_name=labels[0]
if color_weights is not None:
values_dict[f'projections__{projection_name}__{band_name}']=color_weights[:, iband, ispin]



if self.config.plot_color_bar and color_weights is not None:
self.cb = self.fig.colorbar(sc, ax=self.ax)


values_dict['kpath_values']=self.x
tick_names=[]
for i,x in enumerate(self.x):
tick_name=''
for i_tick, tick_position in enumerate(self.kpath.tick_positions):
if i == tick_position:
tick_name=self.kpath.tick_names[i_tick]
tick_names.append(tick_name)
values_dict['kpath_tick_names']=tick_names

self.values_dict=values_dict

def plot_parameteric(
self,
spins:List[int]=None,
width_mask:np.ndarray=None,
color_mask:np.ndarray=None,
width_weights:np.ndarray=None,
color_weights:np.ndarray=None,
elimit:List[float]=None
elimit:List[float]=None,
labels=None
):
"""A method to plot a scatter plot
Expand All @@ -267,6 +314,7 @@ def plot_parameteric(
elimit : List[float], optional
Energy range to plot. Only useful if the band index is written
"""
values_dict={}

# if there is only a single k-point the method for atomic
# levels will be called to fake another kpoint and then
Expand Down Expand Up @@ -338,6 +386,12 @@ def plot_parameteric(
width_weights[:, iband, ispin]*linewidth[ispin])
lc.set_linestyle(self.config.linestyle[ispin])
handle = self.ax.add_collection(lc)

band_name=f'band-{iband}_spinChannel-{str(ispin)}'
projection_name=labels[0]
values_dict[F'bands__{band_name}']=self.ebs.bands[:, iband, ispin]
if color_weights is not None:
values_dict[F'projections__{projection_name}__{band_name}']=color_weights[:, iband, ispin]
# if color_weights is not None:
# handle.set_color(color_map[iweight][:-1].lower())
handle.set_linewidth(linewidth)
Expand All @@ -346,9 +400,21 @@ def plot_parameteric(
if self.config.plot_color_bar and color_weights is not None:
self.cb = self.fig.colorbar(lc, ax=self.ax)

values_dict['kpath_values']=self.x
tick_names=[]
for i,x in enumerate(self.x):
tick_name=''
for i_tick, tick_position in enumerate(self.kpath.tick_positions):
if i == tick_position:
tick_name=self.kpath.tick_names[i_tick]
tick_names.append(tick_name)
values_dict['kpath_tick_names']=tick_names
self.values_dict=values_dict

def plot_parameteric_overlay(self,
spins:List[int]=None,
weights:np.ndarray=None,
labels:str=None,
):
"""A method to plot the parametric overlay
Expand All @@ -359,6 +425,7 @@ def plot_parameteric_overlay(self,
weights : np.ndarray, optional
The weights of each point, by default None
"""
values_dict={}

linewidth = [l*7 for l in self.config.linewidth]
if type(self.config.cmap) is str:
Expand Down Expand Up @@ -395,20 +462,40 @@ def plot_parameteric_overlay(self,
lc.set_array(weight[:, iband, ispin])
lc.set_linewidth(weight[:, iband, ispin]*linewidth[ispin])
handle = self.ax.add_collection(lc)


band_name=f'band-{iband}_spinChannel-{str(ispin)}'
projection_name=labels[iweight]
values_dict[f'bands__{band_name}']=self.ebs.bands[:, iband, ispin]
if weights is not None:
values_dict[f'projections__{projection_name}__{band_name}']=weight[:, iband, ispin]

handle.set_color(color_map[iweight][:-1].lower())
handle.set_linewidth(linewidth)
self.handles.append(handle)

if self.config.plot_color_bar:
self.cb = self.fig.colorbar(lc, ax=self.ax)

values_dict['kpath_values']=self.x
tick_names=[]
for i,x in enumerate(self.x):
tick_name=''
for i_tick, tick_position in enumerate(self.kpath.tick_positions):
if i == tick_position:
tick_name=self.kpath.tick_names[i_tick]
tick_names.append(tick_name)
values_dict['kpath_tick_names']=tick_names
self.values_dict=values_dict

def plot_atomic_levels(self,
spins:List[int]=None,
width_mask:np.ndarray=None,
color_mask:np.ndarray=None,
width_weights:np.ndarray=None,
color_weights:np.ndarray=None,
elimit:List[float]=None
elimit:List[float]=None,
labels=None
):
"""A method to plot a scatter plot
Expand Down Expand Up @@ -500,7 +587,8 @@ def plot_atomic_levels(self,
width_weights=width_weights,
color_mask=color_mask,
width_mask=width_mask,
spins=spins)
spins=spins,
labels=labels)

def set_xticks(self,
tick_positions:List[int]=None,
Expand Down Expand Up @@ -726,4 +814,61 @@ def save(self, filename:str='bands.pdf'):
plt.savefig(filename, dpi=self.config.dpi, bbox_inches="tight")
plt.clf()

def export_data(self,filename):
"""
This method will export the data to a csv file
Parameters
----------
filename : str
The file name to export the data to
Returns
-------
None
None
"""
possible_file_types=['csv','txt','json','dat']
file_type=filename.split('.')[-1]
if file_type not in possible_file_types:
raise ValueError(f"The file type must be {possible_file_types}")
if self.values_dict is None:
raise ValueError("The data has not been plotted yet")

column_names=list(self.values_dict.keys())
sorted_column_names=[None]*len(column_names)
index=0
for column_name in column_names:
if 'kpath_values' == column_name:
sorted_column_names[index]=column_name
index+=1
if 'kpath_tick_names' == column_name:
sorted_column_names[index]=column_name
index+=1
for ispin in range(2):
for column_name in column_names:

if 'spinChannel-0' in column_name.split('_')[-1] and ispin==0:
sorted_column_names[index]=column_name
index+=1
if 'spinChannel-1' in column_name.split('_')[-1] and ispin==1:
sorted_column_names[index]=column_name
index+=1

column_names.sort()
if file_type=='csv':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, index=False)
elif file_type=='txt':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, sep='\t', index=False)
elif file_type=='json':
with open(filename, 'w') as outfile:
for key,value in self.values_dict.items():
self.values_dict[key]=value.tolist()
json.dump(self.values_dict, outfile)
elif file_type=='dat':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, sep=' ', index=False)


Loading

0 comments on commit acc721c

Please sign in to comment.