-
Notifications
You must be signed in to change notification settings - Fork 921
Support functions as output_type, as well as lists of functions and other types #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Docs Preview
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good, I think we should just confirm the changes to FunctionSchema are okay with Samuel (he'll probably just rubber stamp but still), and I'd like to get @Kludex's take on the crazy OutputType types but I'm okay with it if he is.
Oh we also need to add some typing-related tests |
…checked, not executed
# Conflicts: # pydantic_ai_slim/pydantic_ai/_output.py # pydantic_ai_slim/pydantic_ai/tools.py
a1c793e
to
56e196d
Compare
56e196d
to
3ff6e74
Compare
…er to rerun when the length of the example changes
PR Change SummaryEnhanced output handling to support functions as output types, improving flexibility in agent responses.
Modified Files
How can I customize these reviews?Check out the Hyperlint AI Reviewer docs for more information on how to customize the review. If you just want to ignore it on this PR, you can add the Note specifically for link checks, we only check the first 30 links in a file and we cache the results for several hours (for instance, if you just added a page, you might experience this). Our recommendation is to add |
One thing I'd also like to be able to do here is do a single LLM call with my Agent. Ideally it'd do the same behavior as direct model call: https://ai.pydantic.dev/direct/ Except expose the MCP server, and (optionally) force a Pydantic model response in a single call. Am I missing something or does that interface not exist in the Agent yet? |
What I've done here to make Specifically, none of the three support from __future__ import annotations
from dataclasses import dataclass
from typing_extensions import (
Generic,
Sequence,
TypeVar,
assert_type,
)
T = TypeVar("T")
@dataclass
class Agent(Generic[T]):
output_type: Sequence[type[T]]
class Foo:
pass
class Bar:
pass
# pyright - works
# mypy - error: Expression is of type "Agent[object]", not "Agent[Foo | Bar]" [assert-type]
# pyrefly - assert_type(Agent[Foo], Agent[Bar | Foo]) failed + Argument `list[type[Bar] | type[Foo]]` is not assignable to parameter `output_type` with type `Sequence[type[Foo]]` in function `Agent.__init__`
# ty - `Agent[Foo | Bar]` and `Agent[Unknown]` are not equivalent types
assert_type(Agent([Foo, Bar]), Agent[Foo | Bar])
# pyright - works
# mypy - error: Expression is of type "Agent[Never]", not "Agent[int | str]" [assert-type]
# pyrefly - assert_type(Agent[int], Agent[str | int]) failed + Argument `list[type[str] | type[int]]` is not assignable to parameter `output_type` with type `Sequence[type[int]]` in function `Agent.__init__`
# ty - `Agent[int | str]` and `Agent[Unknown]` are not equivalent types
assert_type(Agent([int, str]), Agent[int | str])
# works
assert_type(Agent[Foo | Bar]([Foo, Bar]), Agent[Foo | Bar])
# works
assert_type(Agent[int | str]([int, str]), Agent[int | str]) Ty doesn't support from dataclasses import dataclass
from typing_extensions import (
Callable,
Generic,
TypeVar,
assert_type,
)
T = TypeVar("T")
@dataclass
class Agent(Generic[T]):
output_type: Callable[..., T]
def func() -> int:
return 1
# pyright, mypy, pyrefly - works
# ty - `Agent[int]` and `Agent[Unknown]` are not equivalent types + Expected `((...) -> T) | ((...) -> Awaitable[T])`, found `def func() -> int`
assert_type(Agent(func), Agent[int])
# works
assert_type(Agent[int](func), Agent[int]) And mypy (and ty, because of the above issue) don't support from dataclasses import dataclass
from typing_extensions import (
Awaitable,
Callable,
Generic,
TypeVar,
assert_type,
)
T = TypeVar("T")
@dataclass
class Agent(Generic[T]):
output_type: Callable[..., T] | Callable[..., Awaitable[T]]
async def coro() -> bool:
return True
def func() -> int:
return 1
# pyright, mypy, pyrefly - works
# ty - `Agent[int]` and `Agent[Unknown]` are not equivalent types + Expected `((...) -> T) | ((...) -> Awaitable[T])`, found `def func() -> int`
assert_type(Agent(func), Agent[int])
# mypy - error: Argument 1 to "Agent" has incompatible type "Callable[[], Coroutine[Any, Any, bool]]"; expected "Callable[..., Never] | Callable[..., Awaitable[Never]]" [arg-type]
coro_agent = Agent(coro)
# pyright, pyrefly - works
# mypy - error: Expression is of type "Agent[Any]", not "Agent[bool]"
# ty - `Agent[bool]` and `Agent[Unknown]` are not equivalent types
assert_type(coro_agent, Agent[bool])
# works
assert_type(Agent[bool](coro), Agent[bool]) The issue with @dataclass
class Output(Generic[T]):
output_type: type[T]
def or_[S](self, output_type: type[S]) -> Output[T | S]:
return Output(self.output_type | output_type) # type: ignore
# or
def output_type[T](*args: type[T]) -> Output[T]:
raise NotImplementedError The issue with We could also decide we're fine with all of this because it works correctly on all typecheckers when you explicitly specify the generic parameters with I've filed some issues with the type checkers to see what their teams say. |
In discussion with @dmontagu we decided to merge this as is, with
|
When would this be released? |
@DouweM in the case of using output_type for agent delegation like in the router scenario (as described in the docs here and here), we'd expect it to operate like a directed graph but with a "decision point". The final Is there way to add the results of the delegated agent back into the main message history so it's something more like |
from dotenv import load_dotenv
from pydantic_ai import Agent, RunContext
load_dotenv()
maths_agent = Agent(
model="google-gla:gemini-2.0-flash",
instructions="You are a maths tutor. Given a question, you will provide a step by step solution.",
)
async def hand_off_to_maths_agent(ctx: RunContext, query: str) -> str:
res = await maths_agent.run(query)
ctx.messages += res.new_messages()
return res.output
poet_agent = Agent(
model="google-gla:gemini-2.0-flash",
instructions="You are a poet. Given a topic, you will provide a poem.",
)
async def hand_off_to_poet_agent(ctx: RunContext, query: str) -> str:
res = await poet_agent.run(query)
ctx.messages += res.new_messages()
return res.output
router_agent = Agent(
model="google-gla:gemini-2.0-flash",
instructions="You are a router. Given a user query, you will route it to the appropriate agent.",
output_type=[hand_off_to_maths_agent, hand_off_to_poet_agent],
)
async def main():
query = "Calculate 10 + 10"
result = await router_agent.run(query)
for message in result.all_messages():
print(message, "\n")
print(result.output)
if __name__ == "__main__":
import asyncio
asyncio.run(main()) |
@DouweM @HamzaFarhan great thank-you! After implementing this, there seems to be an error if we want to maintain state across multiple runs of the router in a case like:
This raises a ModelHTTPError for both Gemini and OpenAI models (haven't tried any others). It looks like a validator maybe isn't looking ahead? The messages definitely contain a matching
|
Ooh right |
For what it's worth, here's a hack: from dataclasses import replace
from dotenv import load_dotenv
from pydantic_ai import Agent, RunContext
from pydantic_ai.messages import ModelMessage, ModelRequest, ToolCallPart, ToolReturnPart
load_dotenv()
maths_agent = Agent(
model="google-gla:gemini-2.0-flash",
instructions="You are a maths tutor. Given a question, you will provide a step by step solution.",
)
async def hand_off_to_maths_agent(ctx: RunContext, query: str) -> str:
res = await maths_agent.run(query)
ctx.messages += res.new_messages()
return res.output
poet_agent = Agent(
model="google-gla:gemini-2.0-flash",
instructions="You are a poet. Given a topic, you will provide a poem.",
)
async def hand_off_to_poet_agent(ctx: RunContext, query: str) -> str:
res = await poet_agent.run(query)
ctx.messages += res.new_messages()
return res.output
router_agent = Agent(
model="google-gla:gemini-2.0-flash",
instructions="You are a router. Given a user query, you will route it to the appropriate agent.",
output_type=[hand_off_to_maths_agent, hand_off_to_poet_agent],
)
def filter_tool_parts(messages: list[ModelMessage], filter_str: str) -> list[ModelMessage]:
filtered_messages: list[ModelMessage] = []
for message in messages:
if isinstance(message, ModelRequest):
filtered_parts = [
part
for part in message.parts
if not (isinstance(part, ToolReturnPart) and filter_str in part.tool_name)
]
if filtered_parts:
filtered_messages.append(replace(message, parts=filtered_parts))
else:
filtered_parts = [
part
for part in message.parts
if not (isinstance(part, ToolCallPart) and filter_str in part.tool_name)
]
if filtered_parts:
filtered_messages.append(replace(message, parts=filtered_parts))
return filtered_messages
async def main():
query = "Calculate 10 + 10"
result = await router_agent.run(query)
message_history = filter_tool_parts(result.all_messages(), "hand_off")
sep = "\n" + "-" * 100 + "\n"
print(sep)
for message in message_history:
print(message, "\n")
print(f"{sep}{result.output}{sep}")
query = "Write a poem about your answer"
result = await router_agent.run(query, message_history=message_history)
print(result.output)
if __name__ == "__main__":
import asyncio
asyncio.run(main()) |
Consider reviewing the first commit by itself and then the rest as one diff!
Commit 1 brings in some output handling refactoring borrowed from #1628, to make sure we don't hard-code this against the tool-call output mode (as the original PR did). Also makes it less of a rebase hell for me :) This commit does not change any behavior.
Example:
To do:
RunContext
; this value should be injected, not obtained from the model.