6
6
import os
7
7
import shutil
8
8
import sys
9
+ from contextlib import contextmanager
9
10
from importlib .util import module_from_spec , spec_from_file_location
10
- from pathlib import Path
11
11
from tempfile import mkdtemp
12
12
from types import ModuleType
13
13
from typing import Any , Dict , List , Tuple , Type , TypeVar
14
14
from uuid import uuid4
15
15
16
16
try :
17
- from pydantic import BaseModel as BaseModelV2
18
- from pydantic import create_model as create_model_v2
17
+ from pydantic import BaseModel as BaseModelV2 , create_model as create_model_v2
19
18
from pydantic .v1 import (
20
19
BaseModel as BaseModelV1 ,
21
- )
22
- from pydantic .v1 import (
23
20
create_model as create_model_v1 ,
24
21
)
25
22
26
23
BaseModelType = TypeVar ("BaseModelType" , Type [BaseModelV1 ], Type [BaseModelV2 ])
27
24
except ImportError :
25
+ BaseModelV2 = None
26
+ create_model_v2 = None
28
27
from pydantic import (
29
28
BaseModel as BaseModelV1 ,
30
- )
31
- from pydantic import (
32
29
create_model as create_model_v1 ,
33
30
)
34
31
35
- BaseModelV2 = None
36
- create_model_v2 = None
37
32
BaseModelType = TypeVar ("BaseModelType" , Type [BaseModelV1 ])
38
33
39
34
try :
@@ -56,7 +51,7 @@ def _import_module(path: str) -> ModuleType:
56
51
definition exist in sys.modules under that name.
57
52
"""
58
53
try :
59
- if Path ( path ) .exists ():
54
+ if os . path .exists (path ):
60
55
name = uuid4 ().hex
61
56
spec = spec_from_file_location (name , path , submodule_search_locations = [])
62
57
module = module_from_spec (spec )
@@ -102,7 +97,7 @@ def _is_pydantic_v2_model(obj: Any) -> bool:
102
97
)
103
98
104
99
105
- def _is_concrete_pydantic_model (obj : Any ) -> bool :
100
+ def _is_pydantic_model (obj : Any ) -> bool :
106
101
"""
107
102
Return true if an object is a concrete subclass of pydantic's BaseModel.
108
103
'concrete' meaning that it's not a generic model.
@@ -117,7 +112,7 @@ def _extract_pydantic_models(module: ModuleType) -> List[BaseModelType]:
117
112
models = []
118
113
module_name = module .__name__
119
114
120
- for _ , model in inspect .getmembers (module , _is_concrete_pydantic_model ):
115
+ for _ , model in inspect .getmembers (module , _is_pydantic_model ):
121
116
models .append (model )
122
117
123
118
for _ , submodule in inspect .getmembers (
@@ -179,79 +174,70 @@ def _clean_schema(schema: Dict[str, Any]) -> None:
179
174
if "enum" in schema and schema .get ("description" ) == "An enumeration." :
180
175
del schema ["description" ]
181
176
177
+ # TODO: add check for if it is truly pydantic v1. If so, fix nullable fields. Do the thing to add "null" to union.
178
+ # https://github.com/pydantic/pydantic/issues/1270#issuecomment-729555558
182
179
183
- def _generate_json_schema_v1 (models : List [Type [BaseModelV1 ]]) -> str :
180
+
181
+ def _generate_json_schema (models : List [BaseModelType ]) -> str :
184
182
"""
185
183
Create a top-level '_Master_' model with references to each of the actual models.
186
184
Generate the schema for this model, which will include the schemas for all the
187
185
nested models. Then clean up the schema.
188
-
189
- One weird thing we do is we temporarily override the 'extra' setting in models,
190
- changing it to 'forbid' UNLESS it was explicitly set to 'allow'. This prevents
191
- '[k: string]: any' from being added to every interface. This change is reverted
192
- once the schema has been generated.
193
186
"""
194
- model_extras = [getattr (m .Config , "extra" , None ) for m in models ]
187
+ with _forbid_extras (models ):
188
+ v1 = any (issubclass (m , BaseModelV1 ) for m in models )
195
189
196
- try :
197
- for m in models :
198
- if getattr (m .Config , "extra" , None ) != "allow" :
199
- m .Config .extra = "forbid"
200
-
201
- master_model = create_model_v1 (
190
+ master_model = (create_model_v1 if v1 else create_model_v2 )(
202
191
"_Master_" , ** {m .__name__ : (m , ...) for m in models }
203
192
)
204
- master_model .Config .extra = "forbid"
205
- master_model .Config .schema_extra = staticmethod (_clean_schema )
206
193
207
- schema = json .loads (master_model .schema_json ())
194
+ if v1 :
195
+ master_model .Config .extra = "forbid"
196
+ master_model .Config .schema_extra = staticmethod (_clean_schema )
197
+ else :
198
+ master_model .model_config ["extra" ] = "forbid"
199
+ master_model .model_config ["json_schema_extra" ] = staticmethod (_clean_schema )
208
200
209
- for d in schema .get ("definitions" , {}).values ():
201
+ schema = (
202
+ json .loads (master_model .schema_json ())
203
+ if v1
204
+ else master_model .model_json_schema (mode = "serialization" )
205
+ )
206
+
207
+ for d in schema .get ("definitions" if v1 else "$defs" , {}).values ():
210
208
_clean_schema (d )
211
209
212
210
return json .dumps (schema , indent = 2 )
213
211
214
- finally :
215
- for m , x in zip (models , model_extras ):
216
- if x is not None :
217
- m .Config .extra = x
218
212
219
-
220
- def _generate_json_schema_v2 (models : List [Type [ BaseModelV2 ]] ) -> str :
213
+ @ contextmanager
214
+ def _forbid_extras (models : List [BaseModelType ] ) -> None :
221
215
"""
222
- Create a top-level '_Master_' model with references to each of the actual models.
223
- Generate the schema for this model, which will include the schemas for all the
224
- nested models. Then clean up the schema.
216
+ Temporarily override the 'extra' setting in models,
217
+ changing it to 'forbid' UNLESS it was explicitly set to 'allow'.
225
218
226
- One weird thing we do is we temporarily override the 'extra' setting in models,
227
- changing it to 'forbid' UNLESS it was explicitly set to 'allow'. This prevents
228
- '[k: string]: any' from being added to every interface. This change is reverted
229
- once the schema has been generated.
219
+ This prevents '[k: string]: any' from being added to every interface.
220
+ This change is reverted once the schema has been generated.
230
221
"""
231
- model_extras = [m .model_config .get ("extra" ) for m in models ]
232
-
222
+ v1 = any (issubclass (m , BaseModelV1 ) for m in models )
223
+ extras = [
224
+ getattr (m .Config , "extra" , None ) if v1 else m .model_config .get ("extra" )
225
+ for m in models
226
+ ]
233
227
try :
234
228
for m in models :
235
- if m .model_config .get ("extra" ) != "allow" :
229
+ if v1 :
230
+ m .Config .extra = "forbid"
231
+ else :
236
232
m .model_config ["extra" ] = "forbid"
237
-
238
- master_model = create_model_v2 (
239
- "_Master_" , ** {m .__name__ : (m , ...) for m in models }
240
- )
241
- master_model .model_config ["extra" ] = "forbid"
242
- master_model .model_config ["json_schema_extra" ] = staticmethod (_clean_schema )
243
-
244
- schema : dict = master_model .model_json_schema (mode = "serialization" )
245
-
246
- for d in schema .get ("$defs" , {}).values ():
247
- _clean_schema (d )
248
-
249
- return json .dumps (schema , indent = 2 )
250
-
233
+ yield
251
234
finally :
252
- for m , x in zip (models , model_extras ):
235
+ for m , x in zip (models , extras ):
253
236
if x is not None :
254
- m .model_config ["extra" ] = x
237
+ if v1 :
238
+ m .Config .extra = x
239
+ else :
240
+ m .model_config ["extra" ] = x
255
241
256
242
257
243
def generate_typescript_defs (
@@ -277,20 +263,19 @@ def generate_typescript_defs(
277
263
models = _extract_pydantic_models (_import_module (module ))
278
264
279
265
if exclude :
280
- models = [m for m in models if m .__name__ not in exclude ]
266
+ models = [
267
+ m
268
+ for m in models
269
+ if (m .__name__ not in exclude and m .__qualname__ not in exclude )
270
+ ]
281
271
282
272
if not models :
283
273
logger .info ("No pydantic models found, exiting." )
284
274
return
285
275
286
276
logger .info ("Generating JSON schema from pydantic models..." )
287
277
288
- schema = (
289
- _generate_json_schema_v1 (models )
290
- if any (issubclass (m , BaseModelV1 ) for m in models )
291
- else _generate_json_schema_v2 (models )
292
- )
293
-
278
+ schema = _generate_json_schema (models )
294
279
schema_dir = mkdtemp ()
295
280
schema_file_path = os .path .join (schema_dir , "schema.json" )
296
281
0 commit comments