Skip to content

Commit 90ff956

Browse files
Merge pull request #1680 from basetenlabs/deepakn/BT-14764-support-partial-annotations
BT-14764: Truss should be permissive if either input or output doesn't match Pydantic expectations
2 parents 3e91470 + da94153 commit 90ff956

File tree

3 files changed

+70
-16
lines changed

3 files changed

+70
-16
lines changed

truss/templates/server/common/schema.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,30 @@ class OutputType(BaseModel):
2424

2525

2626
class TrussSchema(BaseModel):
27-
input_type: Type[BaseModel]
27+
input_type: Optional[Type[BaseModel]]
2828
output_type: Optional[Type[BaseModel]]
29-
supports_streaming: bool
29+
supports_streaming: Optional[bool]
3030

3131
@classmethod
3232
def from_signature(
33-
cls, input_parameters: MappingProxyType, output_annotation: Any
33+
cls, input_parameters: Optional[MappingProxyType], output_annotation: Any
3434
) -> Optional["TrussSchema"]:
3535
"""
3636
Create a TrussSchema from a function signature if annotated, else returns None
3737
"""
3838

39-
input_type = _parse_input_type(input_parameters)
40-
output_type = _parse_output_type(output_annotation)
39+
input_type = _parse_input_type(input_parameters) if input_parameters else None
40+
output_type = (
41+
_parse_output_type(output_annotation) if output_annotation else None
42+
)
4143

42-
if not input_type or not output_type:
44+
if not input_type and not output_type:
4345
return None
4446

4547
return cls(
4648
input_type=input_type,
47-
output_type=output_type.type,
48-
supports_streaming=output_type.supports_streaming,
49+
output_type=output_type.type if output_type else None,
50+
supports_streaming=output_type.supports_streaming if output_type else None,
4951
)
5052

5153
def serialize(self) -> dict:
@@ -54,7 +56,7 @@ def serialize(self) -> dict:
5456
generating an OpenAPI spec for this Truss.
5557
"""
5658
return {
57-
"input_schema": self.input_type.schema(),
59+
"input_schema": self.input_type.schema() if self.input_type else None,
5860
"output_schema": self.output_type.schema()
5961
if self.output_type is not None
6062
else None,

truss/templates/server/truss_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ async def _parse_body(
145145
if self.is_binary(request):
146146
with tracing.section_as_event(span, "binary-deserialize"):
147147
inputs = serialization.truss_msgpack_deserialize(body_raw)
148-
if truss_schema:
148+
if truss_schema and truss_schema.input_type:
149149
try:
150150
with tracing.section_as_event(span, "parse-pydantic"):
151151
inputs = truss_schema.input_type.parse_obj(inputs)
@@ -154,7 +154,7 @@ async def _parse_body(
154154
errors.format_pydantic_validation_error(e)
155155
) from e
156156
else:
157-
if truss_schema:
157+
if truss_schema and truss_schema.input_type:
158158
try:
159159
with tracing.section_as_event(span, "parse-pydantic"):
160160
inputs = truss_schema.input_type.parse_raw(body_raw)

truss/tests/templates/server/test_schema.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,53 @@ def predict(self, request: ModelInput) -> ModelOutput:
4747
assert not schema.supports_streaming
4848

4949

50+
def test_truss_schema_pydantic_empty_input():
51+
class Model:
52+
def predict(self) -> ModelOutput:
53+
return ModelOutput(output="hello")
54+
55+
model = Model()
56+
57+
input_signature = inspect.signature(model.predict).parameters
58+
output_signature = inspect.signature(model.predict).return_annotation
59+
60+
schema = TrussSchema.from_signature(input_signature, output_signature)
61+
62+
assert schema.input_type is None
63+
assert schema.output_type == ModelOutput
64+
65+
66+
def test_truss_schema_pydantic_empty_output():
67+
class Model:
68+
def predict(self, _: ModelInput) -> None:
69+
return None
70+
71+
model = Model()
72+
73+
input_signature = inspect.signature(model.predict).parameters
74+
output_signature = inspect.signature(model.predict).return_annotation
75+
76+
schema = TrussSchema.from_signature(input_signature, output_signature)
77+
78+
assert schema.input_type == ModelInput
79+
assert schema.output_type is None
80+
81+
82+
def test_truss_schema_pydantic_empty_input_and_output():
83+
class Model:
84+
def predict(self) -> None:
85+
return None
86+
87+
model = Model()
88+
89+
input_signature = inspect.signature(model.predict).parameters
90+
output_signature = inspect.signature(model.predict).return_annotation
91+
92+
schema = TrussSchema.from_signature(input_signature, output_signature)
93+
94+
assert schema is None
95+
96+
5097
def test_truss_schema_non_pydantic_input():
5198
class Model:
5299
def predict(self, request: str) -> ModelOutput:
@@ -59,7 +106,8 @@ def predict(self, request: str) -> ModelOutput:
59106

60107
schema = TrussSchema.from_signature(input_signature, output_signature)
61108

62-
assert schema is None
109+
assert schema.input_type is None
110+
assert schema.output_type == ModelOutput
63111

64112

65113
def test_truss_schema_non_pydantic_output():
@@ -74,7 +122,8 @@ def predict(self, request: ModelInput) -> str:
74122

75123
schema = TrussSchema.from_signature(input_signature, output_signature)
76124

77-
assert schema is None
125+
assert schema.input_type == ModelInput
126+
assert schema.output_type is None
78127

79128

80129
def test_truss_schema_list_types():
@@ -218,7 +267,8 @@ async def predict(
218267
output_signature = inspect.signature(model.predict).return_annotation
219268

220269
schema = TrussSchema.from_signature(input_signature, output_signature)
221-
assert schema is None
270+
assert schema.input_type == ModelInput
271+
assert schema.output_type is None
222272

223273

224274
def test_truss_schema_union_non_pydantic():
@@ -233,7 +283,8 @@ def predict(self, request: ModelInput) -> Union[str, int]:
233283

234284
schema = TrussSchema.from_signature(input_signature, output_signature)
235285

236-
assert schema is None
286+
assert schema.input_type == ModelInput
287+
assert schema.output_type is None
237288

238289

239290
def test_truss_schema_async_non_pydantic():
@@ -269,4 +320,5 @@ def predict(
269320

270321
schema = TrussSchema.from_signature(input_signature, output_signature)
271322

272-
assert schema is None
323+
assert schema.input_type == ModelInput
324+
assert schema.output_type is None

0 commit comments

Comments
 (0)