Skip to content

Commit a4a57d4

Browse files
committed
Improve adding tsEnumNames
1 parent 8caba44 commit a4a57d4

File tree

1 file changed

+62
-40
lines changed

1 file changed

+62
-40
lines changed

pydantic2ts/cli/script.py

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from pathlib import Path
1313
from tempfile import mkdtemp
1414
from types import ModuleType
15-
from typing import Any, Dict, List, Tuple, Type, Union
16-
from typing_extensions import get_args, get_origin
15+
from typing import Any, Dict, List, Tuple, Type
16+
17+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
18+
from pydantic_core import core_schema
1719
from uuid import uuid4
1820

1921
from pydantic import VERSION, BaseModel, create_model
@@ -89,6 +91,13 @@ def is_concrete_pydantic_model(obj) -> bool:
8991
return issubclass(obj, BaseModel)
9092

9193

94+
def is_enum(obj) -> bool:
95+
"""
96+
Return true if an object is an Enum.
97+
"""
98+
return inspect.isclass(obj) and issubclass(obj, Enum)
99+
100+
92101
def extract_pydantic_models(module: ModuleType) -> List[Type[BaseModel]]:
93102
"""
94103
Given a module, return a list of the pydantic models contained within it.
@@ -107,6 +116,24 @@ def extract_pydantic_models(module: ModuleType) -> List[Type[BaseModel]]:
107116
return models
108117

109118

119+
def extract_enum_models(module: ModuleType) -> List[Type[Enum]]:
120+
"""
121+
Given a module, return a list of the Enum classes contained within it.
122+
"""
123+
enums = []
124+
module_name = module.__name__
125+
126+
for _, enum in inspect.getmembers(module, is_enum):
127+
enums.append(enum)
128+
129+
for _, submodule in inspect.getmembers(
130+
module, lambda obj: is_submodule(obj, module_name)
131+
):
132+
enums.extend(extract_enum_models(submodule))
133+
134+
return enums
135+
136+
110137
def clean_output_file(output_filename: str) -> None:
111138
"""
112139
Clean up the output file typescript definitions were written to by:
@@ -159,8 +186,14 @@ def clean_schema(schema: Dict[str, Any]) -> None:
159186
del schema["description"]
160187

161188

162-
def add_ts_enum_names(schema: Dict[str, Any], enum_class: Type[Enum]) -> None:
163-
schema["tsEnumNames"] = [name for name, member in enum_class.__members__.items()]
189+
def add_enum_names_v1(model: Type[Enum]) -> None:
190+
@classmethod
191+
def __modify_schema__(cls, field_schema: Dict[str, Any]):
192+
field_schema.update(tsEnumNames=list(model.__members__.keys()))
193+
for name, value in zip(field_schema["tsEnumNames"], field_schema["enum"]):
194+
assert cls[name].value == value
195+
196+
setattr(model, "__modify_schema__", __modify_schema__)
164197

165198

166199
def is_matching_enum(prop_type: Any, schema_title: str, schema_enum: list[str]) -> bool:
@@ -175,36 +208,18 @@ def is_matching_enum(prop_type: Any, schema_title: str, schema_enum: list[str])
175208
)
176209

177210

178-
def extend_enum_definitions(
179-
schema: Dict[str, Any], models: List[Type[BaseModel]]
180-
) -> None:
181-
"""
182-
Extend the 'enum' property of a schema with the tsEnumNames property
183-
for any Enum fields in the models so that the generated TypeScript
184-
definitions will include enums instead of plain strings.
185-
"""
186-
if ("enum" in schema) and (not "tsEnumNames" in schema):
187-
for model in models:
188-
for prop, prop_type in model.__annotations__.items():
189-
origin = get_origin(prop_type)
190-
if is_matching_enum(prop_type, schema["title"], schema["enum"]):
191-
add_ts_enum_names(schema, prop_type)
192-
break
193-
elif origin is list:
194-
inner_type = get_args(prop_type)[0]
195-
if is_matching_enum(inner_type, schema["title"], schema["enum"]):
196-
add_ts_enum_names(schema, inner_type)
197-
break
198-
elif (UnionType and origin is UnionType) or origin is Union:
199-
for inner_type in get_args(prop_type):
200-
if is_matching_enum(
201-
inner_type, schema["title"], schema["enum"]
202-
):
203-
add_ts_enum_names(schema, inner_type)
204-
break
205-
206-
207-
def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
211+
class CustomGenerateJsonSchema(GenerateJsonSchema):
212+
def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue:
213+
# Call the original method
214+
result = super().enum_schema(schema)
215+
216+
# Add tsEnumNames property
217+
result['tsEnumNames'] = [v.name for v in schema['members']]
218+
219+
return result
220+
221+
222+
def generate_json_schema_v1(models: List[Type[BaseModel]], enums: List[Type[Enum]]) -> str:
208223
"""
209224
Create a top-level '_Master_' model with references to each of the actual models.
210225
Generate the schema for this model, which will include the schemas for all the
@@ -222,8 +237,12 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
222237
if getattr(m.Config, "extra", None) != "allow":
223238
m.Config.extra = "forbid"
224239

240+
for e in enums:
241+
add_enum_names_v1(e)
242+
243+
all_models = models + enums
225244
master_model = create_model(
226-
"_Master_", **{m.__name__: (m, ...) for m in models}
245+
"_Master_", **{m.__name__: (m, ...) for m in all_models}
227246
)
228247
master_model.Config.extra = "forbid"
229248
master_model.Config.schema_extra = staticmethod(clean_schema)
@@ -232,7 +251,6 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
232251

233252
for d in schema.get("definitions", {}).values():
234253
clean_schema(d)
235-
extend_enum_definitions(d, models)
236254

237255
return json.dumps(schema, indent=2)
238256

@@ -266,11 +284,10 @@ def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
266284
master_model.model_config["extra"] = "forbid"
267285
master_model.model_config["json_schema_extra"] = staticmethod(clean_schema)
268286

269-
schema: dict = master_model.model_json_schema(mode="serialization")
287+
schema: dict = master_model.model_json_schema(schema_generator=CustomGenerateJsonSchema, mode="serialization")
270288

271289
for d in schema.get("$defs", {}).values():
272290
clean_schema(d)
273-
extend_enum_definitions(d, models)
274291

275292
return json.dumps(schema, indent=2)
276293

@@ -300,14 +317,19 @@ def generate_typescript_defs(
300317

301318
logger.info("Finding pydantic models...")
302319

303-
models = extract_pydantic_models(import_module(module))
320+
import_result = import_module(module)
321+
models = extract_pydantic_models(import_result)
304322

305323
if exclude:
306324
models = [m for m in models if m.__name__ not in exclude]
307325

308326
logger.info("Generating JSON schema from pydantic models...")
309327

310-
schema = generate_json_schema_v2(models) if V2 else generate_json_schema_v1(models)
328+
if V2:
329+
schema = generate_json_schema_v2(models)
330+
else:
331+
enums = extract_enum_models(import_result)
332+
schema = generate_json_schema_v1(models, enums)
311333

312334
schema_dir = mkdtemp()
313335
schema_file_path = os.path.join(schema_dir, "schema.json")

0 commit comments

Comments
 (0)