Skip to content

Commit 9b4623d

Browse files
authored
Fix CI failure for io.exciting.inputs owing to scipy constant update (#4244)
* add some types for io.exciting inputs * fix scale check
1 parent 3cb5ea4 commit 9b4623d

File tree

2 files changed

+76
-67
lines changed

2 files changed

+76
-67
lines changed

src/pymatgen/io/exciting/inputs.py

Lines changed: 65 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import itertools
6+
import warnings
67
import xml.etree.ElementTree as ET
78
from typing import TYPE_CHECKING
89

@@ -16,10 +17,13 @@
1617
from pymatgen.symmetry.bandstructure import HighSymmKpath
1718

1819
if TYPE_CHECKING:
19-
from pathlib import Path
20+
from typing import ClassVar, Literal
2021

22+
from numpy.typing import ArrayLike, NDArray
2123
from typing_extensions import Self
2224

25+
from pymatgen.util.typing import PathLike
26+
2327
__author__ = "Christian Vorwerk"
2428
__copyright__ = "Copyright 2016"
2529
__version__ = "1.0"
@@ -37,10 +41,18 @@ class ExcitingInput(MSONable):
3741
Attributes:
3842
structure (Structure): Associated Structure.
3943
title (str): Optional title string.
40-
lockxyz (numpy.ndarray): Lockxyz attribute for each site if available. A Nx3 array of booleans.
44+
lockxyz (Nx3 NDArray of booleans): Lockxyz attribute for each site if available.
4145
"""
4246

43-
def __init__(self, structure: Structure, title=None, lockxyz=None):
47+
# Conversion factor between Bohr radius and Angstrom
48+
bohr2ang: ClassVar[float] = const.value("Bohr radius") / const.value("Angstrom star")
49+
50+
def __init__(
51+
self,
52+
structure: Structure,
53+
title: str | None = None,
54+
lockxyz: ArrayLike | None = None,
55+
) -> None:
4456
"""
4557
Args:
4658
structure (Structure): Structure object.
@@ -52,22 +64,19 @@ def __init__(self, structure: Structure, title=None, lockxyz=None):
5264
if structure.is_ordered:
5365
site_properties = {}
5466
if lockxyz:
55-
site_properties["selective_dynamics"] = lockxyz
67+
site_properties["selective_dynamics"] = np.asarray(lockxyz)
5668
self.structure = structure.copy(site_properties=site_properties)
5769
self.title = structure.formula if title is None else title
5870
else:
5971
raise ValueError("Structure with partial occupancies cannot be converted into exciting input!")
6072

61-
# define conversion factor between Bohr radius and Angstrom
62-
bohr2ang = const.value("Bohr radius") / const.value("Angstrom star")
63-
6473
@property
65-
def lockxyz(self):
74+
def lockxyz(self) -> NDArray:
6675
"""Selective dynamics site properties."""
6776
return self.structure.site_properties.get("selective_dynamics")
6877

6978
@lockxyz.setter
70-
def lockxyz(self, lockxyz):
79+
def lockxyz(self, lockxyz: ArrayLike) -> NDArray:
7180
self.structure.add_site_property("selective_dynamics", lockxyz)
7281

7382
@classmethod
@@ -164,7 +173,7 @@ def from_str(cls, data: str) -> Self:
164173
return cls(structure_in, title_in, lockxyz)
165174

166175
@classmethod
167-
def from_file(cls, filename: str | Path) -> Self:
176+
def from_file(cls, filename: PathLike) -> Self:
168177
"""
169178
Args:
170179
filename: Filename.
@@ -178,32 +187,27 @@ def from_file(cls, filename: str | Path) -> Self:
178187

179188
def write_etree(
180189
self,
181-
celltype,
182-
cartesian=False,
183-
bandstr=False,
190+
celltype: Literal["unchanged", "conventional", "primitive"],
191+
cartesian: bool = False,
192+
bandstr: bool = False,
184193
symprec: float = 0.4,
185-
angle_tolerance=5,
194+
angle_tolerance: float = 5,
186195
**kwargs,
187-
):
188-
"""Write the exciting input parameters to an XML object.
196+
) -> ET.Element:
197+
"""Convert the exciting input parameters to an XML Element object.
189198
190199
Args:
191200
celltype (str): Choice of unit cell. Can be either the unit cell
192201
from self.structure ("unchanged"), the conventional cell
193202
("conventional"), or the primitive unit cell ("primitive").
194-
195203
cartesian (bool): Whether the atomic positions are provided in
196204
Cartesian or unit-cell coordinates. Default is False.
197-
198205
bandstr (bool): Whether the bandstructure path along the
199206
HighSymmKpath is included in the input file. Only supported if the
200207
celltype is set to "primitive". Default is False.
201-
202208
symprec (float): Tolerance for the symmetry finding. Default is 0.4.
203-
204209
angle_tolerance (float): Angle tolerance for the symmetry finding.
205-
Default is 5.
206-
210+
Default is 5.
207211
**kwargs: Additional parameters for the input file.
208212
209213
Returns:
@@ -297,77 +301,68 @@ def write_etree(
297301

298302
def write_string(
299303
self,
300-
celltype,
301-
cartesian=False,
302-
bandstr=False,
304+
celltype: Literal["unchanged", "conventional", "primitive"],
305+
cartesian: bool = False,
306+
bandstr: bool = False,
303307
symprec: float = 0.4,
304-
angle_tolerance=5,
308+
angle_tolerance: float = 5,
305309
**kwargs,
306-
):
307-
"""Write exciting input.xml as a string.
310+
) -> str:
311+
"""Convert exciting input to a string.
308312
309313
Args:
310314
celltype (str): Choice of unit cell. Can be either the unit cell
311-
from self.structure ("unchanged"), the conventional cell
312-
("conventional"), or the primitive unit cell ("primitive").
313-
315+
from self.structure ("unchanged"), the conventional cell
316+
("conventional"), or the primitive unit cell ("primitive").
314317
cartesian (bool): Whether the atomic positions are provided in
315-
Cartesian or unit-cell coordinates. Default is False.
316-
318+
Cartesian or unit-cell coordinates. Default is False.
317319
bandstr (bool): Whether the bandstructure path along the
318-
HighSymmKpath is included in the input file. Only supported if the
319-
celltype is set to "primitive". Default is False.
320-
320+
HighSymmKpath is included in the input file. Only supported if the
321+
celltype is set to "primitive". Default is False.
321322
symprec (float): Tolerance for the symmetry finding. Default is 0.4.
322-
323323
angle_tolerance (float): Angle tolerance for the symmetry finding.
324-
Default is 5.
325-
324+
Default is 5.
326325
**kwargs: Additional parameters for the input file.
327326
328327
Returns:
329-
String
328+
str
330329
"""
331330
try:
332331
root = self.write_etree(celltype, cartesian, bandstr, symprec, angle_tolerance, **kwargs)
333332
self._indent(root)
334333
# output should be a string not a bytes object
335334
string = ET.tostring(root).decode("UTF-8")
335+
336336
except Exception:
337337
raise ValueError("Incorrect celltype!")
338+
338339
return string
339340

340341
def write_file(
341342
self,
342-
celltype,
343-
filename,
344-
cartesian=False,
345-
bandstr=False,
343+
celltype: Literal["unchanged", "conventional", "primitive"],
344+
filename: PathLike,
345+
cartesian: bool = False,
346+
bandstr: bool = False,
346347
symprec: float = 0.4,
347-
angle_tolerance=5,
348+
angle_tolerance: float = 5,
348349
**kwargs,
349-
):
350+
) -> None:
350351
"""Write exciting input file.
351352
352353
Args:
353354
celltype (str): Choice of unit cell. Can be either the unit cell
354-
from self.structure ("unchanged"), the conventional cell
355-
("conventional"), or the primitive unit cell ("primitive").
356-
357-
filename (str): Filename for exciting input.
358-
355+
from self.structure ("unchanged"), the conventional cell
356+
("conventional"), or the primitive unit cell ("primitive").
357+
filename (PathLike): Filename for exciting input.
359358
cartesian (bool): Whether the atomic positions are provided in
360-
Cartesian or unit-cell coordinates. Default is False.
361-
359+
Cartesian or unit-cell coordinates. Default is False.
362360
bandstr (bool): Whether the bandstructure path along the
363-
HighSymmKpath is included in the input file. Only supported if the
364-
celltype is set to "primitive". Default is False.
365-
361+
HighSymmKpath is included in the input file. Only supported if the
362+
celltype is set to "primitive". Default is False.
366363
symprec (float): Tolerance for the symmetry finding. Default is 0.4.
367-
368364
angle_tolerance (float): Angle tolerance for the symmetry finding.
369-
Default is 5.
370-
365+
Default is 5.
371366
**kwargs: Additional parameters for the input file.
372367
"""
373368
try:
@@ -378,15 +373,15 @@ def write_file(
378373
except Exception:
379374
raise ValueError("Incorrect celltype!")
380375

381-
# Missing PrettyPrint option in the current version of xml.etree.cElementTree
382376
@staticmethod
383-
def _indent(elem, level=0):
377+
def _indent(elem: ET.Element, level: int = 0) -> None:
384378
"""
385-
Helper method to indent elements.
379+
Helper method to indent elements, as missing PrettyPrint option
380+
in the current version of xml.etree.cElementTree.
386381
387382
Args:
388-
elem:
389-
level:
383+
elem (ET.Element): The Element to process.
384+
level (int): The indentation level.
390385
"""
391386
i = "\n" + level * " "
392387
if len(elem):
@@ -401,19 +396,23 @@ def _indent(elem, level=0):
401396
elif level and (not elem.tail or not elem.tail.strip()):
402397
elem.tail = i
403398

404-
def _dicttoxml(self, paramdict_, element):
399+
def _dicttoxml(self, paramdict_: dict, element: ET.Element) -> None:
405400
for key, value in paramdict_.items():
406401
if isinstance(value, str) and key == "text()":
407402
element.text = value
403+
408404
elif isinstance(value, str):
409405
element.attrib[key] = value
406+
410407
elif isinstance(value, list):
411408
for item in value:
412409
self._dicttoxml(item, ET.SubElement(element, key))
410+
413411
elif isinstance(value, dict):
414412
if element.findall(key) == []:
415413
self._dicttoxml(value, ET.SubElement(element, key))
416414
else:
417415
self._dicttoxml(value, element.findall(key)[0])
416+
418417
else:
419-
print("cannot deal with", key, "=", value)
418+
warnings.warn(f"cannot deal with {key} = {value}", stacklevel=2)

tests/io/exciting/test_inputs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import math
4+
import re
35
from xml.etree import ElementTree as ET
46

57
from numpy.testing import assert_allclose
@@ -165,4 +167,12 @@ def test_param_dict(self):
165167
root = tree.getroot()
166168
ref_str = ET.tostring(root, encoding="unicode")
167169

168-
assert ref_str.strip() == test_str.strip()
170+
ref_list = ref_str.strip().split()
171+
test_list = test_str.strip().split()
172+
173+
# "scale" is float, direct compare might give surprising results
174+
ref_scale = float(re.search(r'scale="([-+]?\d*\.\d+|\d+)"', ref_list.pop(7))[1])
175+
test_scale = float(re.search(r'scale="([-+]?\d*\.\d+|\d+)"', test_list.pop(7))[1])
176+
177+
assert ref_list == test_list
178+
assert math.isclose(ref_scale, test_scale)

0 commit comments

Comments
 (0)