Skip to content

Commit 529eceb

Browse files
authored
Equalize Phonon(Dos|BS)Plotter colors, allow custom plot settings per-DOS (#3514)
* make temp optional to allow falling back to t if temp not passed * allow passing arbitrary kwargs into PhononDosPlotter.add_dos for use in e.g. color customization * change default line colors of PhononDosPlotter and PhononBSPlotter to tab:10 tab:blue and tab:orange in particular * fix overlapping an non-symbol band struct x-labels label.replace("GAMMA", "Γ").replace("DELTA", "Δ") * change colors from tab10 back to regular red/blue * plot_compare add keyword other_kwargs to customize 2nd set of band lines
1 parent d860b0a commit 529eceb

File tree

2 files changed

+44
-58
lines changed

2 files changed

+44
-58
lines changed

pymatgen/phonon/dos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _positive_densities(self) -> np.ndarray:
161161
"""Numpy array containing the list of densities corresponding to positive frequencies."""
162162
return self.densities[self.ind_zero_freq :]
163163

164-
def cv(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
164+
def cv(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
165165
"""Constant volume specific heat C_v at temperature T obtained from the integration of the DOS.
166166
Only positive frequencies will be used.
167167
Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number
@@ -198,7 +198,7 @@ def csch2(x):
198198

199199
return cv
200200

201-
def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
201+
def entropy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
202202
"""Vibrational entropy at temperature T obtained from the integration of the DOS.
203203
Only positive frequencies will be used.
204204
Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number
@@ -233,7 +233,7 @@ def entropy(self, temp: float, structure: Structure | None = None, **kwargs) ->
233233

234234
return entropy
235235

236-
def internal_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
236+
def internal_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
237237
"""Phonon contribution to the internal energy at temperature T obtained from the integration of the DOS.
238238
Only positive frequencies will be used.
239239
Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number
@@ -268,7 +268,7 @@ def internal_energy(self, temp: float, structure: Structure | None = None, **kwa
268268

269269
return e_phonon
270270

271-
def helmholtz_free_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
271+
def helmholtz_free_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
272272
"""Phonon contribution to the Helmholtz free energy at temperature T obtained from the integration of the DOS.
273273
Only positive frequencies will be used.
274274
Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number

pymatgen/phonon/plotter.py

Lines changed: 40 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import matplotlib.pyplot as plt
1010
import numpy as np
11-
import palettable
1211
import scipy.constants as const
1312
from matplotlib.collections import LineCollection
1413
from monty.json import jsanitize
@@ -95,19 +94,18 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None:
9594
)
9695
self.stack = stack
9796
self.sigma = sigma
98-
self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {}
97+
self._doses: dict[str, dict[str, np.ndarray]] = {}
9998

100-
def add_dos(self, label: str, dos: PhononDos) -> None:
99+
def add_dos(self, label: str, dos: PhononDos, **kwargs: Any) -> None:
101100
"""Adds a dos for plotting.
102101
103102
Args:
104-
label:
105-
label for the DOS. Must be unique.
106-
dos:
107-
PhononDos object
103+
label (str): label for the DOS. Must be unique.
104+
dos (PhononDos): DOS object
105+
**kwargs: kwargs supported by matplotlib.pyplot.plot
108106
"""
109107
densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
110-
self._doses[label] = {"frequencies": dos.frequencies, "densities": densities}
108+
self._doses[label] = {"frequencies": dos.frequencies, "densities": densities, **kwargs}
111109

112110
def add_dos_dict(self, dos_dict: dict, key_sort_func=None) -> None:
113111
"""Add a dictionary of doses, with an optional sorting function for the
@@ -160,8 +158,6 @@ def get_plot(
160158
n_colors = max(3, len(self._doses))
161159
n_colors = min(9, n_colors)
162160

163-
colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors
164-
165161
y = None
166162
all_densities = []
167163
all_frequencies = []
@@ -186,18 +182,14 @@ def get_plot(
186182
all_densities.reverse()
187183
all_frequencies.reverse()
188184
all_pts = []
185+
colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive")
189186
for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)):
187+
color = self._doses[key].get("color", colors[idx % n_colors])
190188
all_pts.extend(list(zip(frequencies, densities)))
191189
if self.stack:
192-
ax.fill(frequencies, densities, color=colors[idx % n_colors], label=str(key))
190+
ax.fill(frequencies, densities, color=color, label=str(key))
193191
else:
194-
ax.plot(
195-
frequencies,
196-
densities,
197-
color=colors[idx % n_colors],
198-
label=str(key),
199-
linewidth=3,
200-
)
192+
ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3)
201193

202194
if xlim:
203195
ax.set_xlim(xlim)
@@ -297,13 +289,9 @@ def _make_ticks(self, ax: Axes) -> Axes:
297289
ax.set_xticks(uniq_d)
298290
ax.set_xticklabels(uniq_l)
299291

300-
for idx in range(len(ticks["label"])):
301-
if ticks["label"][idx] is not None:
302-
# don't print the same label twice
303-
if idx != 0:
304-
ax.axvline(ticks["distance"][idx], color="k")
305-
else:
306-
ax.axvline(ticks["distance"][idx], color="k")
292+
for idx, label in enumerate(ticks["label"]):
293+
if label is not None:
294+
ax.axvline(ticks["distance"][idx], color="k")
307295
return ax
308296

309297
def bs_plot_data(self) -> dict[str, Any]:
@@ -356,14 +344,11 @@ def get_plot(
356344
ax = pretty_plot(12, 8)
357345

358346
data = self.bs_plot_data()
359-
for d in range(len(data["distances"])):
347+
kwargs.setdefault("color", "blue")
348+
for dists, freqs in zip(data["distances"], data["frequency"]):
360349
for idx in range(self._nb_bands):
361-
ax.plot(
362-
data["distances"][d],
363-
[data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))],
364-
"b-",
365-
**kwargs,
366-
)
350+
ys = [freqs[idx][j] * u.factor for j in range(len(dists))]
351+
ax.plot(dists, ys, **kwargs)
367352

368353
self._make_ticks(ax)
369354

@@ -598,15 +583,15 @@ def get_ticks(self) -> dict[str, list]:
598583
label0 = f"${label0}$"
599584
tick_labels.pop()
600585
tick_distance.pop()
601-
tick_labels.append(f"{label0}$\\mid${label1}")
586+
tick_labels.append(f"{label0}|{label1}")
602587
elif point.label.startswith("\\") or point.label.find("_") != -1:
603588
tick_labels.append(f"${point.label}$")
604589
else:
605-
# map atomate2 all-upper-case point.labels to pretty LaTeX
606-
label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label)
607-
tick_labels.append(label)
590+
tick_labels.append(point.label)
608591
previous_label = point.label
609592
previous_branch = this_branch
593+
# map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols
594+
tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ").replace("SIGMA", "Σ") for label in tick_labels]
610595
return {"distance": tick_distance, "label": tick_labels}
611596

612597
def plot_compare(
@@ -616,6 +601,7 @@ def plot_compare(
616601
labels: tuple[str, str] | None = None,
617602
legend_kwargs: dict | None = None,
618603
on_incompatible: Literal["raise", "warn", "ignore"] = "raise",
604+
other_kwargs: dict | None = None,
619605
**kwargs,
620606
) -> Axes:
621607
"""Plot two band structure for comparison. One is in red the other in blue.
@@ -634,14 +620,16 @@ def plot_compare(
634620
legend_kwargs: dict[str, Any]: kwargs passed to ax.legend().
635621
on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible.
636622
Defaults to 'raise'.
623+
other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot().
637624
**kwargs: passed to ax.plot().
638625
639626
Returns:
640627
a matplotlib object with both band structures
641628
"""
642629
unit = freq_units(units)
643630
legend_kwargs = legend_kwargs or {}
644-
legend_kwargs.setdefault("fontsize", 22)
631+
other_kwargs = other_kwargs or {}
632+
legend_kwargs.setdefault("fontsize", 20)
645633

646634
data_orig = self.bs_plot_data()
647635
data = other_plotter.bs_plot_data()
@@ -656,24 +644,22 @@ def plot_compare(
656644
line_width = kwargs.setdefault("linewidth", 1)
657645

658646
ax = self.get_plot(units=units, **kwargs)
659-
for band_idx in range(other_plotter._nb_bands):
660-
for dist_idx in range(len(data_orig["distances"])):
661-
ax.plot(
662-
data_orig["distances"][dist_idx],
663-
[
664-
data["frequency"][dist_idx][band_idx][j] * unit.factor
665-
for j in range(len(data_orig["distances"][dist_idx]))
666-
],
667-
"r-",
668-
**kwargs,
669-
)
670647

671-
# add legend showing which color correspond to which band structure
672-
if labels is None and self._label and other_plotter._label:
673-
labels = (self._label, other_plotter._label)
674-
if labels:
675-
ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width)
676-
ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width)
648+
kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color
649+
650+
for band_idx in range(other_plotter._nb_bands):
651+
for dist_idx, dists in enumerate(data_orig["distances"]):
652+
xs = dists
653+
ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))]
654+
ax.plot(xs, ys, **(kwargs | other_kwargs))
655+
656+
# add legend showing which color corresponds to which band structure
657+
if labels or (self._label and other_plotter._label):
658+
color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color()
659+
label_self, label_other = labels or (self._label, other_plotter._label)
660+
ax.plot([], [], label=label_self, linewidth=2 * line_width, color=color_self)
661+
linestyle = other_kwargs.get("linestyle", "-")
662+
ax.plot([], [], label=label_other, linewidth=2 * line_width, color=color_other, linestyle=linestyle)
677663
ax.legend(**legend_kwargs)
678664

679665
return ax

0 commit comments

Comments
 (0)