|
28 | 28 | )
|
29 | 29 | from a2a.types import (
|
30 | 30 | InternalError,
|
| 31 | + InvalidParamsError, |
31 | 32 | Message,
|
32 | 33 | MessageSendConfiguration,
|
33 | 34 | MessageSendParams,
|
@@ -1137,3 +1138,136 @@ async def consume_stream():
|
1137 | 1138 |
|
1138 | 1139 | texts = [p.root.text for e in events for p in e.status.message.parts]
|
1139 | 1140 | assert texts == ['Event 0', 'Event 1', 'Event 2']
|
| 1141 | + |
| 1142 | +TERMINAL_TASK_STATES = { |
| 1143 | + TaskState.completed, |
| 1144 | + TaskState.canceled, |
| 1145 | + TaskState.failed, |
| 1146 | + TaskState.rejected, |
| 1147 | +} |
| 1148 | + |
| 1149 | +@pytest.mark.asyncio |
| 1150 | +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) |
| 1151 | +async def test_on_message_send_task_in_terminal_state(terminal_state): |
| 1152 | + """Test on_message_send when task is already in a terminal state.""" |
| 1153 | + task_id = f'terminal_task_{terminal_state.value}' |
| 1154 | + terminal_task = create_sample_task( |
| 1155 | + task_id=task_id, status_state=terminal_state |
| 1156 | + ) |
| 1157 | + |
| 1158 | + mock_task_store = AsyncMock(spec=TaskStore) |
| 1159 | + # The get method of TaskManager calls task_store.get. |
| 1160 | + # We mock TaskManager.get_task which is an async method. |
| 1161 | + # So we should patch that instead. |
| 1162 | + |
| 1163 | + request_handler = DefaultRequestHandler( |
| 1164 | + agent_executor=DummyAgentExecutor(), task_store=mock_task_store |
| 1165 | + ) |
| 1166 | + |
| 1167 | + params = MessageSendParams( |
| 1168 | + message=Message( |
| 1169 | + role=Role.user, |
| 1170 | + messageId='msg_terminal', |
| 1171 | + parts=[], |
| 1172 | + taskId=task_id, |
| 1173 | + ) |
| 1174 | + ) |
| 1175 | + |
| 1176 | + from a2a.utils.errors import ServerError |
| 1177 | + |
| 1178 | + # Patch the TaskManager's get_task method to return our terminal task |
| 1179 | + with patch( |
| 1180 | + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', |
| 1181 | + return_value=terminal_task, |
| 1182 | + ): |
| 1183 | + with pytest.raises(ServerError) as exc_info: |
| 1184 | + await request_handler.on_message_send( |
| 1185 | + params, create_server_call_context() |
| 1186 | + ) |
| 1187 | + |
| 1188 | + assert isinstance(exc_info.value.error, InvalidParamsError) |
| 1189 | + assert exc_info.value.error.message |
| 1190 | + assert ( |
| 1191 | + f'Task {task_id} is in terminal state: {terminal_state.value}' |
| 1192 | + in exc_info.value.error.message |
| 1193 | + ) |
| 1194 | + |
| 1195 | + |
| 1196 | +@pytest.mark.asyncio |
| 1197 | +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) |
| 1198 | +async def test_on_message_send_stream_task_in_terminal_state(terminal_state): |
| 1199 | + """Test on_message_send_stream when task is already in a terminal state.""" |
| 1200 | + task_id = f'terminal_stream_task_{terminal_state.value}' |
| 1201 | + terminal_task = create_sample_task( |
| 1202 | + task_id=task_id, status_state=terminal_state |
| 1203 | + ) |
| 1204 | + |
| 1205 | + mock_task_store = AsyncMock(spec=TaskStore) |
| 1206 | + |
| 1207 | + request_handler = DefaultRequestHandler( |
| 1208 | + agent_executor=DummyAgentExecutor(), task_store=mock_task_store |
| 1209 | + ) |
| 1210 | + |
| 1211 | + params = MessageSendParams( |
| 1212 | + message=Message( |
| 1213 | + role=Role.user, |
| 1214 | + messageId='msg_terminal_stream', |
| 1215 | + parts=[], |
| 1216 | + taskId=task_id, |
| 1217 | + ) |
| 1218 | + ) |
| 1219 | + |
| 1220 | + from a2a.utils.errors import ServerError |
| 1221 | + |
| 1222 | + with patch( |
| 1223 | + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', |
| 1224 | + return_value=terminal_task, |
| 1225 | + ): |
| 1226 | + with pytest.raises(ServerError) as exc_info: |
| 1227 | + async for _ in request_handler.on_message_send_stream( |
| 1228 | + params, create_server_call_context() |
| 1229 | + ): |
| 1230 | + pass # pragma: no cover |
| 1231 | + |
| 1232 | + assert isinstance(exc_info.value.error, InvalidParamsError) |
| 1233 | + assert exc_info.value.error.message |
| 1234 | + assert ( |
| 1235 | + f'Task {task_id} is in terminal state: {terminal_state.value}' |
| 1236 | + in exc_info.value.error.message |
| 1237 | + ) |
| 1238 | + |
| 1239 | + |
| 1240 | +@pytest.mark.asyncio |
| 1241 | +@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) |
| 1242 | +async def test_on_resubscribe_to_task_in_terminal_state(terminal_state): |
| 1243 | + """Test on_resubscribe_to_task when task is in a terminal state.""" |
| 1244 | + task_id = f'resub_terminal_task_{terminal_state.value}' |
| 1245 | + terminal_task = create_sample_task( |
| 1246 | + task_id=task_id, status_state=terminal_state |
| 1247 | + ) |
| 1248 | + |
| 1249 | + mock_task_store = AsyncMock(spec=TaskStore) |
| 1250 | + mock_task_store.get.return_value = terminal_task |
| 1251 | + |
| 1252 | + request_handler = DefaultRequestHandler( |
| 1253 | + agent_executor=DummyAgentExecutor(), |
| 1254 | + task_store=mock_task_store, |
| 1255 | + queue_manager=AsyncMock(spec=QueueManager), |
| 1256 | + ) |
| 1257 | + params = TaskIdParams(id=task_id) |
| 1258 | + |
| 1259 | + from a2a.utils.errors import ServerError |
| 1260 | + |
| 1261 | + with pytest.raises(ServerError) as exc_info: |
| 1262 | + async for _ in request_handler.on_resubscribe_to_task( |
| 1263 | + params, create_server_call_context() |
| 1264 | + ): |
| 1265 | + pass # pragma: no cover |
| 1266 | + |
| 1267 | + assert isinstance(exc_info.value.error, InvalidParamsError) |
| 1268 | + assert exc_info.value.error.message |
| 1269 | + assert ( |
| 1270 | + f'Task {task_id} is in terminal state: {terminal_state.value}' |
| 1271 | + in exc_info.value.error.message |
| 1272 | + ) |
| 1273 | + mock_task_store.get.assert_awaited_once_with(task_id) |
0 commit comments