Skip to content

Commit f20bfe7

Browse files
authored
feat: Add request context builder with referenceTasks (#56)
1 parent 429c422 commit f20bfe7

File tree

7 files changed

+109
-16
lines changed

7 files changed

+109
-16
lines changed

src/a2a/server/agent_execution/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
from a2a.server.agent_execution.agent_executor import AgentExecutor
44
from a2a.server.agent_execution.context import RequestContext
5+
from a2a.server.agent_execution.request_context_builder import (
6+
RequestContextBuilder,
7+
)
8+
from a2a.server.agent_execution.simple_request_context_builder import (
9+
SimpleRequestContextBuilder,
10+
)
511

612

7-
__all__ = ['AgentExecutor', 'RequestContext']
13+
__all__ = [
14+
'AgentExecutor',
15+
'RequestContext',
16+
'RequestContextBuilder',
17+
'SimpleRequestContextBuilder',
18+
]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
3+
from a2a.server.agent_execution import RequestContext
4+
from a2a.types import MessageSendParams, Task
5+
6+
7+
class RequestContextBuilder(ABC):
8+
"""Builds request context to be supplied to agent executor"""
9+
10+
@abstractmethod
11+
async def build(
12+
self,
13+
params: MessageSendParams | None = None,
14+
task_id: str | None = None,
15+
context_id: str | None = None,
16+
task: Task | None = None,
17+
) -> RequestContext:
18+
pass
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import asyncio
2+
3+
from a2a.server.agent_execution import RequestContext, RequestContextBuilder
4+
from a2a.server.tasks import TaskStore
5+
from a2a.types import MessageSendParams, Task
6+
7+
8+
class SimpleRequestContextBuilder(RequestContextBuilder):
9+
"""Builds request context and populates referred tasks"""
10+
11+
def __init__(
12+
self,
13+
should_populate_referred_tasks: bool = False,
14+
task_store: TaskStore | None = None,
15+
) -> None:
16+
self._task_store = task_store
17+
self._should_populate_referred_tasks = should_populate_referred_tasks
18+
19+
async def build(
20+
self,
21+
params: MessageSendParams | None = None,
22+
task_id: str | None = None,
23+
context_id: str | None = None,
24+
task: Task | None = None,
25+
) -> RequestContext:
26+
related_tasks: list[Task] | None = None
27+
28+
if (
29+
self._task_store
30+
and self._should_populate_referred_tasks
31+
and params
32+
and params.message.referenceTaskIds
33+
):
34+
tasks = await asyncio.gather(
35+
*[
36+
self._task_store.get(task_id)
37+
for task_id in params.message.referenceTaskIds
38+
]
39+
)
40+
related_tasks = [x for x in tasks if x is not None]
41+
42+
return RequestContext(
43+
request=params,
44+
task_id=task_id,
45+
context_id=context_id,
46+
task=task,
47+
related_tasks=related_tasks,
48+
)

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from collections.abc import AsyncGenerator
55
from typing import cast
66

7-
from a2a.server.agent_execution import AgentExecutor, RequestContext
7+
from a2a.server.agent_execution import (
8+
AgentExecutor,
9+
RequestContext,
10+
RequestContextBuilder,
11+
SimpleRequestContextBuilder,
12+
)
813
from a2a.server.events import (
914
Event,
1015
EventConsumer,
@@ -57,6 +62,7 @@ def __init__(
5762
task_store: TaskStore,
5863
queue_manager: QueueManager | None = None,
5964
push_notifier: PushNotifier | None = None,
65+
request_context_builder: RequestContextBuilder | None = None,
6066
) -> None:
6167
"""Initializes the DefaultRequestHandler.
6268
@@ -70,6 +76,12 @@ def __init__(
7076
self.task_store = task_store
7177
self._queue_manager = queue_manager or InMemoryQueueManager()
7278
self._push_notifier = push_notifier
79+
self._request_context_builder = (
80+
request_context_builder
81+
or SimpleRequestContextBuilder(
82+
should_populate_referred_tasks=False, task_store=self.task_store
83+
)
84+
)
7385
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
7486
self._running_agents = {}
7587
self._running_agents_lock = asyncio.Lock()
@@ -167,12 +179,13 @@ async def on_message_send(
167179
await self._push_notifier.set_info(
168180
task.id, params.configuration.pushNotificationConfig
169181
)
170-
request_context = RequestContext(
171-
params,
172-
task.id if task else None,
173-
task.contextId if task else None,
174-
task,
182+
request_context = await self._request_context_builder.build(
183+
params=params,
184+
task_id=task.id if task else None,
185+
context_id=params.message.contextId,
186+
task=task,
175187
)
188+
176189
task_id = cast(str, request_context.task_id)
177190
# Always assign a task ID. We may not actually upgrade to a task, but
178191
# dictating the task ID at this layer is useful for tracking running
@@ -244,12 +257,13 @@ async def on_message_send_stream(
244257
else:
245258
queue = EventQueue()
246259
result_aggregator = ResultAggregator(task_manager)
247-
request_context = RequestContext(
248-
params,
249-
task.id if task else None,
250-
task.contextId if task else None,
251-
task,
260+
request_context = await self._request_context_builder.build(
261+
params=params,
262+
task_id=task.id if task else None,
263+
context_id=params.message.contextId,
264+
task=task,
252265
)
266+
253267
task_id = cast(str, request_context.task_id)
254268
queue = await self._queue_manager.create_or_tap(task_id)
255269
producer_task = asyncio.create_task(

src/a2a/server/tasks/task_updater.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def start_work(self, message: Message | None = None):
114114
def new_agent_message(
115115
self,
116116
parts: list[Part],
117-
final: bool | None = None,
118117
metadata: dict[str, Any] | None = None,
119118
) -> Message:
120119
"""Creates a new message object sent by the agent for this task/context.
@@ -136,6 +135,5 @@ def new_agent_message(
136135
contextId=self.context_id,
137136
messageId=str(uuid.uuid4()),
138137
metadata=metadata,
139-
final=final,
140138
parts=parts,
141139
)

src/a2a/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,10 @@ class Message(BaseModel):
613613
"""Extension metadata."""
614614
parts: list[Part]
615615
"""Message content."""
616+
referenceTaskIds: list[str] | None = None
617+
"""
618+
list of tasks referenced as context by this message.
619+
"""
616620
role: Role
617621
"""Message sender's role."""
618622
taskId: str | None = None

tests/server/tasks/test_task_updater.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def test_new_agent_message(self, task_updater, sample_parts):
212212
assert message.parts == sample_parts
213213
assert message.metadata is None
214214

215-
def test_new_agent_message_with_metadata_and_final(
215+
def test_new_agent_message_with_metadata(
216216
self, task_updater, sample_parts
217217
):
218218
"""Test creating a new agent message with metadata and final=True."""
@@ -223,7 +223,7 @@ def test_new_agent_message_with_metadata_and_final(
223223
return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'),
224224
):
225225
message = task_updater.new_agent_message(
226-
parts=sample_parts, final=True, metadata=metadata
226+
parts=sample_parts, metadata=metadata
227227
)
228228

229229
assert message.role == Role.agent

0 commit comments

Comments
 (0)