Skip to content

Commit 8adc284

Browse files
committed
Fix test to work like the reproducer
1 parent b504f21 commit 8adc284

2 files changed

Lines changed: 98 additions & 64 deletions

File tree

tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java

Lines changed: 85 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -885,34 +885,81 @@ public void testSubscribeExistingTaskSuccessWithClientConsumers() throws Excepti
885885
* Interrupted states are NOT terminal - the stream should remain open to deliver future state updates.
886886
* <p>
887887
* This test addresses issue #754: Stream was incorrectly closing immediately for INPUT_REQUIRED state.
888+
* The bug had two parts:
889+
* 1. isStreamTerminatingTask() incorrectly treated INPUT_REQUIRED as terminating
890+
* 2. Grace period logic closed queue after agent completion, even for interrupted states
888891
*/
889892
@Test
890893
@Timeout(value = 3, unit = TimeUnit.MINUTES)
891894
public void testSubscribeToTaskWithInterruptedStateKeepsStreamOpen() throws Exception {
892-
// Create task in INPUT_REQUIRED state (interrupted, not terminal)
893-
Task inputRequiredTask = Task.builder()
894-
.id("task-input-required-" + UUID.randomUUID())
895-
.contextId("session-xyz")
896-
.status(new TaskStatus(TaskState.TASK_STATE_INPUT_REQUIRED))
897-
.build();
895+
// Use a taskId with the pattern the test agent recognizes
896+
// When we send a message with a taskId to a non-existent task, it creates
897+
// a new task with that ID, and context.getTask() is still null on first invocation
898+
String taskId = "input-required-test-" + UUID.randomUUID();
898899

899-
saveTaskInTaskStore(inputRequiredTask);
900900
try {
901-
ensureQueueForTask(inputRequiredTask.id());
901+
// Create initial message with the special taskId pattern
902+
// Use non-streaming client so agent can emit INPUT_REQUIRED and return immediately
903+
// This ensures context.getTask() == null on first agent invocation
904+
Message message = Message.builder(MESSAGE)
905+
.taskId(taskId)
906+
.contextId("test-context")
907+
.parts(new TextPart("Trigger INPUT_REQUIRED"))
908+
.build();
909+
910+
// Send message with non-streaming client - agent will emit INPUT_REQUIRED and complete
911+
AtomicReference<TaskState> finalStateRef = new AtomicReference<>();
912+
AtomicReference<Throwable> sendErrorRef = new AtomicReference<>();
913+
CountDownLatch sendLatch = new CountDownLatch(1);
914+
915+
getNonStreamingClient().sendMessage(message, List.of((event, agentCard) -> {
916+
if (event instanceof TaskEvent te) {
917+
finalStateRef.set(te.getTask().status().state());
918+
sendLatch.countDown();
919+
} else if (event instanceof TaskUpdateEvent tue) {
920+
if (tue.getUpdateEvent() instanceof TaskStatusUpdateEvent statusUpdate) {
921+
finalStateRef.set(statusUpdate.status().state());
922+
}
923+
}
924+
}), error -> {
925+
if (!isStreamClosedError(error)) {
926+
sendErrorRef.set(error);
927+
}
928+
sendLatch.countDown();
929+
});
902930

903-
// Track events received through the stream
931+
assertTrue(sendLatch.await(15, TimeUnit.SECONDS), "SendMessage should complete");
932+
assertNull(sendErrorRef.get(), "SendMessage should not error");
933+
TaskState finalState = finalStateRef.get();
934+
assertNotNull(finalState, "Final state should be captured");
935+
assertEquals(TaskState.TASK_STATE_INPUT_REQUIRED, finalState,
936+
"Task should be in INPUT_REQUIRED state after agent completes");
937+
938+
// CRITICAL: At this point the agent has completed with INPUT_REQUIRED state
939+
// The grace period logic should NOT close the queue because INPUT_REQUIRED
940+
// is an interrupted state, not a terminal state
941+
942+
// Wait 2 seconds - longer than the grace period (1.5 seconds)
943+
// Before fix: queue would close after grace period
944+
// After fix: queue stays open because task is in interrupted state
945+
Thread.sleep(2000);
946+
947+
// Track events received through subscription stream
904948
CopyOnWriteArrayList<io.a2a.spec.UpdateEvent> receivedEvents = new CopyOnWriteArrayList<>();
905949
AtomicBoolean receivedInitialTask = new AtomicBoolean(false);
906950
AtomicBoolean streamClosedPrematurely = new AtomicBoolean(false);
907-
AtomicReference<Throwable> errorRef = new AtomicReference<>();
951+
AtomicReference<Throwable> subscribeErrorRef = new AtomicReference<>();
908952
CountDownLatch completionLatch = new CountDownLatch(1);
909953

910-
// Consumer to track all events
954+
// Consumer to track all events from subscription
911955
BiConsumer<ClientEvent, AgentCard> consumer = (event, agentCard) -> {
912956
if (event instanceof TaskEvent taskEvent) {
913957
if (!receivedInitialTask.get()) {
914958
receivedInitialTask.set(true);
915-
// First event should be the initial task snapshot
959+
// First event should be the initial task snapshot in INPUT_REQUIRED state
960+
assertEquals(TaskState.TASK_STATE_INPUT_REQUIRED,
961+
taskEvent.getTask().status().state(),
962+
"Initial task should be in INPUT_REQUIRED state");
916963
return;
917964
}
918965
} else if (event instanceof TaskUpdateEvent taskUpdateEvent) {
@@ -929,7 +976,7 @@ public void testSubscribeToTaskWithInterruptedStateKeepsStreamOpen() throws Exce
929976
// Error handler to detect premature stream closure
930977
Consumer<Throwable> errorHandler = error -> {
931978
if (!isStreamClosedError(error)) {
932-
errorRef.set(error);
979+
subscribeErrorRef.set(error);
933980
}
934981
// If completion latch hasn't been counted down yet, stream closed prematurely
935982
if (completionLatch.getCount() > 0) {
@@ -938,59 +985,50 @@ public void testSubscribeToTaskWithInterruptedStateKeepsStreamOpen() throws Exce
938985
completionLatch.countDown();
939986
};
940987

941-
// Subscribe to the task
988+
// Subscribe to the task - this is AFTER agent completed with INPUT_REQUIRED
942989
CountDownLatch subscriptionLatch = new CountDownLatch(1);
943990
awaitStreamingSubscription()
944991
.whenComplete((unused, throwable) -> subscriptionLatch.countDown());
945992

946-
getClient().subscribeToTask(new TaskIdParams(inputRequiredTask.id()), List.of(consumer), errorHandler);
993+
getClient().subscribeToTask(new TaskIdParams(taskId), List.of(consumer), errorHandler);
947994

948995
// Wait for subscription to be established
949996
assertTrue(subscriptionLatch.await(15, TimeUnit.SECONDS), "Subscription should be established");
950997

951-
// Wait a bit to ensure stream doesn't close prematurely
952-
Thread.sleep(500);
953-
954-
// Verify stream is still open (no premature closure)
998+
// Verify stream received initial task and is still open
999+
assertTrue(receivedInitialTask.get(), "Should receive initial task snapshot");
9551000
assertFalse(streamClosedPrematurely.get(),
9561001
"Stream should NOT close for INPUT_REQUIRED state (interrupted, not terminal)");
9571002

958-
// Send status update to WORKING (still non-terminal)
959-
enqueueEventOnServer(TaskStatusUpdateEvent.builder()
960-
.taskId(inputRequiredTask.id())
961-
.contextId(inputRequiredTask.contextId())
962-
.status(new TaskStatus(TaskState.TASK_STATE_WORKING))
963-
.build());
964-
965-
// Wait a bit and verify stream is still open
966-
Thread.sleep(500);
967-
assertFalse(streamClosedPrematurely.get(),
968-
"Stream should remain open after transitioning to WORKING");
1003+
// Send a follow-up message to provide the required input
1004+
// This will trigger the agent again, which will emit COMPLETED
1005+
Message followUpMessage = Message.builder()
1006+
.messageId("input-response-" + UUID.randomUUID())
1007+
.role(Message.Role.ROLE_USER)
1008+
.parts(new TextPart("User input"))
1009+
.taskId(taskId)
1010+
.build();
9691011

970-
// Send terminal status update to COMPLETED
971-
enqueueEventOnServer(TaskStatusUpdateEvent.builder()
972-
.taskId(inputRequiredTask.id())
973-
.contextId(inputRequiredTask.contextId())
974-
.status(new TaskStatus(TaskState.TASK_STATE_COMPLETED))
975-
.build());
1012+
getClient().sendMessage(followUpMessage, List.of(), error -> {});
9761013

977-
// Now stream should close
1014+
// Stream should now close after receiving COMPLETED event
9781015
assertTrue(completionLatch.await(30, TimeUnit.SECONDS),
9791016
"Stream should close after terminal state");
9801017

981-
// Verify we received both updates before stream closed
982-
assertEquals(2, receivedEvents.size(),
983-
"Should receive both status updates before stream closes");
984-
985-
TaskStatusUpdateEvent firstUpdate = (TaskStatusUpdateEvent) receivedEvents.get(0);
986-
assertEquals(TaskState.TASK_STATE_WORKING, firstUpdate.status().state());
1018+
// Verify we received the COMPLETED update
1019+
assertTrue(receivedEvents.size() >= 1,
1020+
"Should receive at least COMPLETED status update");
9871021

988-
TaskStatusUpdateEvent secondUpdate = (TaskStatusUpdateEvent) receivedEvents.get(1);
989-
assertEquals(TaskState.TASK_STATE_COMPLETED, secondUpdate.status().state());
1022+
// Find the COMPLETED event
1023+
boolean foundCompleted = receivedEvents.stream()
1024+
.filter(e -> e instanceof TaskStatusUpdateEvent)
1025+
.map(e -> (TaskStatusUpdateEvent) e)
1026+
.anyMatch(tue -> tue.status().state() == TaskState.TASK_STATE_COMPLETED);
1027+
assertTrue(foundCompleted, "Should receive COMPLETED status update");
9901028

991-
assertNull(errorRef.get(), "Should not have any errors");
1029+
assertNull(subscribeErrorRef.get(), "Should not have any errors");
9921030
} finally {
993-
deleteTaskInTaskStore(inputRequiredTask.id());
1031+
deleteTaskInTaskStore(taskId);
9941032
}
9951033
}
9961034

tests/server-common/src/test/java/io/a2a/server/apps/common/AgentExecutorProducer.java

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,27 +74,23 @@ public void execute(RequestContext context, AgentEmitter agentEmitter) throws A2
7474

7575
// Special handling for input-required test
7676
if (taskId != null && taskId.startsWith("input-required-test")) {
77-
// First call: context.getTask() == null (new task)
78-
if (context.getTask() == null) {
79-
// Go directly to INPUT_REQUIRED without intermediate WORKING state
80-
// This avoids race condition where blocking call interrupts on WORKING
81-
// before INPUT_REQUIRED is persisted to TaskStore
82-
agentEmitter.requiresInput(agentEmitter.newAgentMessage(
83-
List.of(new TextPart("Please provide additional information")),
84-
context.getMessage().metadata()));
85-
// Return immediately - queue stays open because task is in INPUT_REQUIRED state
86-
return;
87-
} else {
88-
String input = extractTextFromMessage(context.getMessage());
89-
if(! "User input".equals(input)) {
90-
throw new InvalidParamsError("We didn't get the expected input");
91-
}
92-
// Second call: context.getTask() != null (input provided)
77+
String input = extractTextFromMessage(context.getMessage());
78+
// Second call: user provided the required input - complete the task
79+
if ("User input".equals(input)) {
9380
// Go directly to COMPLETED without intermediate WORKING state
94-
// This avoids the same race condition as the first call
81+
// This avoids race condition where blocking call interrupts on WORKING
9582
agentEmitter.complete();
9683
return;
9784
}
85+
// First call: any other message - emit INPUT_REQUIRED
86+
// Go directly to INPUT_REQUIRED without intermediate WORKING state
87+
// This avoids race condition where blocking call interrupts on WORKING
88+
// before INPUT_REQUIRED is persisted to TaskStore
89+
agentEmitter.requiresInput(agentEmitter.newAgentMessage(
90+
List.of(new TextPart("Please provide additional information")),
91+
context.getMessage().metadata()));
92+
// Return immediately - queue stays open because task is in INPUT_REQUIRED state
93+
return;
9894
}
9995

10096
// Special handling for auth-required test

0 commit comments

Comments
 (0)