Skip to content

Commit a0bf13b

Browse files
authored
feat: raise error for tasks in terminal states (#215)
- make tasks to be not restartable after they reach a terminal state. - Based on discussion in a2aproject/A2A#723
1 parent b88ca85 commit a0bf13b

File tree

2 files changed

+163
-1
lines changed

2 files changed

+163
-1
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from a2a.types import (
3030
GetTaskPushNotificationConfigParams,
3131
InternalError,
32+
InvalidParamsError,
3233
Message,
3334
MessageSendConfiguration,
3435
MessageSendParams,
@@ -38,6 +39,7 @@
3839
TaskNotFoundError,
3940
TaskPushNotificationConfig,
4041
TaskQueryParams,
42+
TaskState,
4143
UnsupportedOperationError,
4244
)
4345
from a2a.utils.errors import ServerError
@@ -46,6 +48,12 @@
4648

4749
logger = logging.getLogger(__name__)
4850

51+
TERMINAL_TASK_STATES = {
52+
TaskState.completed,
53+
TaskState.canceled,
54+
TaskState.failed,
55+
TaskState.rejected,
56+
}
4957

5058
@trace_class(kind=SpanKind.SERVER)
5159
class DefaultRequestHandler(RequestHandler):
@@ -178,6 +186,13 @@ async def on_message_send(
178186
)
179187
task: Task | None = await task_manager.get_task()
180188
if task:
189+
if task.status.state in TERMINAL_TASK_STATES:
190+
raise ServerError(
191+
error=InvalidParamsError(
192+
message=f'Task {task.id} is in terminal state: {task.status.state}'
193+
)
194+
)
195+
181196
task = task_manager.update_with_message(params.message, task)
182197
if self.should_add_push_info(params):
183198
assert isinstance(self._push_notifier, PushNotifier)
@@ -264,8 +279,14 @@ async def on_message_send_stream(
264279
task: Task | None = await task_manager.get_task()
265280

266281
if task:
267-
task = task_manager.update_with_message(params.message, task)
282+
if task.status.state in TERMINAL_TASK_STATES:
283+
raise ServerError(
284+
error=InvalidParamsError(
285+
message=f'Task {task.id} is in terminal state: {task.status.state}'
286+
)
287+
)
268288

289+
task = task_manager.update_with_message(params.message, task)
269290
if self.should_add_push_info(params):
270291
assert isinstance(self._push_notifier, PushNotifier)
271292
assert isinstance(
@@ -413,6 +434,13 @@ async def on_resubscribe_to_task(
413434
if not task:
414435
raise ServerError(error=TaskNotFoundError())
415436

437+
if task.status.state in TERMINAL_TASK_STATES:
438+
raise ServerError(
439+
error=InvalidParamsError(
440+
message=f'Task {task.id} is in terminal state: {task.status.state}'
441+
)
442+
)
443+
416444
task_manager = TaskManager(
417445
task_id=task.id,
418446
context_id=task.contextId,

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from a2a.types import (
3030
InternalError,
31+
InvalidParamsError,
3132
Message,
3233
MessageSendConfiguration,
3334
MessageSendParams,
@@ -1137,3 +1138,136 @@ async def consume_stream():
11371138

11381139
texts = [p.root.text for e in events for p in e.status.message.parts]
11391140
assert texts == ['Event 0', 'Event 1', 'Event 2']
1141+
1142+
TERMINAL_TASK_STATES = {
1143+
TaskState.completed,
1144+
TaskState.canceled,
1145+
TaskState.failed,
1146+
TaskState.rejected,
1147+
}
1148+
1149+
@pytest.mark.asyncio
1150+
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
1151+
async def test_on_message_send_task_in_terminal_state(terminal_state):
1152+
"""Test on_message_send when task is already in a terminal state."""
1153+
task_id = f'terminal_task_{terminal_state.value}'
1154+
terminal_task = create_sample_task(
1155+
task_id=task_id, status_state=terminal_state
1156+
)
1157+
1158+
mock_task_store = AsyncMock(spec=TaskStore)
1159+
# The get method of TaskManager calls task_store.get.
1160+
# We mock TaskManager.get_task which is an async method.
1161+
# So we should patch that instead.
1162+
1163+
request_handler = DefaultRequestHandler(
1164+
agent_executor=DummyAgentExecutor(), task_store=mock_task_store
1165+
)
1166+
1167+
params = MessageSendParams(
1168+
message=Message(
1169+
role=Role.user,
1170+
messageId='msg_terminal',
1171+
parts=[],
1172+
taskId=task_id,
1173+
)
1174+
)
1175+
1176+
from a2a.utils.errors import ServerError
1177+
1178+
# Patch the TaskManager's get_task method to return our terminal task
1179+
with patch(
1180+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
1181+
return_value=terminal_task,
1182+
):
1183+
with pytest.raises(ServerError) as exc_info:
1184+
await request_handler.on_message_send(
1185+
params, create_server_call_context()
1186+
)
1187+
1188+
assert isinstance(exc_info.value.error, InvalidParamsError)
1189+
assert exc_info.value.error.message
1190+
assert (
1191+
f'Task {task_id} is in terminal state: {terminal_state.value}'
1192+
in exc_info.value.error.message
1193+
)
1194+
1195+
1196+
@pytest.mark.asyncio
1197+
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
1198+
async def test_on_message_send_stream_task_in_terminal_state(terminal_state):
1199+
"""Test on_message_send_stream when task is already in a terminal state."""
1200+
task_id = f'terminal_stream_task_{terminal_state.value}'
1201+
terminal_task = create_sample_task(
1202+
task_id=task_id, status_state=terminal_state
1203+
)
1204+
1205+
mock_task_store = AsyncMock(spec=TaskStore)
1206+
1207+
request_handler = DefaultRequestHandler(
1208+
agent_executor=DummyAgentExecutor(), task_store=mock_task_store
1209+
)
1210+
1211+
params = MessageSendParams(
1212+
message=Message(
1213+
role=Role.user,
1214+
messageId='msg_terminal_stream',
1215+
parts=[],
1216+
taskId=task_id,
1217+
)
1218+
)
1219+
1220+
from a2a.utils.errors import ServerError
1221+
1222+
with patch(
1223+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
1224+
return_value=terminal_task,
1225+
):
1226+
with pytest.raises(ServerError) as exc_info:
1227+
async for _ in request_handler.on_message_send_stream(
1228+
params, create_server_call_context()
1229+
):
1230+
pass # pragma: no cover
1231+
1232+
assert isinstance(exc_info.value.error, InvalidParamsError)
1233+
assert exc_info.value.error.message
1234+
assert (
1235+
f'Task {task_id} is in terminal state: {terminal_state.value}'
1236+
in exc_info.value.error.message
1237+
)
1238+
1239+
1240+
@pytest.mark.asyncio
1241+
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
1242+
async def test_on_resubscribe_to_task_in_terminal_state(terminal_state):
1243+
"""Test on_resubscribe_to_task when task is in a terminal state."""
1244+
task_id = f'resub_terminal_task_{terminal_state.value}'
1245+
terminal_task = create_sample_task(
1246+
task_id=task_id, status_state=terminal_state
1247+
)
1248+
1249+
mock_task_store = AsyncMock(spec=TaskStore)
1250+
mock_task_store.get.return_value = terminal_task
1251+
1252+
request_handler = DefaultRequestHandler(
1253+
agent_executor=DummyAgentExecutor(),
1254+
task_store=mock_task_store,
1255+
queue_manager=AsyncMock(spec=QueueManager),
1256+
)
1257+
params = TaskIdParams(id=task_id)
1258+
1259+
from a2a.utils.errors import ServerError
1260+
1261+
with pytest.raises(ServerError) as exc_info:
1262+
async for _ in request_handler.on_resubscribe_to_task(
1263+
params, create_server_call_context()
1264+
):
1265+
pass # pragma: no cover
1266+
1267+
assert isinstance(exc_info.value.error, InvalidParamsError)
1268+
assert exc_info.value.error.message
1269+
assert (
1270+
f'Task {task_id} is in terminal state: {terminal_state.value}'
1271+
in exc_info.value.error.message
1272+
)
1273+
mock_task_store.get.assert_awaited_once_with(task_id)

0 commit comments

Comments
 (0)