Skip to content

Commit f668b77

Browse files
authored
fix: Keep stream open on interrupted state changes (#756)
Fixes #754
1 parent fd33b3d commit f668b77

5 files changed

Lines changed: 261 additions & 47 deletions

File tree

server-common/src/main/java/io/a2a/server/events/EventConsumer.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public class EventConsumer {
2323
private volatile boolean cancelled = false;
2424
private volatile boolean agentCompleted = false;
2525
private volatile int pollTimeoutsAfterAgentCompleted = 0;
26+
private volatile @Nullable TaskState lastSeenTaskState = null;
2627

2728
private static final String ERROR_MSG = "Agent did not return any response";
2829
private static final int NO_WAIT = -1;
@@ -89,7 +90,12 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
8990
//
9091
// IMPORTANT: In replicated scenarios, remote events may arrive AFTER local agent completes!
9192
// Use grace period to allow for Kafka replication delays (can be 400-500ms)
92-
if (agentCompleted && queueSize == 0) {
93+
//
94+
// CRITICAL: Do NOT close if task is in interrupted state (INPUT_REQUIRED, AUTH_REQUIRED)
95+
// Per A2A spec, interrupted states are NOT terminal - the stream must stay open
96+
// for future state updates even after agent completes (agent will be re-invoked later).
97+
boolean isInterruptedState = lastSeenTaskState != null && lastSeenTaskState.isInterrupted();
98+
if (agentCompleted && queueSize == 0 && !isInterruptedState) {
9399
pollTimeoutsAfterAgentCompleted++;
94100
if (pollTimeoutsAfterAgentCompleted >= MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED) {
95101
LOGGER.debug("Agent completed with {} consecutive poll timeouts and empty queue, closing for graceful completion (queue={})",
@@ -102,6 +108,10 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
102108
LOGGER.debug("Agent completed but grace period active ({}/{} timeouts), continuing to poll (queue={})",
103109
pollTimeoutsAfterAgentCompleted, MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED, System.identityHashCode(queue));
104110
}
111+
} else if (agentCompleted && isInterruptedState) {
112+
LOGGER.debug("Agent completed but task is in interrupted state ({}), stream must remain open (queue={})",
113+
lastSeenTaskState, System.identityHashCode(queue));
114+
pollTimeoutsAfterAgentCompleted = 0; // Reset counter
105115
} else if (agentCompleted && queueSize > 0) {
106116
LOGGER.debug("Agent completed but queue has {} pending events, resetting timeout counter and continuing to poll (queue={})",
107117
queueSize, System.identityHashCode(queue));
@@ -115,6 +125,13 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
115125
LOGGER.debug("EventConsumer received event: {} (queue={})",
116126
event.getClass().getSimpleName(), System.identityHashCode(queue));
117127

128+
// Track the latest task state for grace period logic
129+
if (event instanceof Task task) {
130+
lastSeenTaskState = task.status().state();
131+
} else if (event instanceof TaskStatusUpdateEvent tue) {
132+
lastSeenTaskState = tue.status().state();
133+
}
134+
118135
// Defensive logging for error handling
119136
if (event instanceof Throwable thr) {
120137
LOGGER.debug("EventConsumer detected Throwable event: {} - triggering tube.fail()",
@@ -195,17 +212,21 @@ public Flow.Publisher<EventQueueItem> consumeAll() {
195212

196213
/**
197214
* Determines if a task is in a state for terminating the stream.
198-
* <p>A task is terminating if:</p>
199-
* <ul>
200-
* <li>Its state is final (e.g., completed, canceled, rejected, failed), OR</li>
201-
* <li>Its state is interrupted (e.g., input-required)</li>
202-
* </ul>
215+
* <p>
216+
* Per A2A Protocol Specification 3.1.6 (SubscribeToTask):
217+
* "The stream MUST terminate when the task reaches a terminal state
218+
* (completed, failed, canceled, or rejected)."
219+
* <p>
220+
* Interrupted states (INPUT_REQUIRED, AUTH_REQUIRED) are NOT terminal.
221+
* The stream should remain open to deliver future state updates when
222+
* the task resumes after receiving the required input or authorization.
223+
*
203224
* @param task the task to check
204-
* @return true if the task has a final state or an interrupted state, false otherwise
225+
* @return true if the task has a terminal/final state, false otherwise
205226
*/
206227
private boolean isStreamTerminatingTask(Task task) {
207228
TaskState state = task.status().state();
208-
return state.isFinal() || state == TaskState.TASK_STATE_INPUT_REQUIRED;
229+
return state.isFinal();
209230
}
210231

211232
public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() {

server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,26 +238,32 @@ public void testConsumeMessageEvents() throws Exception {
238238

239239
@Test
240240
public void testConsumeTaskInputRequired() {
241+
// Per A2A Protocol Specification 3.1.6 (SubscribeToTask):
242+
// "The stream MUST terminate when the task reaches a terminal state
243+
// (completed, failed, canceled, or rejected)."
244+
//
245+
// INPUT_REQUIRED is an interrupted state, NOT a terminal state.
246+
// The stream should remain open to deliver future state updates.
241247
Task task = Task.builder()
242248
.id(TASK_ID)
243249
.contextId("session-xyz")
244250
.status(new TaskStatus(TaskState.TASK_STATE_INPUT_REQUIRED))
245251
.build();
246-
List<Event> events = List.of(
247-
task,
248-
TaskArtifactUpdateEvent.builder()
252+
TaskArtifactUpdateEvent artifactEvent = TaskArtifactUpdateEvent.builder()
249253
.taskId(TASK_ID)
250254
.contextId("session-xyz")
251255
.artifact(Artifact.builder()
252256
.artifactId("11")
253257
.parts(new TextPart("text"))
254258
.build())
255-
.build(),
256-
TaskStatusUpdateEvent.builder()
259+
.build();
260+
TaskStatusUpdateEvent completedEvent = TaskStatusUpdateEvent.builder()
257261
.taskId(TASK_ID)
258262
.contextId("session-xyz")
259263
.status(new TaskStatus(TaskState.TASK_STATE_COMPLETED))
260-
.build());
264+
.build();
265+
List<Event> events = List.of(task, artifactEvent, completedEvent);
266+
261267
for (Event event : events) {
262268
eventQueue.enqueueEvent(event);
263269
}
@@ -269,9 +275,12 @@ public void testConsumeTaskInputRequired() {
269275
publisher.subscribe(getSubscriber(receivedEvents, error));
270276

271277
assertNull(error.get());
272-
// The stream is closed after the input_required task
273-
assertEquals(1, receivedEvents.size());
278+
// Stream should remain open for INPUT_REQUIRED and deliver all events
279+
// until the terminal COMPLETED state is reached
280+
assertEquals(3, receivedEvents.size());
274281
assertSame(task, receivedEvents.get(0));
282+
assertSame(artifactEvent, receivedEvents.get(1));
283+
assertSame(completedEvent, receivedEvents.get(2));
275284
}
276285

277286
private Flow.Subscriber<EventQueueItem> getSubscriber(List<Event> receivedEvents, AtomicReference<Throwable> error) {

spec/src/main/java/io/a2a/spec/TaskState.java

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
* TaskState represents the discrete states a task can be in during its execution lifecycle.
77
* States are categorized as either transitional (non-final) or terminal (final), where
88
* terminal states indicate that the task has reached its end state and will not transition further.
9+
* A subset of transitional states are also marked as interrupted, indicating the task execution
10+
* has paused and requires external action before proceeding.
911
* <p>
10-
* <b>Transitional States:</b>
12+
* <b>Active Transitional States:</b>
1113
* <ul>
1214
* <li><b>TASK_STATE_SUBMITTED:</b> Task has been received by the agent and is queued for processing</li>
1315
* <li><b>TASK_STATE_WORKING:</b> Agent is actively processing the task and may produce incremental results</li>
16+
* </ul>
17+
* <p>
18+
* <b>Interrupted States:</b>
19+
* <ul>
1420
* <li><b>TASK_STATE_INPUT_REQUIRED:</b> Agent needs additional input from the user to continue</li>
1521
* <li><b>TASK_STATE_AUTH_REQUIRED:</b> Agent requires authentication or authorization before proceeding</li>
1622
* </ul>
@@ -25,44 +31,47 @@
2531
* </ul>
2632
* <p>
2733
* The {@link #isFinal()} method can be used to determine if a state is terminal, which is
28-
* important for event queue management and client polling logic.
34+
* important for event queue management and client polling logic. The {@link #isInterrupted()}
35+
* method identifies states where the task is paused awaiting external action.
2936
*
3037
* @see TaskStatus
3138
* @see Task
3239
* @see <a href="https://a2a-protocol.org/latest/">A2A Protocol Specification</a>
3340
*/
3441
public enum TaskState {
3542
/** Task has been received and is queued for processing (transitional state). */
36-
TASK_STATE_SUBMITTED(false),
43+
TASK_STATE_SUBMITTED(false, false),
3744

3845
/** Agent is actively processing the task (transitional state). */
39-
TASK_STATE_WORKING(false),
46+
TASK_STATE_WORKING(false, false),
4047

41-
/** Agent requires additional input from the user to continue (transitional state). */
42-
TASK_STATE_INPUT_REQUIRED(false),
48+
/** Agent requires additional input from the user to continue (interrupted state). */
49+
TASK_STATE_INPUT_REQUIRED(false, true),
4350

44-
/** Agent requires authentication or authorization to proceed (transitional state). */
45-
TASK_STATE_AUTH_REQUIRED(false),
51+
/** Agent requires authentication or authorization to proceed (interrupted state). */
52+
TASK_STATE_AUTH_REQUIRED(false, true),
4653

4754
/** Task completed successfully (terminal state). */
48-
TASK_STATE_COMPLETED(true),
55+
TASK_STATE_COMPLETED(true, false),
4956

5057
/** Task was canceled by user or system (terminal state). */
51-
TASK_STATE_CANCELED(true),
58+
TASK_STATE_CANCELED(true, false),
5259

5360
/** Task failed due to an error (terminal state). */
54-
TASK_STATE_FAILED(true),
61+
TASK_STATE_FAILED(true, false),
5562

5663
/** Task was rejected by the agent (terminal state). */
57-
TASK_STATE_REJECTED(true),
64+
TASK_STATE_REJECTED(true, false),
5865

5966
/** Task state is unknown or cannot be determined (terminal state). */
60-
UNRECOGNIZED(true);
67+
UNRECOGNIZED(true, false);
6168

6269
private final boolean isFinal;
70+
private final boolean isInterrupted;
6371

64-
TaskState(boolean isFinal) {
72+
TaskState(boolean isFinal, boolean isInterrupted) {
6573
this.isFinal = isFinal;
74+
this.isInterrupted = isInterrupted;
6675
}
6776

6877
/**
@@ -71,10 +80,32 @@ public enum TaskState {
7180
* Terminal states indicate that the task has completed its lifecycle and will
7281
* not transition to any other state. This is used by the event queue system
7382
* to determine when to close queues and by clients to know when to stop polling.
83+
* <p>
84+
* Terminal states: COMPLETED, FAILED, CANCELED, REJECTED, UNRECOGNIZED.
7485
*
7586
* @return {@code true} if this is a terminal state, {@code false} else.
7687
*/
7788
public boolean isFinal(){
7889
return isFinal;
7990
}
91+
92+
/**
93+
* Determines whether this state is an interrupted state.
94+
* <p>
95+
* Interrupted states indicate that the task execution has paused and requires
96+
* external action before proceeding. The task may resume after the required
97+
* action is provided. Interrupted states are NOT terminal - streams should
98+
* remain open to deliver state updates.
99+
* <p>
100+
* Interrupted states: INPUT_REQUIRED, AUTH_REQUIRED.
101+
* <p>
102+
* Per A2A Protocol Specification 4.1.3 (TaskState):
103+
* "TASK_STATE_INPUT_REQUIRED: This is an interrupted state."
104+
* "TASK_STATE_AUTH_REQUIRED: This is an interrupted state."
105+
*
106+
* @return {@code true} if this is an interrupted state, {@code false} else.
107+
*/
108+
public boolean isInterrupted() {
109+
return isInterrupted;
110+
}
80111
}

0 commit comments

Comments
 (0)