Skip to content

Commit 29138fa

Browse files
committed
fix: pydantic 2.12 compatibility.
1 parent 3fecbed commit 29138fa

File tree

4 files changed

+105
-98
lines changed

4 files changed

+105
-98
lines changed

pydantic_xml/fields.py

Lines changed: 91 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses as dc
22
import typing
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Dict, Optional, Union
44

55
import pydantic as pd
66
import pydantic_core as pdc
@@ -17,6 +17,7 @@
1717
'computed_element',
1818
'computed_entity',
1919
'element',
20+
'extract_field_xml_entity_info',
2021
'wrapped',
2122
'xml_field_serializer',
2223
'xml_field_validator',
@@ -37,83 +38,79 @@ class XmlEntityInfoP(typing.Protocol):
3738
wrapped: Optional['XmlEntityInfoP']
3839

3940

40-
class XmlEntityInfo(pd.fields.FieldInfo, XmlEntityInfoP):
41+
@dc.dataclass(frozen=True)
42+
class XmlEntityInfo(XmlEntityInfoP):
4143
"""
4244
Field xml meta-information.
4345
"""
4446

45-
__slots__ = ('location', 'path', 'ns', 'nsmap', 'nillable', 'wrapped')
47+
location: Optional[EntityLocation]
48+
path: Optional[str] = None
49+
ns: Optional[str] = None
50+
nsmap: Optional[NsMap] = None
51+
nillable: Optional[bool] = None
52+
wrapped: Optional[XmlEntityInfoP] = None
53+
54+
def __post_init__(self) -> None:
55+
if config.REGISTER_NS_PREFIXES and self.nsmap:
56+
utils.register_nsmap(self.nsmap)
4657

4758
@staticmethod
48-
def merge_field_infos(*field_infos: pd.fields.FieldInfo, **overrides: Any) -> pd.fields.FieldInfo:
49-
location, path, ns, nsmap, nillable, wrapped = None, None, None, None, None, None
50-
51-
for field_info in field_infos:
52-
if isinstance(field_info, XmlEntityInfo):
53-
location = field_info.location if field_info.location is not None else location
54-
path = field_info.path if field_info.path is not None else path
55-
ns = field_info.ns if field_info.ns is not None else ns
56-
nsmap = field_info.nsmap if field_info.nsmap is not None else nsmap
57-
nillable = field_info.nillable if field_info.nillable is not None else nillable
58-
wrapped = field_info.wrapped if field_info.wrapped is not None else wrapped
59-
60-
field_info = pd.fields.FieldInfo.merge_field_infos(*field_infos, **overrides)
61-
62-
xml_entity_info = XmlEntityInfo(
63-
location,
59+
def merge(*entity_infos: XmlEntityInfoP) -> 'XmlEntityInfo':
60+
location: Optional[EntityLocation] = None
61+
path: Optional[str] = None
62+
ns: Optional[str] = None
63+
nsmap: Optional[NsMap] = None
64+
nillable: Optional[bool] = None
65+
wrapped: Optional[XmlEntityInfoP] = None
66+
67+
for entity_info in entity_infos:
68+
if entity_info.location is not None:
69+
location = entity_info.location
70+
if entity_info.wrapped is not None:
71+
wrapped = entity_info.wrapped
72+
if entity_info.path is not None:
73+
path = entity_info.path
74+
if entity_info.ns is not None:
75+
ns = entity_info.ns
76+
if entity_info.nsmap is not None:
77+
nsmap = utils.merge_nsmaps(entity_info.nsmap, nsmap)
78+
if entity_info.nillable is not None:
79+
nillable = entity_info.nillable
80+
81+
return XmlEntityInfo(
82+
location=location,
6483
path=path,
6584
ns=ns,
6685
nsmap=nsmap,
6786
nillable=nillable,
68-
wrapped=wrapped if isinstance(wrapped, XmlEntityInfo) else None,
69-
**field_info._attributes_set,
87+
wrapped=wrapped,
7088
)
71-
xml_entity_info.metadata = field_info.metadata
72-
73-
return xml_entity_info
74-
75-
def __init__(
76-
self,
77-
location: Optional[EntityLocation],
78-
/,
79-
path: Optional[str] = None,
80-
ns: Optional[str] = None,
81-
nsmap: Optional[NsMap] = None,
82-
nillable: Optional[bool] = None,
83-
wrapped: Optional[pd.fields.FieldInfo] = None,
84-
**kwargs: Any,
85-
):
86-
wrapped_metadata: list[Any] = []
87-
if wrapped is not None:
88-
# copy arguments from the wrapped entity to let pydantic know how to process the field
89-
for entity_field_name in utils.get_slots(wrapped):
90-
if entity_field_name in pd.fields._FIELD_ARG_NAMES:
91-
kwargs[entity_field_name] = getattr(wrapped, entity_field_name)
92-
wrapped_metadata = wrapped.metadata
93-
94-
if kwargs.get('serialization_alias') is None:
95-
kwargs['serialization_alias'] = kwargs.get('alias')
96-
97-
if kwargs.get('validation_alias') is None:
98-
kwargs['validation_alias'] = kwargs.get('alias')
99-
100-
super().__init__(**kwargs)
101-
self.metadata.extend(wrapped_metadata)
102-
103-
self.location = location
104-
self.path = path
105-
self.ns = ns
106-
self.nsmap = nsmap
107-
self.nillable = nillable
108-
self.wrapped: Optional[XmlEntityInfoP] = wrapped if isinstance(wrapped, XmlEntityInfo) else None
109-
110-
if config.REGISTER_NS_PREFIXES and nsmap:
111-
utils.register_nsmap(nsmap)
89+
90+
91+
def extract_field_xml_entity_info(field_info: pd.fields.FieldInfo) -> Optional[XmlEntityInfoP]:
92+
entity_info_list = list(filter(lambda meta: isinstance(meta, XmlEntityInfo), field_info.metadata))
93+
if entity_info_list:
94+
entity_info = XmlEntityInfo.merge(*entity_info_list)
95+
else:
96+
entity_info = None
97+
98+
return entity_info
11299

113100

114101
_Unset: Any = pdc.PydanticUndefined
115102

116103

104+
def prepare_field_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
105+
if kwargs.get('serialization_alias') in (None, pdc.PydanticUndefined):
106+
kwargs['serialization_alias'] = kwargs.get('alias')
107+
108+
if kwargs.get('validation_alias') in (None, pdc.PydanticUndefined):
109+
kwargs['validation_alias'] = kwargs.get('alias')
110+
111+
return kwargs
112+
113+
117114
def attr(
118115
name: Optional[str] = None,
119116
ns: Optional[str] = None,
@@ -132,12 +129,15 @@ def attr(
132129
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
133130
"""
134131

135-
return XmlEntityInfo(
136-
EntityLocation.ATTRIBUTE,
137-
path=name, ns=ns, default=default, default_factory=default_factory,
138-
**kwargs,
132+
kwargs = prepare_field_kwargs(kwargs)
133+
134+
field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
135+
field_info.metadata.append(
136+
XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns),
139137
)
140138

139+
return field_info
140+
141141

142142
def element(
143143
tag: Optional[str] = None,
@@ -161,12 +161,15 @@ def element(
161161
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
162162
"""
163163

164-
return XmlEntityInfo(
165-
EntityLocation.ELEMENT,
166-
path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory,
167-
**kwargs,
164+
kwargs = prepare_field_kwargs(kwargs)
165+
166+
field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
167+
field_info.metadata.append(
168+
XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable),
168169
)
169170

171+
return field_info
172+
170173

171174
def wrapped(
172175
path: str,
@@ -190,12 +193,24 @@ def wrapped(
190193
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
191194
"""
192195

193-
return XmlEntityInfo(
194-
EntityLocation.WRAPPED,
195-
path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory,
196-
**kwargs,
196+
if entity is None:
197+
wrapped_entity_info = None
198+
field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
199+
else:
200+
wrapped_entity_info = extract_field_xml_entity_info(entity)
201+
field_info = pd.fields.FieldInfo._construct( # type: ignore[attr-defined]
202+
[
203+
pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs),
204+
entity,
205+
],
206+
)
207+
208+
field_info.metadata.append(
209+
XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=wrapped_entity_info),
197210
)
198211

212+
return field_info
213+
199214

200215
@dc.dataclass
201216
class ComputedXmlEntityInfo(pd.fields.ComputedFieldInfo, XmlEntityInfoP):
@@ -293,7 +308,7 @@ def computed_element(
293308

294309

295310
def xml_field_validator(
296-
field: str, /, *fields: str
311+
field: str, /, *fields: str,
297312
) -> 'Callable[[model.ValidatorFuncT[model.ModelT]], model.ValidatorFuncT[model.ModelT]]':
298313
"""
299314
Marks the method as a field xml validator.
@@ -312,7 +327,7 @@ def wrapper(func: model.ValidatorFuncT[model.ModelT]) -> model.ValidatorFuncT[mo
312327

313328

314329
def xml_field_serializer(
315-
field: str, /, *fields: str
330+
field: str, /, *fields: str,
316331
) -> 'Callable[[model.SerializerFuncT[model.ModelT]], model.SerializerFuncT[model.ModelT]]':
317332
"""
318333
Marks the method as a field xml serializer.

pydantic_xml/serializers/factories/model.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pydantic_xml as pxml
1010
from pydantic_xml import errors, utils
1111
from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill
12-
from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP
12+
from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP, extract_field_xml_entity_info
1313
from pydantic_xml.serializers.serializer import SearchMode, Serializer
1414
from pydantic_xml.typedefs import EntityLocation, Location, NsMap
1515
from pydantic_xml.utils import QName, merge_nsmaps, select_ns
@@ -79,15 +79,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '
7979
fields_validation_aliases[field_name] = validation_alias
8080

8181
field_info = model_cls.model_fields[field_name]
82-
if isinstance(field_info, pxml.model.XmlEntityInfo):
83-
entity_info = field_info
84-
else:
85-
entity_info = None
86-
8782
field_ctx = ctx.child(
8883
field_name=field_name,
8984
field_alias=field_alias,
90-
entity_info=entity_info,
85+
entity_info=extract_field_xml_entity_info(field_info),
9186
)
9287
fields_serializers[field_name] = Serializer.parse_core_schema(model_field['schema'], field_ctx)
9388

@@ -234,16 +229,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '
234229

235230
assert issubclass(model_cls, pxml.BaseXmlModel), "model class must be a BaseXmlModel subclass"
236231

237-
entity_info: Optional[XmlEntityInfoP]
238232
field_info = model_cls.model_fields['root']
239-
if isinstance(field_info, pxml.model.XmlEntityInfo):
240-
entity_info = field_info
241-
else:
242-
entity_info = None
243-
244233
field_ctx = ctx.child(
245234
field_name=None,
246-
entity_info=entity_info,
235+
entity_info=extract_field_xml_entity_info(field_info),
247236
)
248237
root_serializer = Serializer.parse_core_schema(root_schema, field_ctx)
249238

tests/test_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ def validate_model_before(cls, data: Dict[str, Any]) -> 'TestModel':
282282
}
283283

284284
@model_validator(mode='after')
285-
def validate_model_after(cls, obj: 'TestModel') -> 'TestModel':
286-
obj.field1 = obj.field1.replace(tzinfo=dt.timezone.utc)
287-
return obj
285+
def validate_model_after(self) -> 'TestModel':
286+
self.field1 = self.field1.replace(tzinfo=dt.timezone.utc)
287+
return self
288288

289289
@model_validator(mode='wrap')
290290
def validate_model_wrap(cls, obj: 'TestModel', handler: Callable) -> 'TestModel':

tests/test_misc.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from helpers import assert_xml_equal
77

88
from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element, errors, wrapped
9-
from pydantic_xml.fields import XmlEntityInfo
9+
from pydantic_xml.fields import EntityLocation, XmlEntityInfo
1010

1111

1212
def test_xml_declaration():
@@ -396,11 +396,14 @@ class TestModel(BaseXmlModel, tag='root'):
396396
] = element(tag='elm', lt=10)
397397

398398
field_info = TestModel.model_fields['element1']
399-
assert isinstance(field_info, XmlEntityInfo)
400-
assert field_info.metadata == [Ge(ge=0), Lt(lt=10)]
399+
assert field_info.metadata == [
400+
Lt(lt=10),
401+
XmlEntityInfo(EntityLocation.ELEMENT, path='elm'),
402+
Ge(ge=0),
403+
Lt(lt=100),
404+
XmlEntityInfo(EntityLocation.ELEMENT, nillable=True),
405+
]
401406
assert field_info.default == 0
402-
assert field_info.nillable == True
403-
assert field_info.path == 'elm'
404407

405408
TestModel.from_xml("<root><elm>0</elm></root>")
406409

0 commit comments

Comments
 (0)