Skip to content

Commit 56e196d

Browse files
committed
Treat str in an output_type list the same as in a union
1 parent 3aad6fc commit 56e196d

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,25 @@ def build(
173173
return None
174174

175175
multiple = False
176+
allow_text_output = False
176177

177178
output_types_or_markers: Sequence[SimpleOutputTypeOrMarker[OutputDataT]]
178179
if isinstance(output_type, Sequence):
179180
output_types_or_markers = output_type
180-
multiple = True
181+
if str in output_types_or_markers:
182+
allow_text_output = True
183+
output_types_or_markers = [t for t in output_types_or_markers if t is not str]
184+
if len(output_types_or_markers) > 1:
185+
multiple = True
181186
else:
182187
output_types_or_markers = [output_type]
183188

184-
allow_text_output = False
185189
tools: dict[str, OutputTool[OutputDataT]] = {}
186190
for output_type_or_marker in output_types_or_markers:
191+
if output_type_or_marker is str:
192+
allow_text_output = True
193+
continue
194+
187195
tool_name = name
188196
tool_description = description
189197
tool_strict = strict
@@ -293,6 +301,7 @@ def __init__(
293301
self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema)
294302
self.validator = self.function_schema.validator
295303
json_schema = self.function_schema.json_schema
304+
json_schema['description'] = self.function_schema.description
296305
else:
297306
type_adapter: TypeAdapter[Any]
298307
if _utils.is_model_like(output_type):

tests/test_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,8 @@ def test_response_tuple():
390390

391391
@pytest.mark.parametrize(
392392
'input_union_callable',
393-
[lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str],
394-
ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'],
393+
[lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str, lambda: [Foo, str]],
394+
ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str', '[Foo, str]'],
395395
)
396396
def test_response_union_allow_str(input_union_callable: Callable[[], Any]):
397397
try:

0 commit comments

Comments
 (0)