Skip to content

Generate Typescript enums instead of string types #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/cicd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
pydantic-version: ["<2.0.0", ">=2.0.0"]
steps:
- name: Check out repo
uses: actions/checkout@v4
Expand All @@ -37,6 +38,7 @@ jobs:
- name: Install python dependencies
run: |
python -m pip install -U pip wheel pytest pytest-cov coverage
python -m pip install -U "pydantic${{ matrix.pydantic-version }}"
python -m pip install -U .
- name: Run tests
run: |
Expand Down
150 changes: 128 additions & 22 deletions pydantic2ts/cli/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@
import os
import shutil
import sys
from enum import Enum
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path
from tempfile import mkdtemp
from types import ModuleType
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Tuple, Type, get_origin, get_args, Set, cast
from uuid import uuid4

from pydantic import VERSION, BaseModel, create_model

try:
from types import UnionType
except ImportError:
UnionType = None

V2 = True if VERSION.startswith("2") else False

if not V2:
Expand All @@ -23,10 +29,16 @@
except ImportError:
GenericModel = None

logger = logging.getLogger("pydantic2ts")

if V2:
try:
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import core_schema
except ImportError:
GenerateJsonSchema = None
JsonSchemaValue = None
core_schema = None

DEBUG = os.environ.get("DEBUG", False)
logger = logging.getLogger("pydantic2ts")


def import_module(path: str) -> ModuleType:
Expand Down Expand Up @@ -63,7 +75,7 @@ def is_submodule(obj, module_name: str) -> bool:
)


def is_concrete_pydantic_model(obj) -> bool:
def is_concrete_pydantic_model(obj: type) -> bool:
"""
Return true if an object is a concrete subclass of pydantic's BaseModel.
'concrete' meaning that it's not a GenericModel.
Expand All @@ -81,24 +93,80 @@ def is_concrete_pydantic_model(obj) -> bool:
return issubclass(obj, BaseModel)


def is_enum(obj) -> bool:
"""
Return true if an object is an Enum.
"""
return inspect.isclass(obj) and issubclass(obj, Enum)


def flatten_types(field_type: type) -> Set[type]:
types = set()

origin = get_origin(field_type)
if origin is None:
types.add(field_type)
else:
args = get_args(field_type)
for arg in args:
types.update(flatten_types(arg))

return types


def get_model_fields(model: Type[BaseModel]) -> Dict[str, Any]:
if V2:
return model.model_fields
else:
return model.__fields__


def extract_pydantic_models_from_model(
model: Type[BaseModel], all_models: List[Type[BaseModel]]
) -> None:
"""
Given a pydantic model, add the pydantic models contained within it to all_models.
"""
if model in all_models:
return

all_models.append(model)

for field, field_type in get_model_fields(model).items():
flattened_types = flatten_types(field_type.annotation)
for inner_type in flattened_types:
if is_concrete_pydantic_model(inner_type):
extract_pydantic_models_from_model(inner_type, all_models)


def extract_pydantic_models(module: ModuleType) -> List[Type[BaseModel]]:
"""
Given a module, return a list of the pydantic models contained within it.
"""
models = []
module_name = module.__name__

for _, model in inspect.getmembers(module, is_concrete_pydantic_model):
models.append(model)

for _, submodule in inspect.getmembers(
module, lambda obj: is_submodule(obj, module_name)
):
models.extend(extract_pydantic_models(submodule))
extract_pydantic_models_from_model(model, models)

return models


def extract_enum_models(models: List[Type[BaseModel]]) -> List[Type[Enum]]:
"""
Given a list of pydantic models, return a list of the Enum classes used as fields within those models.
"""
enums = []

for model in models:
for field_type in get_model_fields(model).values():
flattened_types = flatten_types(field_type.annotation)
for inner_type in flattened_types:
if is_enum(inner_type):
enums.append(cast(Type[Enum], inner_type))

return enums


def clean_output_file(output_filename: str) -> None:
"""
Clean up the output file typescript definitions were written to by:
Expand Down Expand Up @@ -151,7 +219,34 @@ def clean_schema(schema: Dict[str, Any]) -> None:
del schema["description"]


def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
def add_enum_names_v1(model: Type[Enum]) -> None:
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]):
if len(model.__members__.keys()) == len(field_schema["enum"]):
field_schema.update(tsEnumNames=list(model.__members__.keys()))
for name, value in zip(field_schema["tsEnumNames"], field_schema["enum"]):
assert cls[name].value == value

setattr(model, "__modify_schema__", __modify_schema__)


if V2:

class CustomGenerateJsonSchema(GenerateJsonSchema):
def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue:
# Call the original method
result = super().enum_schema(schema)

# Add tsEnumNames property
if len(schema["members"]) > 0:
result["tsEnumNames"] = [v.name for v in schema["members"]]

return result


def generate_json_schema_v1(
models: List[Type[BaseModel]], enums: List[Type[Enum]]
) -> str:
"""
Create a top-level '_Master_' model with references to each of the actual models.
Generate the schema for this model, which will include the schemas for all the
Expand All @@ -162,18 +257,21 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
'[k: string]: any' from being added to every interface. This change is reverted
once the schema has been generated.
"""
model_extras = [m.model_config.get("extra", None) for m in models]
model_extras = [getattr(m.Config, "extra", None) for m in models]

try:
for m in models:
if m.model_config.get("extra", None) != "allow":
m.model_config["extra"] = "forbid"
if getattr(m.Config, "extra", None) != "allow":
m.Config.extra = "forbid"

for e in enums:
add_enum_names_v1(e)

master_model = create_model(
"_Master_", **{m.__name__: (m, ...) for m in models}, __base__=m
"_Master_", **{m.__name__: (m, ...) for m in models}
)
master_model.model_config["extra"] = "forbid"
master_model.model_config["schema_extra"] = staticmethod(clean_schema)
master_model.Config.extra = "forbid"
master_model.Config.schema_extra = staticmethod(clean_schema)

schema = json.loads(master_model.schema_json())

Expand All @@ -185,7 +283,7 @@ def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str:
finally:
for m, x in zip(models, model_extras):
if x is not None:
m.model_config["extra"] = x
m.Config.extra = x


def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
Expand All @@ -212,7 +310,9 @@ def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str:
master_model.model_config["extra"] = "forbid"
master_model.model_config["json_schema_extra"] = staticmethod(clean_schema)

schema: dict = master_model.model_json_schema(mode="serialization")
schema: dict = master_model.model_json_schema(
schema_generator=CustomGenerateJsonSchema, mode="serialization"
)

for d in schema.get("$defs", {}).values():
clean_schema(d)
Expand Down Expand Up @@ -252,14 +352,20 @@ def generate_typescript_defs(

logger.info("Generating JSON schema from pydantic models...")

schema = generate_json_schema_v2(models) if V2 else generate_json_schema_v1(models)
if V2:
schema = generate_json_schema_v2(models)
else:
enums = extract_enum_models(models)
schema = generate_json_schema_v1(models, enums)

schema_dir = mkdtemp()
schema_file_path = os.path.join(schema_dir, "schema.json")

with open(schema_file_path, "w") as f:
f.write(schema)

DEBUG = os.environ.get("DEBUG", False)

if DEBUG:
debug_schema_file_path = Path(module).parent / "schema_debug.json"
# raise ValueError(module)
Expand Down
Empty file.
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions tests/expected_results/enums/v1/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import sys
from enum import Enum
from typing import List, Optional, Literal

from pydantic import BaseModel

# Make absolute imports work
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__))))

from .schemas.schema_one import ModelOne # noqa: F401
from .schemas.schema_two import ModelTwo # noqa: F401
from schemas.sub_model import SubModel # this tests absolute imports
from schemas.complex import LevelOne # this tests absolute imports in multiple layers


class CatBreed(str, Enum):
domestic_shorthair = "domestic shorthair"
bengal = "bengal"
persian = "persian"
siamese = "siamese"


class Cat(BaseModel):
name: str
age: int
declawed: bool
breed: CatBreed


class DogBreed(str, Enum):
mutt = "mutt"
labrador = "labrador"
golden_retriever = "golden retriever"


class Dog(BaseModel):
name: str
age: int
breed: DogBreed


class AnimalShelter(BaseModel):
address: str
cats: List[Cat]
dogs: List[Dog]
owner: Optional[Dog]
master: Cat


class Standalone(Enum):
something = "something"
anything = "anything"


class ImportedSubModule(BaseModel):
sub: SubModel


class ComplexModelTree(BaseModel):
one: LevelOne
Loading
Loading