@@ -47,6 +47,53 @@ def predict(self, request: ModelInput) -> ModelOutput:
47
47
assert not schema .supports_streaming
48
48
49
49
50
+ def test_truss_schema_pydantic_empty_input ():
51
+ class Model :
52
+ def predict (self ) -> ModelOutput :
53
+ return ModelOutput (output = "hello" )
54
+
55
+ model = Model ()
56
+
57
+ input_signature = inspect .signature (model .predict ).parameters
58
+ output_signature = inspect .signature (model .predict ).return_annotation
59
+
60
+ schema = TrussSchema .from_signature (input_signature , output_signature )
61
+
62
+ assert schema .input_type is None
63
+ assert schema .output_type == ModelOutput
64
+
65
+
66
+ def test_truss_schema_pydantic_empty_output ():
67
+ class Model :
68
+ def predict (self , _ : ModelInput ) -> None :
69
+ return None
70
+
71
+ model = Model ()
72
+
73
+ input_signature = inspect .signature (model .predict ).parameters
74
+ output_signature = inspect .signature (model .predict ).return_annotation
75
+
76
+ schema = TrussSchema .from_signature (input_signature , output_signature )
77
+
78
+ assert schema .input_type == ModelInput
79
+ assert schema .output_type is None
80
+
81
+
82
+ def test_truss_schema_pydantic_empty_input_and_output ():
83
+ class Model :
84
+ def predict (self ) -> None :
85
+ return None
86
+
87
+ model = Model ()
88
+
89
+ input_signature = inspect .signature (model .predict ).parameters
90
+ output_signature = inspect .signature (model .predict ).return_annotation
91
+
92
+ schema = TrussSchema .from_signature (input_signature , output_signature )
93
+
94
+ assert schema is None
95
+
96
+
50
97
def test_truss_schema_non_pydantic_input ():
51
98
class Model :
52
99
def predict (self , request : str ) -> ModelOutput :
@@ -59,7 +106,8 @@ def predict(self, request: str) -> ModelOutput:
59
106
60
107
schema = TrussSchema .from_signature (input_signature , output_signature )
61
108
62
- assert schema is None
109
+ assert schema .input_type is None
110
+ assert schema .output_type == ModelOutput
63
111
64
112
65
113
def test_truss_schema_non_pydantic_output ():
@@ -74,7 +122,8 @@ def predict(self, request: ModelInput) -> str:
74
122
75
123
schema = TrussSchema .from_signature (input_signature , output_signature )
76
124
77
- assert schema is None
125
+ assert schema .input_type == ModelInput
126
+ assert schema .output_type is None
78
127
79
128
80
129
def test_truss_schema_list_types ():
@@ -218,7 +267,8 @@ async def predict(
218
267
output_signature = inspect .signature (model .predict ).return_annotation
219
268
220
269
schema = TrussSchema .from_signature (input_signature , output_signature )
221
- assert schema is None
270
+ assert schema .input_type == ModelInput
271
+ assert schema .output_type is None
222
272
223
273
224
274
def test_truss_schema_union_non_pydantic ():
@@ -233,7 +283,8 @@ def predict(self, request: ModelInput) -> Union[str, int]:
233
283
234
284
schema = TrussSchema .from_signature (input_signature , output_signature )
235
285
236
- assert schema is None
286
+ assert schema .input_type == ModelInput
287
+ assert schema .output_type is None
237
288
238
289
239
290
def test_truss_schema_async_non_pydantic ():
@@ -269,4 +320,5 @@ def predict(
269
320
270
321
schema = TrussSchema .from_signature (input_signature , output_signature )
271
322
272
- assert schema is None
323
+ assert schema .input_type == ModelInput
324
+ assert schema .output_type is None
0 commit comments