12
12
from pathlib import Path
13
13
from tempfile import mkdtemp
14
14
from types import ModuleType
15
- from typing import Any , Dict , List , Tuple , Type , Union
16
- from typing_extensions import get_args , get_origin
15
+ from typing import Any , Dict , List , Tuple , Type
16
+
17
+ from pydantic .json_schema import GenerateJsonSchema , JsonSchemaValue
18
+ from pydantic_core import core_schema
17
19
from uuid import uuid4
18
20
19
21
from pydantic import VERSION , BaseModel , create_model
@@ -89,6 +91,13 @@ def is_concrete_pydantic_model(obj) -> bool:
89
91
return issubclass (obj , BaseModel )
90
92
91
93
94
+ def is_enum (obj ) -> bool :
95
+ """
96
+ Return true if an object is an Enum.
97
+ """
98
+ return inspect .isclass (obj ) and issubclass (obj , Enum )
99
+
100
+
92
101
def extract_pydantic_models (module : ModuleType ) -> List [Type [BaseModel ]]:
93
102
"""
94
103
Given a module, return a list of the pydantic models contained within it.
@@ -107,6 +116,24 @@ def extract_pydantic_models(module: ModuleType) -> List[Type[BaseModel]]:
107
116
return models
108
117
109
118
119
+ def extract_enum_models (module : ModuleType ) -> List [Type [Enum ]]:
120
+ """
121
+ Given a module, return a list of the Enum classes contained within it.
122
+ """
123
+ enums = []
124
+ module_name = module .__name__
125
+
126
+ for _ , enum in inspect .getmembers (module , is_enum ):
127
+ enums .append (enum )
128
+
129
+ for _ , submodule in inspect .getmembers (
130
+ module , lambda obj : is_submodule (obj , module_name )
131
+ ):
132
+ enums .extend (extract_enum_models (submodule ))
133
+
134
+ return enums
135
+
136
+
110
137
def clean_output_file (output_filename : str ) -> None :
111
138
"""
112
139
Clean up the output file typescript definitions were written to by:
@@ -159,8 +186,14 @@ def clean_schema(schema: Dict[str, Any]) -> None:
159
186
del schema ["description" ]
160
187
161
188
162
- def add_ts_enum_names (schema : Dict [str , Any ], enum_class : Type [Enum ]) -> None :
163
- schema ["tsEnumNames" ] = [name for name , member in enum_class .__members__ .items ()]
189
+ def add_enum_names_v1 (model : Type [Enum ]) -> None :
190
+ @classmethod
191
+ def __modify_schema__ (cls , field_schema : Dict [str , Any ]):
192
+ field_schema .update (tsEnumNames = list (model .__members__ .keys ()))
193
+ for name , value in zip (field_schema ["tsEnumNames" ], field_schema ["enum" ]):
194
+ assert cls [name ].value == value
195
+
196
+ setattr (model , "__modify_schema__" , __modify_schema__ )
164
197
165
198
166
199
def is_matching_enum (prop_type : Any , schema_title : str , schema_enum : list [str ]) -> bool :
@@ -175,36 +208,18 @@ def is_matching_enum(prop_type: Any, schema_title: str, schema_enum: list[str])
175
208
)
176
209
177
210
178
- def extend_enum_definitions (
179
- schema : Dict [str , Any ], models : List [Type [BaseModel ]]
180
- ) -> None :
181
- """
182
- Extend the 'enum' property of a schema with the tsEnumNames property
183
- for any Enum fields in the models so that the generated TypeScript
184
- definitions will include enums instead of plain strings.
185
- """
186
- if ("enum" in schema ) and (not "tsEnumNames" in schema ):
187
- for model in models :
188
- for prop , prop_type in model .__annotations__ .items ():
189
- origin = get_origin (prop_type )
190
- if is_matching_enum (prop_type , schema ["title" ], schema ["enum" ]):
191
- add_ts_enum_names (schema , prop_type )
192
- break
193
- elif origin is list :
194
- inner_type = get_args (prop_type )[0 ]
195
- if is_matching_enum (inner_type , schema ["title" ], schema ["enum" ]):
196
- add_ts_enum_names (schema , inner_type )
197
- break
198
- elif (UnionType and origin is UnionType ) or origin is Union :
199
- for inner_type in get_args (prop_type ):
200
- if is_matching_enum (
201
- inner_type , schema ["title" ], schema ["enum" ]
202
- ):
203
- add_ts_enum_names (schema , inner_type )
204
- break
205
-
206
-
207
- def generate_json_schema_v1 (models : List [Type [BaseModel ]]) -> str :
211
+ class CustomGenerateJsonSchema (GenerateJsonSchema ):
212
+ def enum_schema (self , schema : core_schema .EnumSchema ) -> JsonSchemaValue :
213
+ # Call the original method
214
+ result = super ().enum_schema (schema )
215
+
216
+ # Add tsEnumNames property
217
+ result ['tsEnumNames' ] = [v .name for v in schema ['members' ]]
218
+
219
+ return result
220
+
221
+
222
+ def generate_json_schema_v1 (models : List [Type [BaseModel ]], enums : List [Type [Enum ]]) -> str :
208
223
"""
209
224
Create a top-level '_Master_' model with references to each of the actual models.
210
225
Generate the schema for this model, which will include the schemas for all the
@@ -222,8 +237,12 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
222
237
if getattr (m .Config , "extra" , None ) != "allow" :
223
238
m .Config .extra = "forbid"
224
239
240
+ for e in enums :
241
+ add_enum_names_v1 (e )
242
+
243
+ all_models = models + enums
225
244
master_model = create_model (
226
- "_Master_" , ** {m .__name__ : (m , ...) for m in models }
245
+ "_Master_" , ** {m .__name__ : (m , ...) for m in all_models }
227
246
)
228
247
master_model .Config .extra = "forbid"
229
248
master_model .Config .schema_extra = staticmethod (clean_schema )
@@ -232,7 +251,6 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
232
251
233
252
for d in schema .get ("definitions" , {}).values ():
234
253
clean_schema (d )
235
- extend_enum_definitions (d , models )
236
254
237
255
return json .dumps (schema , indent = 2 )
238
256
@@ -266,11 +284,10 @@ def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
266
284
master_model .model_config ["extra" ] = "forbid"
267
285
master_model .model_config ["json_schema_extra" ] = staticmethod (clean_schema )
268
286
269
- schema : dict = master_model .model_json_schema (mode = "serialization" )
287
+ schema : dict = master_model .model_json_schema (schema_generator = CustomGenerateJsonSchema , mode = "serialization" )
270
288
271
289
for d in schema .get ("$defs" , {}).values ():
272
290
clean_schema (d )
273
- extend_enum_definitions (d , models )
274
291
275
292
return json .dumps (schema , indent = 2 )
276
293
@@ -300,14 +317,19 @@ def generate_typescript_defs(
300
317
301
318
logger .info ("Finding pydantic models..." )
302
319
303
- models = extract_pydantic_models (import_module (module ))
320
+ import_result = import_module (module )
321
+ models = extract_pydantic_models (import_result )
304
322
305
323
if exclude :
306
324
models = [m for m in models if m .__name__ not in exclude ]
307
325
308
326
logger .info ("Generating JSON schema from pydantic models..." )
309
327
310
- schema = generate_json_schema_v2 (models ) if V2 else generate_json_schema_v1 (models )
328
+ if V2 :
329
+ schema = generate_json_schema_v2 (models )
330
+ else :
331
+ enums = extract_enum_models (import_result )
332
+ schema = generate_json_schema_v1 (models , enums )
311
333
312
334
schema_dir = mkdtemp ()
313
335
schema_file_path = os .path .join (schema_dir , "schema.json" )
0 commit comments