Skip to content

Commit 7ed40e3

Browse files
committed
clean up handling of v1/v2 branching
1 parent 4ede224 commit 7ed40e3

File tree

1 file changed

+53
-68
lines changed

1 file changed

+53
-68
lines changed

pydantic2ts/cli/script.py

Lines changed: 53 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,29 @@
66
import os
77
import shutil
88
import sys
9+
from contextlib import contextmanager
910
from importlib.util import module_from_spec, spec_from_file_location
10-
from pathlib import Path
1111
from tempfile import mkdtemp
1212
from types import ModuleType
1313
from typing import Any, Dict, List, Tuple, Type, TypeVar
1414
from uuid import uuid4
1515

1616
try:
17-
from pydantic import BaseModel as BaseModelV2
18-
from pydantic import create_model as create_model_v2
17+
from pydantic import BaseModel as BaseModelV2, create_model as create_model_v2
1918
from pydantic.v1 import (
2019
BaseModel as BaseModelV1,
21-
)
22-
from pydantic.v1 import (
2320
create_model as create_model_v1,
2421
)
2522

2623
BaseModelType = TypeVar("BaseModelType", Type[BaseModelV1], Type[BaseModelV2])
2724
except ImportError:
25+
BaseModelV2 = None
26+
create_model_v2 = None
2827
from pydantic import (
2928
BaseModel as BaseModelV1,
30-
)
31-
from pydantic import (
3229
create_model as create_model_v1,
3330
)
3431

35-
BaseModelV2 = None
36-
create_model_v2 = None
3732
BaseModelType = TypeVar("BaseModelType", Type[BaseModelV1])
3833

3934
try:
@@ -56,7 +51,7 @@ def _import_module(path: str) -> ModuleType:
5651
definition exist in sys.modules under that name.
5752
"""
5853
try:
59-
if Path(path).exists():
54+
if os.path.exists(path):
6055
name = uuid4().hex
6156
spec = spec_from_file_location(name, path, submodule_search_locations=[])
6257
module = module_from_spec(spec)
@@ -102,7 +97,7 @@ def _is_pydantic_v2_model(obj: Any) -> bool:
10297
)
10398

10499

105-
def _is_concrete_pydantic_model(obj: Any) -> bool:
100+
def _is_pydantic_model(obj: Any) -> bool:
106101
"""
107102
Return true if an object is a concrete subclass of pydantic's BaseModel.
108103
'concrete' meaning that it's not a generic model.
@@ -117,7 +112,7 @@ def _extract_pydantic_models(module: ModuleType) -> List[BaseModelType]:
117112
models = []
118113
module_name = module.__name__
119114

120-
for _, model in inspect.getmembers(module, _is_concrete_pydantic_model):
115+
for _, model in inspect.getmembers(module, _is_pydantic_model):
121116
models.append(model)
122117

123118
for _, submodule in inspect.getmembers(
@@ -179,79 +174,70 @@ def _clean_schema(schema: Dict[str, Any]) -> None:
179174
if "enum" in schema and schema.get("description") == "An enumeration.":
180175
del schema["description"]
181176

177+
# TODO: add check for if it is truly pydantic v1. If so, fix nullable fields. Do the thing to add "null" to union.
178+
# https://github.com/pydantic/pydantic/issues/1270#issuecomment-729555558
182179

183-
def _generate_json_schema_v1(models: List[Type[BaseModelV1]]) -> str:
180+
181+
def _generate_json_schema(models: List[BaseModelType]) -> str:
184182
"""
185183
Create a top-level '_Master_' model with references to each of the actual models.
186184
Generate the schema for this model, which will include the schemas for all the
187185
nested models. Then clean up the schema.
188-
189-
One weird thing we do is we temporarily override the 'extra' setting in models,
190-
changing it to 'forbid' UNLESS it was explicitly set to 'allow'. This prevents
191-
'[k: string]: any' from being added to every interface. This change is reverted
192-
once the schema has been generated.
193186
"""
194-
model_extras = [getattr(m.Config, "extra", None) for m in models]
187+
with _forbid_extras(models):
188+
v1 = any(issubclass(m, BaseModelV1) for m in models)
195189

196-
try:
197-
for m in models:
198-
if getattr(m.Config, "extra", None) != "allow":
199-
m.Config.extra = "forbid"
200-
201-
master_model = create_model_v1(
190+
master_model = (create_model_v1 if v1 else create_model_v2)(
202191
"_Master_", **{m.__name__: (m, ...) for m in models}
203192
)
204-
master_model.Config.extra = "forbid"
205-
master_model.Config.schema_extra = staticmethod(_clean_schema)
206193

207-
schema = json.loads(master_model.schema_json())
194+
if v1:
195+
master_model.Config.extra = "forbid"
196+
master_model.Config.schema_extra = staticmethod(_clean_schema)
197+
else:
198+
master_model.model_config["extra"] = "forbid"
199+
master_model.model_config["json_schema_extra"] = staticmethod(_clean_schema)
208200

209-
for d in schema.get("definitions", {}).values():
201+
schema = (
202+
json.loads(master_model.schema_json())
203+
if v1
204+
else master_model.model_json_schema(mode="serialization")
205+
)
206+
207+
for d in schema.get("definitions" if v1 else "$defs", {}).values():
210208
_clean_schema(d)
211209

212210
return json.dumps(schema, indent=2)
213211

214-
finally:
215-
for m, x in zip(models, model_extras):
216-
if x is not None:
217-
m.Config.extra = x
218212

219-
220-
def _generate_json_schema_v2(models: List[Type[BaseModelV2]]) -> str:
213+
@contextmanager
214+
def _forbid_extras(models: List[BaseModelType]) -> None:
221215
"""
222-
Create a top-level '_Master_' model with references to each of the actual models.
223-
Generate the schema for this model, which will include the schemas for all the
224-
nested models. Then clean up the schema.
216+
Temporarily override the 'extra' setting in models,
217+
changing it to 'forbid' UNLESS it was explicitly set to 'allow'.
225218
226-
One weird thing we do is we temporarily override the 'extra' setting in models,
227-
changing it to 'forbid' UNLESS it was explicitly set to 'allow'. This prevents
228-
'[k: string]: any' from being added to every interface. This change is reverted
229-
once the schema has been generated.
219+
This prevents '[k: string]: any' from being added to every interface.
220+
This change is reverted once the schema has been generated.
230221
"""
231-
model_extras = [m.model_config.get("extra") for m in models]
232-
222+
v1 = any(issubclass(m, BaseModelV1) for m in models)
223+
extras = [
224+
getattr(m.Config, "extra", None) if v1 else m.model_config.get("extra")
225+
for m in models
226+
]
233227
try:
234228
for m in models:
235-
if m.model_config.get("extra") != "allow":
229+
if v1:
230+
m.Config.extra = "forbid"
231+
else:
236232
m.model_config["extra"] = "forbid"
237-
238-
master_model = create_model_v2(
239-
"_Master_", **{m.__name__: (m, ...) for m in models}
240-
)
241-
master_model.model_config["extra"] = "forbid"
242-
master_model.model_config["json_schema_extra"] = staticmethod(_clean_schema)
243-
244-
schema: dict = master_model.model_json_schema(mode="serialization")
245-
246-
for d in schema.get("$defs", {}).values():
247-
_clean_schema(d)
248-
249-
return json.dumps(schema, indent=2)
250-
233+
yield
251234
finally:
252-
for m, x in zip(models, model_extras):
235+
for m, x in zip(models, extras):
253236
if x is not None:
254-
m.model_config["extra"] = x
237+
if v1:
238+
m.Config.extra = x
239+
else:
240+
m.model_config["extra"] = x
255241

256242

257243
def generate_typescript_defs(
@@ -277,20 +263,19 @@ def generate_typescript_defs(
277263
models = _extract_pydantic_models(_import_module(module))
278264

279265
if exclude:
280-
models = [m for m in models if m.__name__ not in exclude]
266+
models = [
267+
m
268+
for m in models
269+
if (m.__name__ not in exclude and m.__qualname__ not in exclude)
270+
]
281271

282272
if not models:
283273
logger.info("No pydantic models found, exiting.")
284274
return
285275

286276
logger.info("Generating JSON schema from pydantic models...")
287277

288-
schema = (
289-
_generate_json_schema_v1(models)
290-
if any(issubclass(m, BaseModelV1) for m in models)
291-
else _generate_json_schema_v2(models)
292-
)
293-
278+
schema = _generate_json_schema(models)
294279
schema_dir = mkdtemp()
295280
schema_file_path = os.path.join(schema_dir, "schema.json")
296281

0 commit comments

Comments
 (0)