Skip to content

Commit 9e109df

Browse files
authored
Add Structure.to_(conventional|primitive|cell) methods (#3384)
* add Structure.to_(conventional|primitive|cell) methods * core/test_structure.py add add test_to_primitive and test_to_conventional
1 parent f415cc3 commit 9e109df

File tree

4 files changed

+100
-96
lines changed

4 files changed

+100
-96
lines changed

pymatgen/analysis/chemenv/connectivity/connected_components.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None):
6060
np.array([0, 0, 1], float),
6161
)
6262
vv /= np.linalg.norm(vv)
63-
midarc = midpoint + rad * dist * np.array([vv[0], vv[1]], float)
64-
xytext_offset = 0.1 * dist * np.array([vv[0], vv[1]], float)
63+
mid_arc = midpoint + rad * dist * np.array([vv[0], vv[1]], float)
64+
xy_text_offset = 0.1 * dist * np.array([vv[0], vv[1]], float)
6565

6666
if periodicity_vectors is not None and len(periodicity_vectors) == 1:
6767
if np.all(np.array(delta) == np.array(periodicity_vectors[0])) or np.all(
@@ -109,11 +109,11 @@ def draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None):
109109
)
110110
ax.annotate(
111111
delta,
112-
midarc,
112+
mid_arc,
113113
ha="center",
114114
va="center",
115115
xycoords="data",
116-
xytext=xytext_offset,
116+
xytext=xy_text_offset,
117117
textcoords="offset points",
118118
)
119119
seen[(u, v)] = rad

pymatgen/core/structure.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from fnmatch import fnmatch
2222
from inspect import isclass
2323
from io import StringIO
24-
from typing import TYPE_CHECKING, Any, Callable, Literal, SupportsIndex, cast
24+
from typing import TYPE_CHECKING, Any, Callable, Literal, SupportsIndex, cast, get_args
2525

2626
import numpy as np
2727
from monty.dev import deprecated
@@ -472,7 +472,7 @@ def add_site_property(self, property_name: str, values: Sequence | np.ndarray):
472472
"""
473473
if len(values) != len(self):
474474
raise ValueError(f"Values has length {len(values)} but there are {len(self)} sites! Must be same length.")
475-
for site, val in zip(self.sites, values):
475+
for site, val in zip(self, values):
476476
site.properties[property_name] = val
477477

478478
def remove_site_property(self, property_name: str):
@@ -551,7 +551,7 @@ def add_oxidation_state_by_site(self, oxidation_states: list[float]) -> None:
551551
f"Oxidation states of all sites must be specified, expected {len(self)} values, "
552552
f"got {len(oxidation_states)}"
553553
)
554-
for site, ox in zip(self.sites, oxidation_states):
554+
for site, ox in zip(self, oxidation_states):
555555
new_sp = {}
556556
for el, occu in site.species.items():
557557
sym = el.symbol
@@ -1243,7 +1243,7 @@ def is_3d_periodic(self) -> bool:
12431243
"""True if the Lattice is periodic in all directions."""
12441244
return self._lattice.is_3d_periodic
12451245

1246-
def get_space_group_info(self, symprec=1e-2, angle_tolerance=5.0) -> tuple[str, int]:
1246+
def get_space_group_info(self, symprec: float = 1e-2, angle_tolerance: float = 5.0) -> tuple[str, int]:
12471247
"""Convenience method to quickly get the spacegroup of a structure.
12481248
12491249
Args:
@@ -2035,7 +2035,7 @@ def get_reduced_structure(self, reduction_algo: Literal["niggli", "LLL"] = "nigg
20352035
elif reduction_algo == "LLL":
20362036
reduced_latt = self._lattice.get_lll_reduced_lattice()
20372037
else:
2038-
raise ValueError(f"Invalid reduction algo : {reduction_algo}")
2038+
raise ValueError(f"Invalid {reduction_algo=}")
20392039

20402040
if reduced_latt != self.lattice:
20412041
return self.__class__(
@@ -2874,6 +2874,55 @@ def from_file(cls, filename, primitive=False, sort=False, merge_tol=0.0, **kwarg
28742874
struct.__class__ = cls
28752875
return struct
28762876

2877+
CellType = Literal["primitive", "conventional"]
2878+
2879+
def to_cell(self, cell_type: IStructure.CellType, **kwargs) -> Structure:
2880+
"""Returns a cell based on the current structure.
2881+
2882+
Args:
2883+
cell_type ("primitive" | "conventional"): Whether to return a primitive or conventional cell.
2884+
kwargs: Any keyword supported by pymatgen.symmetry.analyzer.SpacegroupAnalyzer such as
2885+
symprec=0.01, angle_tolerance=5, international_monoclinic=True and keep_site_properties=False.
2886+
2887+
Returns:
2888+
Structure: with the requested cell type.
2889+
"""
2890+
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
2891+
2892+
valid_cell_types = get_args(IStructure.CellType)
2893+
if cell_type not in valid_cell_types:
2894+
raise ValueError(f"Invalid {cell_type=}, valid values are {valid_cell_types}")
2895+
2896+
method_keys = ["international_monoclinic", "keep_site_properties"]
2897+
method_kwargs = {key: kwargs.pop(key) for key in method_keys if key in kwargs}
2898+
2899+
sga = SpacegroupAnalyzer(self, **kwargs)
2900+
return getattr(sga, f"get_{cell_type}_standard_structure")(**method_kwargs)
2901+
2902+
def to_primitive(self, **kwargs) -> Structure:
2903+
"""Returns a primitive cell based on the current structure.
2904+
2905+
Args:
2906+
kwargs: Any keyword supported by pymatgen.symmetry.analyzer.SpacegroupAnalyzer such as
2907+
symprec=0.01, angle_tolerance=5, international_monoclinic=True and keep_site_properties=False.
2908+
2909+
Returns:
2910+
Structure: with the requested cell type.
2911+
"""
2912+
return self.to_cell("primitive", **kwargs)
2913+
2914+
def to_conventional(self, **kwargs) -> Structure:
2915+
"""Returns a conventional cell based on the current structure.
2916+
2917+
Args:
2918+
kwargs: Any keyword supported by pymatgen.symmetry.analyzer.SpacegroupAnalyzer such as
2919+
symprec=0.01, angle_tolerance=5, international_monoclinic=True and keep_site_properties=False.
2920+
2921+
Returns:
2922+
Structure: with the requested cell type.
2923+
"""
2924+
return self.to_cell("conventional", **kwargs)
2925+
28772926

28782927
class IMolecule(SiteCollection, MSONable):
28792928
"""Basic immutable Molecule object without periodicity. Essentially a

tests/core/test_structure.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from pymatgen.electronic_structure.core import Magmom
3030
from pymatgen.io.ase import AseAtomsAdaptor
31+
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
3132
from pymatgen.util.testing import TEST_FILES_DIR, PymatgenTest
3233

3334
try:
@@ -196,9 +197,7 @@ def test_as_dict(self):
196197
assert "sites" in struct.as_dict()
197198
d = self.propertied_structure.as_dict()
198199
assert d["sites"][0]["properties"]["magmom"] == 5
199-
coords = []
200-
coords.append([0, 0, 0])
201-
coords.append([0.75, 0.5, 0.75])
200+
coords = [[0, 0, 0], [0.75, 0.5, 0.75]]
202201
struct = IStructure(
203202
self.lattice,
204203
[
@@ -1349,7 +1348,7 @@ def test_merge_sites(self):
13491348
navs2.insert(0, "Na", coords[0], properties={"prop1": 100.0})
13501349
navs2.merge_sites(mode="a")
13511350
assert len(navs2) == 12
1352-
assert 51.5 in [itr.properties["prop1"] for itr in navs2.sites]
1351+
assert 51.5 in [itr.properties["prop1"] for itr in navs2]
13531352

13541353
def test_properties(self):
13551354
assert self.struct.num_sites == len(self.struct)
@@ -1633,15 +1632,35 @@ def test_from_prototype(self):
16331632
struct = Structure.from_prototype(prototype, ["Cs", "Cl"], a=5)
16341633
assert struct.lattice.is_orthogonal
16351634

1635+
def test_to_primitive(self):
1636+
struct = Structure.from_file(f"{TEST_FILES_DIR}/orci_1010.cif")
1637+
primitive = struct.to_primitive()
1638+
1639+
assert struct != primitive
1640+
sga = SpacegroupAnalyzer(struct)
1641+
assert primitive == sga.get_primitive_standard_structure()
1642+
assert struct.formula == "Mn1 B4"
1643+
assert primitive.formula == "Mn1 B4"
1644+
1645+
def test_to_conventional(self):
1646+
struct = Structure.from_file(f"{TEST_FILES_DIR}/bcc_1927.cif")
1647+
conventional = struct.to_conventional()
1648+
1649+
assert struct != conventional
1650+
sga = SpacegroupAnalyzer(struct)
1651+
assert conventional == sga.get_conventional_standard_structure()
1652+
assert struct.formula == "Dy8 Sb6"
1653+
assert conventional.formula == "Dy16 Sb12"
1654+
16361655

16371656
class TestIMolecule(PymatgenTest):
16381657
def setUp(self):
16391658
coords = [
16401659
[0, 0, 0],
1641-
[0, 0, 1.089000],
1642-
[1.026719, 0, -0.363000],
1643-
[-0.513360, -0.889165, -0.363000],
1644-
[-0.513360, 0.889165, -0.363000],
1660+
[0, 0, 1.089],
1661+
[1.026719, 0, -0.363],
1662+
[-0.513360, -0.889165, -0.363],
1663+
[-0.513360, 0.889165, -0.363],
16451664
]
16461665
self.coords = coords
16471666
self.mol = Molecule(["C", "H", "H", "H", "H"], coords)

tests/symmetry/test_analyzer.py

Lines changed: 15 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -340,85 +340,21 @@ def test_get_conventional_standard_structure(self):
340340
assert conv.site_properties.get("magmom") is None
341341

342342
def test_get_primitive_standard_structure(self):
343-
structure = Structure.from_file(f"{TEST_FILES_DIR}/bcc_1927.cif")
344-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
345-
prim = spga.get_primitive_standard_structure()
346-
assert prim.lattice.alpha == approx(109.47122063400001)
347-
assert prim.lattice.beta == approx(109.47122063400001)
348-
assert prim.lattice.gamma == approx(109.47122063400001)
349-
assert prim.lattice.a == approx(7.9657251015812145)
350-
assert prim.lattice.b == approx(7.9657251015812145)
351-
assert prim.lattice.c == approx(7.9657251015812145)
352-
353-
structure = Structure.from_file(f"{TEST_FILES_DIR}/btet_1915.cif")
354-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
355-
prim = spga.get_primitive_standard_structure()
356-
assert prim.lattice.alpha == approx(105.015053349)
357-
assert prim.lattice.beta == approx(105.015053349)
358-
assert prim.lattice.gamma == approx(118.80658411899999)
359-
assert prim.lattice.a == approx(4.1579321075608791)
360-
assert prim.lattice.b == approx(4.1579321075608791)
361-
assert prim.lattice.c == approx(4.1579321075608791)
362-
363-
structure = Structure.from_file(f"{TEST_FILES_DIR}/orci_1010.cif")
364-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
365-
prim = spga.get_primitive_standard_structure()
366-
assert prim.lattice.alpha == approx(134.78923546600001)
367-
assert prim.lattice.beta == approx(105.856239333)
368-
assert prim.lattice.gamma == approx(91.276341676000001)
369-
assert prim.lattice.a == approx(3.8428217771014852)
370-
assert prim.lattice.b == approx(3.8428217771014852)
371-
assert prim.lattice.c == approx(3.8428217771014852)
372-
373-
structure = Structure.from_file(f"{TEST_FILES_DIR}/orcc_1003.cif")
374-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
375-
prim = spga.get_primitive_standard_structure()
376-
assert prim.lattice.alpha == approx(90)
377-
assert prim.lattice.beta == approx(90)
378-
assert prim.lattice.gamma == approx(164.985257335)
379-
assert prim.lattice.a == approx(15.854897098324196)
380-
assert prim.lattice.b == approx(15.854897098324196)
381-
assert prim.lattice.c == approx(3.99648651)
382-
383-
structure = Structure.from_file(f"{TEST_FILES_DIR}/orac_632475.cif")
384-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
385-
prim = spga.get_primitive_standard_structure()
386-
assert prim.lattice.alpha == approx(90)
387-
assert prim.lattice.beta == approx(90)
388-
assert prim.lattice.gamma == approx(144.40557588533386)
389-
assert prim.lattice.a == approx(5.2005185662155391)
390-
assert prim.lattice.b == approx(5.2005185662155391)
391-
assert prim.lattice.c == approx(3.5372412099999999)
392-
393-
structure = Structure.from_file(f"{TEST_FILES_DIR}/monoc_1028.cif")
394-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
395-
prim = spga.get_primitive_standard_structure()
396-
assert prim.lattice.alpha == approx(63.579155761999999)
397-
assert prim.lattice.beta == approx(116.42084423747779)
398-
assert prim.lattice.gamma == approx(148.47965136208569)
399-
assert prim.lattice.a == approx(7.2908007159612325)
400-
assert prim.lattice.b == approx(7.2908007159612325)
401-
assert prim.lattice.c == approx(6.8743926325200002)
402-
403-
structure = Structure.from_file(f"{TEST_FILES_DIR}/hex_1170.cif")
404-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
405-
prim = spga.get_primitive_standard_structure()
406-
assert prim.lattice.alpha == approx(90)
407-
assert prim.lattice.beta == approx(90)
408-
assert prim.lattice.gamma == approx(120)
409-
assert prim.lattice.a == approx(3.699919902005897)
410-
assert prim.lattice.b == approx(3.699919902005897)
411-
assert prim.lattice.c == approx(6.9779585500000003)
412-
413-
structure = Structure.from_file(f"{TEST_FILES_DIR}/rhomb_3478_conv.cif")
414-
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
415-
prim = spga.get_primitive_standard_structure()
416-
assert prim.lattice.alpha == approx(28.049186140546812)
417-
assert prim.lattice.beta == approx(28.049186140546812)
418-
assert prim.lattice.gamma == approx(28.049186140546812)
419-
assert prim.lattice.a == approx(5.9352627428399982)
420-
assert prim.lattice.b == approx(5.9352627428399982)
421-
assert prim.lattice.c == approx(5.9352627428399982)
343+
for file_name, expected_angles, expected_abc in [
344+
("bcc_1927.cif", [109.47122063400001] * 3, [7.9657251015812145] * 3),
345+
("btet_1915.cif", [105.015053349, 105.015053349, 118.80658411899999], [4.1579321075608791] * 3),
346+
("orci_1010.cif", [134.78923546600001, 105.856239333, 91.276341676000001], [3.8428217771014852] * 3),
347+
("orcc_1003.cif", [90, 90, 164.985257335], [15.854897098324196, 15.854897098324196, 3.99648651]),
348+
("orac_632475.cif", [90, 90, 144.40557588533386], [5.2005185662, 5.20051856621, 3.53724120999]),
349+
("monoc_1028.cif", [63.579155761, 116.420844, 148.479651], [7.2908007159, 7.29080071, 6.87439263]),
350+
("hex_1170.cif", [90, 90, 120], [3.699919902005897, 3.699919902005897, 6.9779585500000003]),
351+
("rhomb_3478_conv.cif", [28.0491861, 28.049186140, 28.049186140], [5.93526274, 5.9352627428, 5.9352627428]),
352+
]:
353+
structure = Structure.from_file(f"{TEST_FILES_DIR}/{file_name}")
354+
spga = SpacegroupAnalyzer(structure, symprec=1e-2)
355+
prim = spga.get_primitive_standard_structure()
356+
assert prim.lattice.angles == approx(expected_angles)
357+
assert prim.lattice.abc == approx(expected_abc)
422358

423359
structure = Structure.from_file(f"{TEST_FILES_DIR}/rhomb_3478_conv.cif")
424360
structure.add_site_property("magmom", [1.0] * len(structure))

0 commit comments

Comments
 (0)