6
6
import os
7
7
import shutil
8
8
import sys
9
+ from enum import Enum
9
10
from importlib .util import module_from_spec , spec_from_file_location
10
11
from pathlib import Path
11
12
from tempfile import mkdtemp
12
13
from types import ModuleType
13
- from typing import Any , Dict , List , Tuple , Type
14
+ from typing import Any , Dict , List , Tuple , Type , get_origin , get_args , Set , cast
14
15
from uuid import uuid4
15
16
16
17
from pydantic import VERSION , BaseModel , create_model
17
18
19
+ try :
20
+ from types import UnionType
21
+ except ImportError :
22
+ UnionType = None
23
+
18
24
V2 = True if VERSION .startswith ("2" ) else False
19
25
20
26
if not V2 :
23
29
except ImportError :
24
30
GenericModel = None
25
31
26
- logger = logging .getLogger ("pydantic2ts" )
27
-
32
+ if V2 :
33
+ try :
34
+ from pydantic .json_schema import GenerateJsonSchema , JsonSchemaValue
35
+ from pydantic_core import core_schema
36
+ except ImportError :
37
+ GenerateJsonSchema = None
38
+ JsonSchemaValue = None
39
+ core_schema = None
28
40
29
- DEBUG = os . environ . get ( "DEBUG" , False )
41
+ logger = logging . getLogger ( "pydantic2ts" )
30
42
31
43
32
44
def import_module (path : str ) -> ModuleType :
@@ -63,7 +75,7 @@ def is_submodule(obj, module_name: str) -> bool:
63
75
)
64
76
65
77
66
- def is_concrete_pydantic_model (obj ) -> bool :
78
+ def is_concrete_pydantic_model (obj : type ) -> bool :
67
79
"""
68
80
Return true if an object is a concrete subclass of pydantic's BaseModel.
69
81
'concrete' meaning that it's not a GenericModel.
@@ -81,24 +93,80 @@ def is_concrete_pydantic_model(obj) -> bool:
81
93
return issubclass (obj , BaseModel )
82
94
83
95
96
+ def is_enum (obj ) -> bool :
97
+ """
98
+ Return true if an object is an Enum.
99
+ """
100
+ return inspect .isclass (obj ) and issubclass (obj , Enum )
101
+
102
+
103
+ def flatten_types (field_type : type ) -> Set [type ]:
104
+ types = set ()
105
+
106
+ origin = get_origin (field_type )
107
+ if origin is None :
108
+ types .add (field_type )
109
+ else :
110
+ args = get_args (field_type )
111
+ for arg in args :
112
+ types .update (flatten_types (arg ))
113
+
114
+ return types
115
+
116
+
117
+ def get_model_fields (model : Type [BaseModel ]) -> Dict [str , Any ]:
118
+ if V2 :
119
+ return model .model_fields
120
+ else :
121
+ return model .__fields__
122
+
123
+
124
+ def extract_pydantic_models_from_model (
125
+ model : Type [BaseModel ], all_models : List [Type [BaseModel ]]
126
+ ) -> None :
127
+ """
128
+ Given a pydantic model, add the pydantic models contained within it to all_models.
129
+ """
130
+ if model in all_models :
131
+ return
132
+
133
+ all_models .append (model )
134
+
135
+ for field , field_type in get_model_fields (model ).items ():
136
+ flattened_types = flatten_types (field_type .annotation )
137
+ for inner_type in flattened_types :
138
+ if is_concrete_pydantic_model (inner_type ):
139
+ extract_pydantic_models_from_model (inner_type , all_models )
140
+
141
+
84
142
def extract_pydantic_models (module : ModuleType ) -> List [Type [BaseModel ]]:
85
143
"""
86
144
Given a module, return a list of the pydantic models contained within it.
87
145
"""
88
146
models = []
89
- module_name = module .__name__
90
147
91
148
for _ , model in inspect .getmembers (module , is_concrete_pydantic_model ):
92
- models .append (model )
93
-
94
- for _ , submodule in inspect .getmembers (
95
- module , lambda obj : is_submodule (obj , module_name )
96
- ):
97
- models .extend (extract_pydantic_models (submodule ))
149
+ extract_pydantic_models_from_model (model , models )
98
150
99
151
return models
100
152
101
153
154
+ def extract_enum_models (models : List [Type [BaseModel ]]) -> List [Type [Enum ]]:
155
+ """
156
+ Given a list of pydantic models, return a list of the Enum classes used as fields within those models.
157
+ """
158
+ enums = []
159
+
160
+ for model in models :
161
+ for field_type in get_model_fields (model ).values ():
162
+ flattened_types = flatten_types (field_type .annotation )
163
+ for inner_type in flattened_types :
164
+ if is_enum (inner_type ):
165
+ enums .append (cast (Type [Enum ], inner_type ))
166
+
167
+ return enums
168
+
169
+
102
170
def clean_output_file (output_filename : str ) -> None :
103
171
"""
104
172
Clean up the output file typescript definitions were written to by:
@@ -151,7 +219,34 @@ def clean_schema(schema: Dict[str, Any]) -> None:
151
219
del schema ["description" ]
152
220
153
221
154
- def generate_json_schema_v1 (models : List [Type [BaseModel ]]) -> str :
222
+ def add_enum_names_v1 (model : Type [Enum ]) -> None :
223
+ @classmethod
224
+ def __modify_schema__ (cls , field_schema : Dict [str , Any ]):
225
+ if len (model .__members__ .keys ()) == len (field_schema ["enum" ]):
226
+ field_schema .update (tsEnumNames = list (model .__members__ .keys ()))
227
+ for name , value in zip (field_schema ["tsEnumNames" ], field_schema ["enum" ]):
228
+ assert cls [name ].value == value
229
+
230
+ setattr (model , "__modify_schema__" , __modify_schema__ )
231
+
232
+
233
+ if V2 :
234
+
235
+ class CustomGenerateJsonSchema (GenerateJsonSchema ):
236
+ def enum_schema (self , schema : core_schema .EnumSchema ) -> JsonSchemaValue :
237
+ # Call the original method
238
+ result = super ().enum_schema (schema )
239
+
240
+ # Add tsEnumNames property
241
+ if len (schema ["members" ]) > 0 :
242
+ result ["tsEnumNames" ] = [v .name for v in schema ["members" ]]
243
+
244
+ return result
245
+
246
+
247
+ def generate_json_schema_v1 (
248
+ models : List [Type [BaseModel ]], enums : List [Type [Enum ]]
249
+ ) -> str :
155
250
"""
156
251
Create a top-level '_Master_' model with references to each of the actual models.
157
252
Generate the schema for this model, which will include the schemas for all the
@@ -162,18 +257,21 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
162
257
'[k: string]: any' from being added to every interface. This change is reverted
163
258
once the schema has been generated.
164
259
"""
165
- model_extras = [m . model_config . get ( "extra" , None ) for m in models ]
260
+ model_extras = [getattr ( m . Config , "extra" , None ) for m in models ]
166
261
167
262
try :
168
263
for m in models :
169
- if m .model_config .get ("extra" , None ) != "allow" :
170
- m .model_config ["extra" ] = "forbid"
264
+ if getattr (m .Config , "extra" , None ) != "allow" :
265
+ m .Config .extra = "forbid"
266
+
267
+ for e in enums :
268
+ add_enum_names_v1 (e )
171
269
172
270
master_model = create_model (
173
- "_Master_" , ** {m .__name__ : (m , ...) for m in models }, __base__ = m
271
+ "_Master_" , ** {m .__name__ : (m , ...) for m in models }
174
272
)
175
- master_model .model_config [ " extra" ] = "forbid"
176
- master_model .model_config [ " schema_extra" ] = staticmethod (clean_schema )
273
+ master_model .Config . extra = "forbid"
274
+ master_model .Config . schema_extra = staticmethod (clean_schema )
177
275
178
276
schema = json .loads (master_model .schema_json ())
179
277
@@ -185,7 +283,7 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
185
283
finally :
186
284
for m , x in zip (models , model_extras ):
187
285
if x is not None :
188
- m .model_config [ " extra" ] = x
286
+ m .Config . extra = x
189
287
190
288
191
289
def generate_json_schema_v2 (models : List [Type [BaseModel ]]) -> str :
@@ -212,7 +310,9 @@ def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
212
310
master_model .model_config ["extra" ] = "forbid"
213
311
master_model .model_config ["json_schema_extra" ] = staticmethod (clean_schema )
214
312
215
- schema : dict = master_model .model_json_schema (mode = "serialization" )
313
+ schema : dict = master_model .model_json_schema (
314
+ schema_generator = CustomGenerateJsonSchema , mode = "serialization"
315
+ )
216
316
217
317
for d in schema .get ("$defs" , {}).values ():
218
318
clean_schema (d )
@@ -252,14 +352,20 @@ def generate_typescript_defs(
252
352
253
353
logger .info ("Generating JSON schema from pydantic models..." )
254
354
255
- schema = generate_json_schema_v2 (models ) if V2 else generate_json_schema_v1 (models )
355
+ if V2 :
356
+ schema = generate_json_schema_v2 (models )
357
+ else :
358
+ enums = extract_enum_models (models )
359
+ schema = generate_json_schema_v1 (models , enums )
256
360
257
361
schema_dir = mkdtemp ()
258
362
schema_file_path = os .path .join (schema_dir , "schema.json" )
259
363
260
364
with open (schema_file_path , "w" ) as f :
261
365
f .write (schema )
262
366
367
+ DEBUG = os .environ .get ("DEBUG" , False )
368
+
263
369
if DEBUG :
264
370
debug_schema_file_path = Path (module ).parent / "schema_debug.json"
265
371
# raise ValueError(module)
0 commit comments