Skip to content

Commit 4eb2292

Browse files
authored
Merge pull request #3 from guaycuru/generate-enums
Generate Typescript enums instead of string types
2 parents 8fa1d2a + 5f5449a commit 4eb2292

File tree

24 files changed

+550
-31
lines changed

24 files changed

+550
-31
lines changed

.github/workflows/cicd.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
matrix:
2121
os: [ubuntu-latest, windows-latest, macOS-latest]
2222
python-version: ["3.9", "3.10", "3.11", "3.12"]
23+
pydantic-version: ["<2.0.0", ">=2.0.0"]
2324
steps:
2425
- name: Check out repo
2526
uses: actions/checkout@v4
@@ -37,6 +38,7 @@ jobs:
3738
- name: Install python dependencies
3839
run: |
3940
python -m pip install -U pip wheel pytest pytest-cov coverage
41+
python -m pip install -U "pydantic${{ matrix.pydantic-version }}"
4042
python -m pip install -U .
4143
- name: Run tests
4244
run: |

pydantic2ts/cli/script.py

Lines changed: 128 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,21 @@
66
import os
77
import shutil
88
import sys
9+
from enum import Enum
910
from importlib.util import module_from_spec, spec_from_file_location
1011
from pathlib import Path
1112
from tempfile import mkdtemp
1213
from types import ModuleType
13-
from typing import Any, Dict, List, Tuple, Type
14+
from typing import Any, Dict, List, Tuple, Type, get_origin, get_args, Set, cast
1415
from uuid import uuid4
1516

1617
from pydantic import VERSION, BaseModel, create_model
1718

19+
try:
20+
from types import UnionType
21+
except ImportError:
22+
UnionType = None
23+
1824
V2 = True if VERSION.startswith("2") else False
1925

2026
if not V2:
@@ -23,10 +29,16 @@
2329
except ImportError:
2430
GenericModel = None
2531

26-
logger = logging.getLogger("pydantic2ts")
27-
32+
if V2:
33+
try:
34+
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
35+
from pydantic_core import core_schema
36+
except ImportError:
37+
GenerateJsonSchema = None
38+
JsonSchemaValue = None
39+
core_schema = None
2840

29-
DEBUG = os.environ.get("DEBUG", False)
41+
logger = logging.getLogger("pydantic2ts")
3042

3143

3244
def import_module(path: str) -> ModuleType:
@@ -63,7 +75,7 @@ def is_submodule(obj, module_name: str) -> bool:
6375
)
6476

6577

66-
def is_concrete_pydantic_model(obj) -> bool:
78+
def is_concrete_pydantic_model(obj: type) -> bool:
6779
"""
6880
Return true if an object is a concrete subclass of pydantic's BaseModel.
6981
'concrete' meaning that it's not a GenericModel.
@@ -81,24 +93,80 @@ def is_concrete_pydantic_model(obj) -> bool:
8193
return issubclass(obj, BaseModel)
8294

8395

96+
def is_enum(obj) -> bool:
97+
"""
98+
Return true if an object is an Enum.
99+
"""
100+
return inspect.isclass(obj) and issubclass(obj, Enum)
101+
102+
103+
def flatten_types(field_type: type) -> Set[type]:
104+
types = set()
105+
106+
origin = get_origin(field_type)
107+
if origin is None:
108+
types.add(field_type)
109+
else:
110+
args = get_args(field_type)
111+
for arg in args:
112+
types.update(flatten_types(arg))
113+
114+
return types
115+
116+
117+
def get_model_fields(model: Type[BaseModel]) -> Dict[str, Any]:
118+
if V2:
119+
return model.model_fields
120+
else:
121+
return model.__fields__
122+
123+
124+
def extract_pydantic_models_from_model(
125+
model: Type[BaseModel], all_models: List[Type[BaseModel]]
126+
) -> None:
127+
"""
128+
Given a pydantic model, add the pydantic models contained within it to all_models.
129+
"""
130+
if model in all_models:
131+
return
132+
133+
all_models.append(model)
134+
135+
for field, field_type in get_model_fields(model).items():
136+
flattened_types = flatten_types(field_type.annotation)
137+
for inner_type in flattened_types:
138+
if is_concrete_pydantic_model(inner_type):
139+
extract_pydantic_models_from_model(inner_type, all_models)
140+
141+
84142
def extract_pydantic_models(module: ModuleType) -> List[Type[BaseModel]]:
85143
"""
86144
Given a module, return a list of the pydantic models contained within it.
87145
"""
88146
models = []
89-
module_name = module.__name__
90147

91148
for _, model in inspect.getmembers(module, is_concrete_pydantic_model):
92-
models.append(model)
93-
94-
for _, submodule in inspect.getmembers(
95-
module, lambda obj: is_submodule(obj, module_name)
96-
):
97-
models.extend(extract_pydantic_models(submodule))
149+
extract_pydantic_models_from_model(model, models)
98150

99151
return models
100152

101153

154+
def extract_enum_models(models: List[Type[BaseModel]]) -> List[Type[Enum]]:
155+
"""
156+
Given a list of pydantic models, return a list of the Enum classes used as fields within those models.
157+
"""
158+
enums = []
159+
160+
for model in models:
161+
for field_type in get_model_fields(model).values():
162+
flattened_types = flatten_types(field_type.annotation)
163+
for inner_type in flattened_types:
164+
if is_enum(inner_type):
165+
enums.append(cast(Type[Enum], inner_type))
166+
167+
return enums
168+
169+
102170
def clean_output_file(output_filename: str) -> None:
103171
"""
104172
Clean up the output file typescript definitions were written to by:
@@ -151,7 +219,34 @@ def clean_schema(schema: Dict[str, Any]) -> None:
151219
del schema["description"]
152220

153221

154-
def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
222+
def add_enum_names_v1(model: Type[Enum]) -> None:
223+
@classmethod
224+
def __modify_schema__(cls, field_schema: Dict[str, Any]):
225+
if len(model.__members__.keys()) == len(field_schema["enum"]):
226+
field_schema.update(tsEnumNames=list(model.__members__.keys()))
227+
for name, value in zip(field_schema["tsEnumNames"], field_schema["enum"]):
228+
assert cls[name].value == value
229+
230+
setattr(model, "__modify_schema__", __modify_schema__)
231+
232+
233+
if V2:
234+
235+
class CustomGenerateJsonSchema(GenerateJsonSchema):
236+
def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue:
237+
# Call the original method
238+
result = super().enum_schema(schema)
239+
240+
# Add tsEnumNames property
241+
if len(schema["members"]) > 0:
242+
result["tsEnumNames"] = [v.name for v in schema["members"]]
243+
244+
return result
245+
246+
247+
def generate_json_schema_v1(
248+
models: List[Type[BaseModel]], enums: List[Type[Enum]]
249+
) -> str:
155250
"""
156251
Create a top-level '_Master_' model with references to each of the actual models.
157252
Generate the schema for this model, which will include the schemas for all the
@@ -162,18 +257,21 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
162257
'[k: string]: any' from being added to every interface. This change is reverted
163258
once the schema has been generated.
164259
"""
165-
model_extras = [m.model_config.get("extra", None) for m in models]
260+
model_extras = [getattr(m.Config, "extra", None) for m in models]
166261

167262
try:
168263
for m in models:
169-
if m.model_config.get("extra", None) != "allow":
170-
m.model_config["extra"] = "forbid"
264+
if getattr(m.Config, "extra", None) != "allow":
265+
m.Config.extra = "forbid"
266+
267+
for e in enums:
268+
add_enum_names_v1(e)
171269

172270
master_model = create_model(
173-
"_Master_", **{m.__name__: (m, ...) for m in models}, __base__=m
271+
"_Master_", **{m.__name__: (m, ...) for m in models}
174272
)
175-
master_model.model_config["extra"] = "forbid"
176-
master_model.model_config["schema_extra"] = staticmethod(clean_schema)
273+
master_model.Config.extra = "forbid"
274+
master_model.Config.schema_extra = staticmethod(clean_schema)
177275

178276
schema = json.loads(master_model.schema_json())
179277

@@ -185,7 +283,7 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
185283
finally:
186284
for m, x in zip(models, model_extras):
187285
if x is not None:
188-
m.model_config["extra"] = x
286+
m.Config.extra = x
189287

190288

191289
def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
@@ -212,7 +310,9 @@ def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
212310
master_model.model_config["extra"] = "forbid"
213311
master_model.model_config["json_schema_extra"] = staticmethod(clean_schema)
214312

215-
schema: dict = master_model.model_json_schema(mode="serialization")
313+
schema: dict = master_model.model_json_schema(
314+
schema_generator=CustomGenerateJsonSchema, mode="serialization"
315+
)
216316

217317
for d in schema.get("$defs", {}).values():
218318
clean_schema(d)
@@ -252,14 +352,20 @@ def generate_typescript_defs(
252352

253353
logger.info("Generating JSON schema from pydantic models...")
254354

255-
schema = generate_json_schema_v2(models) if V2 else generate_json_schema_v1(models)
355+
if V2:
356+
schema = generate_json_schema_v2(models)
357+
else:
358+
enums = extract_enum_models(models)
359+
schema = generate_json_schema_v1(models, enums)
256360

257361
schema_dir = mkdtemp()
258362
schema_file_path = os.path.join(schema_dir, "schema.json")
259363

260364
with open(schema_file_path, "w") as f:
261365
f.write(schema)
262366

367+
DEBUG = os.environ.get("DEBUG", False)
368+
263369
if DEBUG:
264370
debug_schema_file_path = Path(module).parent / "schema_debug.json"
265371
# raise ValueError(module)

tests/expected_results/__init__.py

Whitespace-only changes.

tests/expected_results/enums/__init__.py

Whitespace-only changes.

tests/expected_results/enums/v1/__init__.py

Whitespace-only changes.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
import sys
3+
from enum import Enum
4+
from typing import List, Optional, Literal
5+
6+
from pydantic import BaseModel
7+
8+
# Make absolute imports work
9+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))
10+
11+
from .schemas.schema_one import ModelOne # noqa: F401
12+
from .schemas.schema_two import ModelTwo # noqa: F401
13+
from schemas.sub_model import SubModel # this tests absolute imports
14+
from schemas.complex import LevelOne # this tests absolute imports in multiple layers
15+
16+
17+
class CatBreed(str, Enum):
18+
domestic_shorthair = "domestic shorthair"
19+
bengal = "bengal"
20+
persian = "persian"
21+
siamese = "siamese"
22+
23+
24+
class Cat(BaseModel):
25+
name: str
26+
age: int
27+
declawed: bool
28+
breed: CatBreed
29+
30+
31+
class DogBreed(str, Enum):
32+
mutt = "mutt"
33+
labrador = "labrador"
34+
golden_retriever = "golden retriever"
35+
36+
37+
class Dog(BaseModel):
38+
name: str
39+
age: int
40+
breed: DogBreed
41+
42+
43+
class AnimalShelter(BaseModel):
44+
address: str
45+
cats: List[Cat]
46+
dogs: List[Dog]
47+
owner: Optional[Dog]
48+
master: Cat
49+
50+
51+
class Standalone(Enum):
52+
something = "something"
53+
anything = "anything"
54+
55+
56+
class ImportedSubModule(BaseModel):
57+
sub: SubModel
58+
59+
60+
class ComplexModelTree(BaseModel):
61+
one: LevelOne

0 commit comments

Comments
 (0)