Aggregate llm streaming response into non partial result#109
Aggregate llm streaming response into non partial result#109baptmont merged 9 commits intogoogle:mainfrom
Conversation
|
I'm applying the wrapper when exporting the gemini model instead of when it is added to the agent. Some models may use different logic and in this way we can apply the wrapper only for the models that need it. |
llm/gemini/gemini_test.go
Outdated
| t.Errorf("Model.GenerateStream() error = %v, wantErr %v", err, tt.wantErr) | ||
| return | ||
| } | ||
| if diff := cmp.Diff(tt.want, gotNonPartial); diff != "" { |
There was a problem hiding this comment.
Curious about the checks in this test: why do we read readResponsePartial then readResponseNonPartial and expect both to be equal to tt.want?
Right now there's one simple test case returning one word, so this works.
Will the test be broken if there's another test case added, which returns e.g. 2 responses: partial + final?
There was a problem hiding this comment.
Since the readResponsePartial is concating the text in the response parts the result should be the same.
The current test is returning 2 responses the first containing Paris and the second containing the \n.
There was a problem hiding this comment.
The existing test was using the test data in TestModel_GenerateStream_ok.httprr, this test data represented an example returning responses via sse. It returned the following 2 sse events:
data: {"candidates": [{"content": {"parts": [{"text": "Paris"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 11,"totalTokenCount": 11,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 11}]},"modelVersion": "gemini-2.0-flash","responseId": "wzCjaPa4As7shMIP2Mei0AI"}
data: {"candidates": [{"content": {"parts": [{"text": "\n"}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 10,"candidatesTokenCount": 2,"totalTokenCount": 12,"promptTokensDetails": [{"modality": "TEXT","tokenCount": 10}],"candidatesTokensDetails": [{"modality": "TEXT","tokenCount": 2}]},"modelVersion": "gemini-2.0-flash","responseId": "wzCjaPa4As7shMIP2Mei0AI"}
The test was using the func readResponse(s iter.Seq2[*llm.Response, error]) (string, error) to convert that stream into a single string "Paris\n".
After adding the aggregate logic, a new event appears in the stream making the result "Paris\nParis\n".
I swapped the readResponse with a readResponsePartial and readResponseNonPartial.
In this way the concatenation of all partial events will be the original "Paris\n", and the concatenation of the non-partial events (the aggregated event) will also be "Paris\n".
There was a problem hiding this comment.
Makes sense, thanks!
Could I ask if possible to add a short comment explanation for readResponsePartial and readResponseNonPartial?
Without context it may be tricky to understand from the first glance for other readers.
llm/gemini/gemini.go
Outdated
| GroundingMetadata: candidate.GroundingMetadata, | ||
| Partial: !complete, | ||
| UsageMetadata: resp.UsageMetadata, | ||
| Partial: true, |
There was a problem hiding this comment.
I have a mixed feeling of always marking the Response as partial. (and relying on the wrapper to aggregate). Even if it's one full response only.
WDYT if the stream aggregator process has the following prototype (types.GenerateContentResponse -> llm.Response), same as python? So it's not an external wrapper, but rather a processor here within GenerateStream?
This would also make it easier to reason about and keep up this logic in sync across adks.
There was a problem hiding this comment.
I changed the implementation to use the same logic as python, removing the wrapper model and instead using directly the streamingResponseAggregator.
Let me WDYT, of this approach. It requires a little bit of logic duplication since in every model it needs to call the close method at the end of the processing the stream, but it is more flexible.
Partial should have been set to false. When processing, the aggregator was already modifying the value.
| return nil | ||
| } | ||
|
|
||
| func (s *streamingResponseAggregator) Clear() { |
There was a problem hiding this comment.
Do we need Clear() logic?
There was a problem hiding this comment.
Yes, but Close is a not a good name for the method since it is also being used to generate intermediate events.
I'll change that and make clear private
| response := &llm.Response{ | ||
| Content: &genai.Content{Parts: parts, Role: s.role}, | ||
| ErrorCode: s.response.ErrorCode, | ||
| ErrorMessage: s.response.ErrorMessage, |
There was a problem hiding this comment.
Curious why ErrorMessage is not like in python?
error_message=None if candidate.finish_reason == types.FinishReason.STOP else candidate.finish_message,
Same for ErrorCode (it's an int right now, we will update it to string soon). If it makes sense to update to string in this PR, would be nice, otherwise a TODO is good.
There was a problem hiding this comment.
Python implementation is storing a GenerateContentResponse and using it on those fields when creating the aggregate event.
We were just storing the next event and copying its errorCode and errorMessage.
However we are also missing their logic for creating the llmResponse, https://github.com/google/adk-python/blob/83fd0457188decdabeae58b4e8be25daa89f2943/src/google/adk/models/llm_response.py#L136.
This is also missing in the llm model Generate method, as such we are never setting the error message fields.
| return s.Close() | ||
| } | ||
|
|
||
| return nil |
There was a problem hiding this comment.
Is it intentional there's a difference with python logic?
If llmResponse.Content.Parts[0].InlineData == nil python yields one extra LLMResponse http://google3/third_party/py/google/adk/utils/streaming_utils.py;l=76;rcl=800521404
in addition to
http://google3/third_party/py/google/adk/utils/streaming_utils.py;l=82;rcl=800521404
There was a problem hiding this comment.
The logic here is just separated into ProcessResponse and aggregateResponse.
The second yield still happens in ProcessResponse if the aggregateResponse doesn't return nil.
llm/gemini/gemini_test.go
Outdated
| t.Errorf("Model.GenerateStream() error = %v, wantErr %v", err, tt.wantErr) | ||
| return | ||
| } | ||
| if diff := cmp.Diff(tt.want, gotNonPartial); diff != "" { |
|
|
||
| func CreateResponse(res *genai.GenerateContentResponse) *Response { | ||
| usageMetadata := res.UsageMetadata | ||
| if len(res.Candidates) > 0 && res.Candidates[0] != nil { |
There was a problem hiding this comment.
nit: let's reduce nesting? We have if condition -> return
llm/gemini/gemini_test.go
Outdated
| @@ -141,7 +159,25 @@ func readResponse(s iter.Seq2[*llm.Response, error]) (string, error) { | |||
| if resp.Content == nil || len(resp.Content.Parts) == 0 { | |||
| return answer, fmt.Errorf("encountered an empty response: %v", resp) | |||
| } | |||
| answer += resp.Content.Parts[0].Text | |||
| if resp.Partial { | |||
| answer += resp.Content.Parts[0].Text | |||
| } | |||
| } | |||
| return answer, nil | |||
| } | |||
|
|
|||
| func readResponseNonPartial(s iter.Seq2[*llm.Response, error]) (string, error) { | |||
| var answer string | |||
| for resp, err := range s { | |||
| if err != nil { | |||
| return answer, err | |||
| } | |||
| if resp.Content == nil || len(resp.Content.Parts) == 0 { | |||
| return answer, fmt.Errorf("encountered an empty response: %v", resp) | |||
| } | |||
| if !resp.Partial { | |||
| answer += resp.Content.Parts[0].Text | |||
| } | |||
There was a problem hiding this comment.
nit: maybe merge to one func? readResponse(readPartial bool, s iter.Seq2[*llm.Response, error])
llm/gemini/gemini_test.go
Outdated
| t.Errorf("Model.GenerateStream() error = %v, wantErr %v", err, tt.wantErr) | ||
| return | ||
| } | ||
| if diff := cmp.Diff(tt.want, gotNonPartial); diff != "" { |
There was a problem hiding this comment.
Makes sense, thanks!
Could I ask if possible to add a short comment explanation for readResponsePartial and readResponseNonPartial?
Without context it may be tricky to understand from the first glance for other readers.
| if len(genResp.Candidates) == 0 { | ||
| // shouldn't happen? | ||
| yield(nil, fmt.Errorf("empty response")) | ||
| return | ||
| } | ||
| candidate := genResp.Candidates[0] |
There was a problem hiding this comment.
Does it make sense to remove this now as we have llm.CreateResponse(genResp) handling that case?
* Add StreamAggregator model proxy * Add stream aggregator tests * Modify the implementation to apply the aggregator directly to the model. * Fix set_test to use llm.Model. * Add llmResponse create function using GenerateContentResponse * Add test for create llm request * Change gemini_test readResponse to use custom struct
When using a StreamingModeSSE the model returns the response as a stream of partial events.
Partial events are not saved in the session service (in runner.go Run method), and as such are not sent in the following LLMRequest.
Currently only the last event of the stream is considered non-partial (in gemini.go GenerateStream method), but it still is a partial response.
So on the following LLMRequests, only the partial response will be included causing the llm to lose parts of the conversation, as such:
llmRequest1 : [{user:'write me an email to Andrew'}]
llmResponse events : [{model:'Hi Andrew how are you'}, {model:'doing?'}]
llmRequest2 : [{user:'write me an email'}, {model:'doing?'} {user:'can you change the greeting?}]
The strategy for Adk python is to use and aggregator that collects the partial responses, and at the end generates an aggregated event that will be stored and used in the following llmrequests.
https://github.com/google/adk-python/blob/f7bd3c111c211e880d7c1954dd4508b952704c68/src/google/adk/models/google_llm.py#L143
Added modelWithStreamAggregator proxy, that when applied to a model will make it so that results in streaming response are aggregated into a final one with complete.