8
8
9
9
import matplotlib .pyplot as plt
10
10
import numpy as np
11
- import palettable
12
11
import scipy .constants as const
13
12
from matplotlib .collections import LineCollection
14
13
from monty .json import jsanitize
@@ -95,19 +94,18 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None:
95
94
)
96
95
self .stack = stack
97
96
self .sigma = sigma
98
- self ._doses : dict [str , dict [Literal [ "frequencies" , "densities" ] , np .ndarray ]] = {}
97
+ self ._doses : dict [str , dict [str , np .ndarray ]] = {}
99
98
100
- def add_dos (self , label : str , dos : PhononDos ) -> None :
99
+ def add_dos (self , label : str , dos : PhononDos , ** kwargs : Any ) -> None :
101
100
"""Adds a dos for plotting.
102
101
103
102
Args:
104
- label:
105
- label for the DOS. Must be unique.
106
- dos:
107
- PhononDos object
103
+ label (str): label for the DOS. Must be unique.
104
+ dos (PhononDos): DOS object
105
+ **kwargs: kwargs supported by matplotlib.pyplot.plot
108
106
"""
109
107
densities = dos .get_smeared_densities (self .sigma ) if self .sigma else dos .densities
110
- self ._doses [label ] = {"frequencies" : dos .frequencies , "densities" : densities }
108
+ self ._doses [label ] = {"frequencies" : dos .frequencies , "densities" : densities , ** kwargs }
111
109
112
110
def add_dos_dict (self , dos_dict : dict , key_sort_func = None ) -> None :
113
111
"""Add a dictionary of doses, with an optional sorting function for the
@@ -160,8 +158,6 @@ def get_plot(
160
158
n_colors = max (3 , len (self ._doses ))
161
159
n_colors = min (9 , n_colors )
162
160
163
- colors = palettable .colorbrewer .qualitative .Set1_9 .mpl_colors
164
-
165
161
y = None
166
162
all_densities = []
167
163
all_frequencies = []
@@ -186,18 +182,14 @@ def get_plot(
186
182
all_densities .reverse ()
187
183
all_frequencies .reverse ()
188
184
all_pts = []
185
+ colors = ("blue" , "red" , "green" , "orange" , "purple" , "brown" , "pink" , "gray" , "olive" )
189
186
for idx , (key , frequencies , densities ) in enumerate (zip (keys , all_frequencies , all_densities )):
187
+ color = self ._doses [key ].get ("color" , colors [idx % n_colors ])
190
188
all_pts .extend (list (zip (frequencies , densities )))
191
189
if self .stack :
192
- ax .fill (frequencies , densities , color = colors [ idx % n_colors ] , label = str (key ))
190
+ ax .fill (frequencies , densities , color = color , label = str (key ))
193
191
else :
194
- ax .plot (
195
- frequencies ,
196
- densities ,
197
- color = colors [idx % n_colors ],
198
- label = str (key ),
199
- linewidth = 3 ,
200
- )
192
+ ax .plot (frequencies , densities , color = color , label = str (key ), linewidth = 3 )
201
193
202
194
if xlim :
203
195
ax .set_xlim (xlim )
@@ -297,13 +289,9 @@ def _make_ticks(self, ax: Axes) -> Axes:
297
289
ax .set_xticks (uniq_d )
298
290
ax .set_xticklabels (uniq_l )
299
291
300
- for idx in range (len (ticks ["label" ])):
301
- if ticks ["label" ][idx ] is not None :
302
- # don't print the same label twice
303
- if idx != 0 :
304
- ax .axvline (ticks ["distance" ][idx ], color = "k" )
305
- else :
306
- ax .axvline (ticks ["distance" ][idx ], color = "k" )
292
+ for idx , label in enumerate (ticks ["label" ]):
293
+ if label is not None :
294
+ ax .axvline (ticks ["distance" ][idx ], color = "k" )
307
295
return ax
308
296
309
297
def bs_plot_data (self ) -> dict [str , Any ]:
@@ -356,14 +344,11 @@ def get_plot(
356
344
ax = pretty_plot (12 , 8 )
357
345
358
346
data = self .bs_plot_data ()
359
- for d in range (len (data ["distances" ])):
347
+ kwargs .setdefault ("color" , "blue" )
348
+ for dists , freqs in zip (data ["distances" ], data ["frequency" ]):
360
349
for idx in range (self ._nb_bands ):
361
- ax .plot (
362
- data ["distances" ][d ],
363
- [data ["frequency" ][d ][idx ][j ] * u .factor for j in range (len (data ["distances" ][d ]))],
364
- "b-" ,
365
- ** kwargs ,
366
- )
350
+ ys = [freqs [idx ][j ] * u .factor for j in range (len (dists ))]
351
+ ax .plot (dists , ys , ** kwargs )
367
352
368
353
self ._make_ticks (ax )
369
354
@@ -598,15 +583,15 @@ def get_ticks(self) -> dict[str, list]:
598
583
label0 = f"${ label0 } $"
599
584
tick_labels .pop ()
600
585
tick_distance .pop ()
601
- tick_labels .append (f"{ label0 } $ \\ mid$ { label1 } " )
586
+ tick_labels .append (f"{ label0 } | { label1 } " )
602
587
elif point .label .startswith ("\\ " ) or point .label .find ("_" ) != - 1 :
603
588
tick_labels .append (f"${ point .label } $" )
604
589
else :
605
- # map atomate2 all-upper-case point.labels to pretty LaTeX
606
- label = dict (GAMMA = r"$\Gamma$" , DELTA = r"$\Delta$" ).get (point .label , point .label )
607
- tick_labels .append (label )
590
+ tick_labels .append (point .label )
608
591
previous_label = point .label
609
592
previous_branch = this_branch
593
+ # map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols
594
+ tick_labels = [label .replace ("GAMMA" , "Γ" ).replace ("DELTA" , "Δ" ).replace ("SIGMA" , "Σ" ) for label in tick_labels ]
610
595
return {"distance" : tick_distance , "label" : tick_labels }
611
596
612
597
def plot_compare (
@@ -616,6 +601,7 @@ def plot_compare(
616
601
labels : tuple [str , str ] | None = None ,
617
602
legend_kwargs : dict | None = None ,
618
603
on_incompatible : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
604
+ other_kwargs : dict | None = None ,
619
605
** kwargs ,
620
606
) -> Axes :
621
607
"""Plot two band structure for comparison. One is in red the other in blue.
@@ -634,14 +620,16 @@ def plot_compare(
634
620
legend_kwargs: dict[str, Any]: kwargs passed to ax.legend().
635
621
on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible.
636
622
Defaults to 'raise'.
623
+ other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot().
637
624
**kwargs: passed to ax.plot().
638
625
639
626
Returns:
640
627
a matplotlib object with both band structures
641
628
"""
642
629
unit = freq_units (units )
643
630
legend_kwargs = legend_kwargs or {}
644
- legend_kwargs .setdefault ("fontsize" , 22 )
631
+ other_kwargs = other_kwargs or {}
632
+ legend_kwargs .setdefault ("fontsize" , 20 )
645
633
646
634
data_orig = self .bs_plot_data ()
647
635
data = other_plotter .bs_plot_data ()
@@ -656,24 +644,22 @@ def plot_compare(
656
644
line_width = kwargs .setdefault ("linewidth" , 1 )
657
645
658
646
ax = self .get_plot (units = units , ** kwargs )
659
- for band_idx in range (other_plotter ._nb_bands ):
660
- for dist_idx in range (len (data_orig ["distances" ])):
661
- ax .plot (
662
- data_orig ["distances" ][dist_idx ],
663
- [
664
- data ["frequency" ][dist_idx ][band_idx ][j ] * unit .factor
665
- for j in range (len (data_orig ["distances" ][dist_idx ]))
666
- ],
667
- "r-" ,
668
- ** kwargs ,
669
- )
670
647
671
- # add legend showing which color correspond to which band structure
672
- if labels is None and self ._label and other_plotter ._label :
673
- labels = (self ._label , other_plotter ._label )
674
- if labels :
675
- ax .plot ([], [], "b-" , label = labels [0 ], linewidth = 3 * line_width )
676
- ax .plot ([], [], "r-" , label = labels [1 ], linewidth = 3 * line_width )
648
+ kwargs .setdefault ("color" , "red" ) # don't move this line up! it would mess up self.get_plot color
649
+
650
+ for band_idx in range (other_plotter ._nb_bands ):
651
+ for dist_idx , dists in enumerate (data_orig ["distances" ]):
652
+ xs = dists
653
+ ys = [data ["frequency" ][dist_idx ][band_idx ][j ] * unit .factor for j in range (len (dists ))]
654
+ ax .plot (xs , ys , ** (kwargs | other_kwargs ))
655
+
656
+ # add legend showing which color corresponds to which band structure
657
+ if labels or (self ._label and other_plotter ._label ):
658
+ color_self , color_other = ax .lines [0 ].get_color (), ax .lines [- 1 ].get_color ()
659
+ label_self , label_other = labels or (self ._label , other_plotter ._label )
660
+ ax .plot ([], [], label = label_self , linewidth = 2 * line_width , color = color_self )
661
+ linestyle = other_kwargs .get ("linestyle" , "-" )
662
+ ax .plot ([], [], label = label_other , linewidth = 2 * line_width , color = color_other , linestyle = linestyle )
677
663
ax .legend (** legend_kwargs )
678
664
679
665
return ax
0 commit comments