Skip to content

Commit a1b926f

Browse files
authored
Merge pull request #240 from dapper91/raw-deserialization
raw element deserialization bug fixed
2 parents bb1b697 + 0df84b8 commit a1b926f

File tree

5 files changed

+123
-15
lines changed

5 files changed

+123
-15
lines changed

pydantic_xml/element/element.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from enum import Enum
3-
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar
3+
from typing import Any, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar
44

55
from pydantic_xml.typedefs import NsMap
66

@@ -21,6 +21,13 @@ def tag(self) -> str:
2121
Xml element tag.
2222
"""
2323

24+
@property
25+
@abc.abstractmethod
26+
def nsmap(self) -> Optional[NsMap]:
27+
"""
28+
Xml element namespace map.
29+
"""
30+
2431
@abc.abstractmethod
2532
def is_empty(self) -> bool:
2633
"""
@@ -64,6 +71,15 @@ def pop_text(self) -> Optional[str]:
6471
:return: element text
6572
"""
6673

74+
@abc.abstractmethod
75+
def pop_tail(self) -> Optional[str]:
76+
"""
77+
Extracts the tail from the xml element.
78+
All subsequent calls return `None`.
79+
80+
:return: element tail
81+
"""
82+
6783
@abc.abstractmethod
6884
def pop_attrib(self, name: str) -> Optional[str]:
6985
"""
@@ -83,12 +99,22 @@ def pop_attributes(self) -> Optional[Dict[str, str]]:
8399
"""
84100

85101
@abc.abstractmethod
86-
def pop_element(self, tag: str, search_mode: 'SearchMode') -> Optional['XmlElementReader']:
102+
def pop_elements(self) -> Tuple['XmlElementReader', ...]:
103+
"""
104+
Extracts all sub-elements from the xml element.
105+
All subsequent calls return empty list.
106+
107+
:return: element sub-elements
108+
"""
109+
110+
@abc.abstractmethod
111+
def pop_element(self, tag: str, search_mode: 'SearchMode', remove: bool = False) -> Optional['XmlElementReader']:
87112
"""
88113
Extracts a sub-element from the xml element matching `tag`.
89114
90115
:param tag: element tag
91116
:param search_mode: element search mode
117+
:param remove: remove all entities from the element
92118
:return: sub-element
93119
"""
94120

@@ -280,7 +306,7 @@ def __init__(
280306
text: Optional[str] = None,
281307
tail: Optional[str] = None,
282308
attributes: Optional[Dict[str, str]] = None,
283-
elements: Optional[List['XmlElement[NativeElement]']] = None,
309+
elements: Optional[Iterable['XmlElement[NativeElement]']] = None,
284310
nsmap: Optional[NsMap] = None,
285311
sourceline: int = -1,
286312
):
@@ -290,7 +316,7 @@ def __init__(
290316
text=text,
291317
tail=tail,
292318
attrib=dict(attributes) if attributes is not None else None,
293-
elements=elements or [],
319+
elements=list(elements) if elements is not None else [],
294320
next_element_idx=0,
295321
)
296322
self._sourceline = sourceline
@@ -303,6 +329,10 @@ def get_sourceline(self) -> int:
303329
def tag(self) -> str:
304330
return self._tag
305331

332+
@property
333+
def nsmap(self) -> Optional[NsMap]:
334+
return self._nsmap
335+
306336
def create_snapshot(self) -> 'XmlElement[NativeElement]':
307337
element = self.__class__(
308338
tag=self._tag,
@@ -359,6 +389,11 @@ def pop_text(self) -> Optional[str]:
359389

360390
return result
361391

392+
def pop_tail(self) -> Optional[str]:
393+
result, self._state.tail = self._state.tail, None
394+
395+
return result
396+
362397
def pop_attrib(self, name: str) -> Optional[str]:
363398
return self._state.attrib.pop(name, None) if self._state.attrib else None
364399

@@ -367,10 +402,33 @@ def pop_attributes(self) -> Optional[Dict[str, str]]:
367402

368403
return result
369404

370-
def pop_element(self, tag: str, search_mode: 'SearchMode') -> Optional['XmlElement[NativeElement]']:
405+
def pop_elements(self) -> Tuple['XmlElement[NativeElement]', ...]:
406+
elements, self._state.elements = self._state.elements, []
407+
self._state.next_element_idx = 0
408+
409+
return tuple(elements)
410+
411+
def pop_element(
412+
self,
413+
tag: str,
414+
search_mode: 'SearchMode',
415+
remove: bool = False,
416+
) -> Optional['XmlElement[NativeElement]']:
371417
searcher: Searcher[NativeElement] = get_searcher(search_mode)
372418

373-
return searcher(self._state, tag, False, True)
419+
element = searcher(self._state, tag, False, True)
420+
if element is not None and remove:
421+
return self.__class__(
422+
tag=element.tag,
423+
nsmap=element.nsmap,
424+
text=element.pop_text(),
425+
tail=element.pop_tail(),
426+
attributes=element.pop_attributes(),
427+
elements=element.pop_elements(),
428+
sourceline=element.get_sourceline(),
429+
)
430+
431+
return element
374432

375433
def find_sub_element(self, path: Sequence[str], search_mode: 'SearchMode') -> PathT['XmlElement[NativeElement]']:
376434
assert len(path) > 0, "path can't be empty"

pydantic_xml/mypy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContex
2121
return self._pydantic_model_metaclass_marker_callback
2222
return super().get_metaclass_hook(fullname)
2323

24-
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
24+
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
2525
transformer = PydanticXmlModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
26-
return transformer.transform()
26+
transformer.transform()
2727

2828

2929
class PydanticXmlModelTransformer(PydanticModelTransformer):

pydantic_xml/serializers/factories/raw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def deserialize(
6363
if element is None:
6464
return None
6565

66-
if (sub_element := element.pop_element(self._element_name, self._search_mode)) is not None:
66+
if (sub_element := element.pop_element(self._element_name, self._search_mode, remove=True)) is not None:
6767
sourcemap[loc] = sub_element.get_sourceline()
6868
return sub_element.to_native()
6969
else:

tests/test_extra.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Dict
1+
from typing import Dict, Optional
22

33
import pydantic as pd
44
import pytest
55

66
from pydantic_xml import BaseXmlModel, attr, element, wrapped
7+
from pydantic_xml.element.native import ElementT
78
from tests.helpers import fmt_sourceline
89

910

@@ -230,3 +231,52 @@ class TestModel(BaseXmlModel, tag='model', extra='forbid', search_mode=search_mo
230231
},
231232
},
232233
]
234+
235+
236+
@pytest.mark.parametrize('search_mode', ['strict', 'ordered', 'unordered'])
237+
def test_raw_extra_forbid(search_mode: str):
238+
class TestModel(
239+
BaseXmlModel,
240+
tag='model',
241+
extra='forbid',
242+
arbitrary_types_allowed=True,
243+
search_mode=search_mode,
244+
):
245+
field1: ElementT = element("field1")
246+
field2: Optional[ElementT] = element("field2", default=None)
247+
248+
xml = '''
249+
<model>
250+
<field1>field value 1<nested>nested element field</nested></field1>
251+
<field2>field value 2</field2>
252+
<extra>undefined field<nested>nested undefined field</nested></extra>
253+
</model>
254+
'''
255+
with pytest.raises(pd.ValidationError) as exc:
256+
TestModel.from_xml(xml)
257+
258+
err = exc.value
259+
assert err.title == 'TestModel'
260+
assert err.error_count() == 2
261+
assert err.errors() == [
262+
{
263+
'input': 'undefined field',
264+
'loc': ('extra',),
265+
'msg': f'[line {fmt_sourceline(5)}]: Extra inputs are not permitted',
266+
'type': 'extra_forbidden',
267+
'ctx': {
268+
'orig': 'Extra inputs are not permitted',
269+
'sourceline': fmt_sourceline(5),
270+
},
271+
},
272+
{
273+
'input': 'nested undefined field',
274+
'loc': ('extra', 'nested'),
275+
'msg': f'[line {fmt_sourceline(5)}]: Extra inputs are not permitted',
276+
'type': 'extra_forbidden',
277+
'ctx': {
278+
'orig': 'Extra inputs are not permitted',
279+
'sourceline': fmt_sourceline(5),
280+
},
281+
},
282+
]

tests/test_raw.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def test_raw_primitive_element_serialization():
10-
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
10+
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
1111
element1: ElementT = element()
1212
element2: ElementT = element()
1313

@@ -43,7 +43,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
4343

4444

4545
def test_optional_raw_primitive_element_serialization():
46-
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
46+
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
4747
element1: Optional[ElementT] = element(default=None)
4848
element2: ElementT = element()
4949

@@ -66,7 +66,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
6666

6767

6868
def test_raw_element_homogeneous_collection_serialization():
69-
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
69+
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
7070
field1: List[ElementT] = element(tag="element1")
7171

7272
xml = '''
@@ -97,7 +97,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
9797

9898

9999
def test_raw_element_heterogeneous_collection_serialization():
100-
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
100+
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
101101
field1: Tuple[ElementT, ElementT] = element(tag="element1")
102102

103103
xml = '''
@@ -128,7 +128,7 @@ class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
128128

129129

130130
def test_wrapped_raw_element_serialization():
131-
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True):
131+
class TestModel(BaseXmlModel, tag='model', arbitrary_types_allowed=True, extra='forbid'):
132132
field1: ElementT = wrapped('wrapper', element(tag="element1"))
133133

134134
xml = '''

0 commit comments

Comments
 (0)