From 0446fe5d1957f9fc814fee7c5d9c7fce107bebaa Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 15 May 2025 11:54:11 +0200 Subject: [PATCH 01/11] Start expanding StreamingChunk --- haystack/components/generators/chat/openai.py | 2 +- .../components/generators/hugging_face_api.py | 2 +- haystack/components/generators/openai.py | 2 +- haystack/dataclasses/chat_message.py | 15 +++++++++++++++ haystack/dataclasses/streaming_chunk.py | 9 +++++++++ 5 files changed, 27 insertions(+), 3 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 8f0d99cad4..67ff695a7b 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -577,7 +577,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio # we stream the content of the chunk if it's not a tool or function call choice: ChunkChoice = chunk.choices[0] content = choice.delta.content or "" - chunk_message = StreamingChunk(content) + chunk_message = StreamingChunk(content=content) # but save the tool calls and function call in the meta if they are present # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method chunk_message.meta.update( diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index f55db685d0..6d62670f3d 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -227,7 +227,7 @@ def _stream_and_build_response( if first_chunk_time is None: first_chunk_time = datetime.now().isoformat() - stream_chunk = StreamingChunk(token.text, chunk_metadata) + stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata) chunks.append(stream_chunk) streaming_callback(stream_chunk) diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 72455a1b92..901b06eb2d 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -320,7 +320,7 @@ def _build_chunk(chunk: Any) -> StreamingChunk: """ choice = chunk.choices[0] content = choice.delta.content or "" - chunk_message = StreamingChunk(content) + chunk_message = StreamingChunk(content=content) chunk_message.meta.update( { "model": chunk.model, diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 691a10ece4..e193065563 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -75,6 +75,21 @@ class ToolCallResult: error: bool +@dataclass +class ToolCallDelta: + """ + Represents a Tool call prepared by the model, usually contained in an assistant message. + + :param id: The ID of the Tool call. + :param tool_name: The name of the Tool to call. + :param arguments_delta: + """ + + tool_name: Optional[str] = None + arguments_delta: Optional[str] = None + id: Optional[str] = None # noqa: A003 + + @dataclass class TextContent: """ diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index b11ef5f7f2..e824cfb0f4 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from typing import Any, Awaitable, Callable, Dict, Optional, Union +from haystack.dataclasses.chat_message import ToolCallDelta, ToolCallResult from haystack.utils.asynchronous import is_callable_async_compatible @@ -16,11 +17,19 @@ class StreamingChunk: This structure facilitates the handling and processing of streamed data in a systematic manner. :param content: The content of the message chunk as a string. + :param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk. + :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. + :param start: A boolean indicating whether this chunk marks the start of a message. + :param end: A boolean indicating whether this chunk marks the end of a message. :param meta: A dictionary containing metadata related to the message chunk. """ content: str meta: Dict[str, Any] = field(default_factory=dict, hash=False) + tool_call: Optional[ToolCallDelta] = None + tool_call_result: Optional[ToolCallResult] = None + start: Optional[bool] = None + end: Optional[bool] = None SyncStreamingCallbackT = Callable[[StreamingChunk], None] From f2ddbff1a6595fec779ee7e204b6495481ef6328 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 08:44:51 +0200 Subject: [PATCH 02/11] First pass at expanding Streaming Chunk --- haystack/components/generators/chat/openai.py | 63 +- haystack/components/generators/utils.py | 42 +- haystack/dataclasses/__init__.py | 2 + haystack/dataclasses/chat_message.py | 15 - haystack/dataclasses/streaming_chunk.py | 45 +- .../components/generators/chat/test_openai.py | 619 ++++++++++-------- 6 files changed, 467 insertions(+), 319 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 67ff695a7b..200a0546b0 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -20,6 +20,8 @@ StreamingChunk, SyncStreamingCallbackT, ToolCall, + ToolCallDelta, + ToolCallResult, select_streaming_callback, ) from haystack.tools import ( @@ -418,22 +420,22 @@ def _prepare_api_call( # noqa: PLR0913 } def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]: + openai_chunks = [] chunks: List[StreamingChunk] = [] - chunk = None chunk_delta: StreamingChunk for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." + openai_chunks.append(chunk) chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) chunks.append(chunk_delta) callback(chunk_delta) - return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)] async def _handle_async_stream_response( self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT ) -> List[ChatMessage]: chunks: List[StreamingChunk] = [] - chunk = None chunk_delta: StreamingChunk async for chunk in chat_completion: # pylint: disable=not-an-iterable @@ -441,7 +443,7 @@ async def _handle_async_stream_response( chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) chunks.append(chunk_delta) await callback(chunk_delta) - return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)] + return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)] def _check_finish_reason(self, meta: Dict[str, Any]) -> None: if meta["finish_reason"] == "length": @@ -458,13 +460,10 @@ def _check_finish_reason(self, meta: Dict[str, Any]) -> None: finish_reason=meta["finish_reason"], ) - def _convert_streaming_chunks_to_chat_message( - self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk] - ) -> ChatMessage: + def _convert_streaming_chunks_to_chat_message(self, chunks: List[StreamingChunk]) -> ChatMessage: """ Connects the streaming chunks into a single ChatMessage. - :param last_chunk: The last chunk returned by the OpenAI API. :param chunks: The list of all `StreamingChunk` objects. :returns: The ChatMessage. @@ -514,11 +513,11 @@ def _convert_streaming_chunks_to_chat_message( finish_reason = finish_reasons[-1] if finish_reasons else None meta = { - "model": last_chunk.model, + "model": chunks[-1].meta.get("model"), "index": 0, "finish_reason": finish_reason, "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received - "usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available + "usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available } return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta) @@ -570,25 +569,53 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio :returns: The StreamingChunk. """ - # if there are no choices, return an empty chunk + # Choices is empty on the very first chunk which provides role information (e.g. "assistant"). + # It is also empty if include_usage is set to True where the usage information is returned. if len(chunk.choices) == 0: - return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()}) + return StreamingChunk( + content="", + meta={ + "model": chunk.model, + "received_at": datetime.now().isoformat(), + "usage": self._serialize_usage(chunk.usage), + }, + ) # we stream the content of the chunk if it's not a tool or function call choice: ChunkChoice = chunk.choices[0] content = choice.delta.content or "" - chunk_message = StreamingChunk(content=content) - # but save the tool calls and function call in the meta if they are present - # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method - chunk_message.meta.update( - { + # create a list of ToolCallDelta objects from the tool calls + tool_call_deltas = [] + if choice.delta.tool_calls: + for tool_call in choice.delta.tool_calls: + tool_call_deltas.append( + ToolCallDelta( + index=tool_call.index, + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments or None, + ) + ) + # determine if this is the start of a content block (e.g. new tool call, completion, etc.) + start = None + if tool_call_deltas: + start = any(tc.name is not None for tc in tool_call_deltas) + # TODO Need to add check for when the start of a normal content stream is + chunk_message = StreamingChunk( + content=content, + tool_calls=tool_call_deltas if tool_call_deltas else None, + start=start, + meta={ "model": chunk.model, "index": choice.index, "tool_calls": choice.delta.tool_calls, "finish_reason": choice.finish_reason, "received_at": datetime.now().isoformat(), - } + "usage": self._serialize_usage(chunk.usage), + }, ) + # but save the tool calls and function call in the meta if they are present + # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method return chunk_message def _serialize_usage(self, usage): diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 4a8346b7b3..6130fe835e 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall - from haystack.dataclasses import StreamingChunk @@ -22,24 +20,38 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and tool results. """ - # Print tool call metadata if available (from ChatGenerator) - if chunk.meta.get("tool_calls"): - for tool_call in chunk.meta["tool_calls"]: - if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function: - # print the tool name - if tool_call.function.name and not tool_call.function.arguments: - print("[TOOL CALL]\n", flush=True, end="") - print(f"Tool: {tool_call.function.name} ", flush=True, end="") - print("\nArguments: ", flush=True, end="") + ## Tool Call streaming + if chunk.tool_calls: + for tool_call in chunk.tool_calls: + # Presence of tool_name indicates beginning of a tool call + # or chunk.start: would be equivalent here + if tool_call.name: + # If this is not the first tool call, we add two new lines + if tool_call.index > 0: + print("\n\n", flush=True, end="") + + print("[TOOL CALL]\n", flush=True, end="") + print(f"Tool: {tool_call.name} ", flush=True, end="") + print("\nArguments: ", flush=True, end="") # print the tool arguments - if tool_call.function.arguments: - print(tool_call.function.arguments, flush=True, end="") + if tool_call.arguments: + print(tool_call.arguments, flush=True, end="") + ## Tool Call Result streaming # Print tool call results if available (from ToolInvoker) - if chunk.meta.get("tool_result"): - print(f"\n\n[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="") + if chunk.tool_call_results: + # Tool Call Result is fully formed so delta accumulation is not needed + print(f"[TOOL RESULT]\n{chunk.tool_call_results[0]}\n\n", flush=True, end="") + ## Normal content streaming # Print the main content of the chunk (from ChatGenerator) if chunk.content: + if chunk.start: + print("[ASSISTANT]\n", flush=True, end="") print(chunk.content, flush=True, end="") + + # End of LLM assistant message so we add two new lines + # This ensures spacing between LLM message and next LLM message or Tool Call Result + if chunk.meta.get("finish_reason") is not None: + print("\n\n", flush=True, end="") diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 15da1ae84a..8692772572 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -19,6 +19,7 @@ "AsyncStreamingCallbackT", "StreamingCallbackT", "SyncStreamingCallbackT", + "ToolCallDelta", "select_streaming_callback", ], } @@ -35,6 +36,7 @@ StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, + ToolCallDelta, select_streaming_callback, ) else: diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index e193065563..691a10ece4 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -75,21 +75,6 @@ class ToolCallResult: error: bool -@dataclass -class ToolCallDelta: - """ - Represents a Tool call prepared by the model, usually contained in an assistant message. - - :param id: The ID of the Tool call. - :param tool_name: The name of the Tool to call. - :param arguments_delta: - """ - - tool_name: Optional[str] = None - arguments_delta: Optional[str] = None - id: Optional[str] = None # noqa: A003 - - @dataclass class TextContent: """ diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index e824cfb0f4..fcec91da2a 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -3,12 +3,35 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Dict, Optional, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union -from haystack.dataclasses.chat_message import ToolCallDelta, ToolCallResult +from haystack.dataclasses.chat_message import ToolCallResult from haystack.utils.asynchronous import is_callable_async_compatible +# Similar to ChoiceDeltaToolCall from OpenAI +@dataclass(kw_only=True) +class ToolCallDelta: + """ + Represents a Tool call prepared by the model, usually contained in an assistant message. + + :param id: The ID of the Tool call. + :param name: The name of the Tool to call. + :param arguments: + """ + + index: int + id: Optional[str] = None # noqa: A003 + name: Optional[str] = None + arguments: Optional[str] = None + + def __post_init__(self): + if self.name is None and self.arguments is None: + raise ValueError("At least one of tool_name or arguments must be provided.") + if self.name is not None and self.arguments is not None: + raise ValueError("Only one of tool_name or arguments can be provided.") + + @dataclass class StreamingChunk: """ @@ -17,19 +40,25 @@ class StreamingChunk: This structure facilitates the handling and processing of streamed data in a systematic manner. :param content: The content of the message chunk as a string. - :param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk. - :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. + :param tool_calls: An optional ToolCallDelta object representing a tool call associated with the message chunk. + :param tool_call_results: An optional ToolCallResult object representing the result of a tool call. :param start: A boolean indicating whether this chunk marks the start of a message. - :param end: A boolean indicating whether this chunk marks the end of a message. :param meta: A dictionary containing metadata related to the message chunk. """ content: str meta: Dict[str, Any] = field(default_factory=dict, hash=False) - tool_call: Optional[ToolCallDelta] = None - tool_call_result: Optional[ToolCallResult] = None + tool_calls: Optional[List[ToolCallDelta]] = None + tool_call_results: Optional[List[ToolCallResult]] = None start: Optional[bool] = None - end: Optional[bool] = None + + def __post_init__(self): + if self.tool_calls and self.content: + raise ValueError("A StreamingChunk should not have both content and tool calls.") + if self.tool_call_results and self.content: + raise ValueError("A StreamingChunk should not have both content and tool call results.") + if self.tool_calls and self.tool_call_results: + raise ValueError("A StreamingChunk should not have both tool calls and tool call results.") SyncStreamingCallbackT = Callable[[StreamingChunk], None] diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index d393e33a1e..14c6dbcac6 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -13,6 +13,7 @@ from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk @@ -508,7 +509,7 @@ def test_run_with_tools(self, tools): assert not message.text assert message.tool_calls - tool_call = message.tool_call + tool_call = message.tool_calls assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} @@ -541,7 +542,7 @@ def streaming_callback(chunk: StreamingChunk) -> None: message = response["replies"][0] assert message.tool_calls - tool_call = message.tool_call + tool_call = message.tool_calls assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} @@ -593,277 +594,365 @@ def test_invalid_tool_call_json(self, tools, caplog): assert message.meta["finish_reason"] == "tool_calls" assert message.meta["usage"]["completion_tokens"] == 47 - def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): - component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) - chunk = chat_completion_chunk.ChatCompletionChunk( - id="chatcmpl-B2g1XYv1WzALulC5c8uLtJgvEB48I", - choices=[ - chat_completion_chunk.Choice( - delta=chat_completion_chunk.ChoiceDelta( - content=None, function_call=None, refusal=None, role=None, tool_calls=None - ), - finish_reason="tool_calls", - index=0, - logprobs=None, - ) - ], - created=1739977895, - model="gpt-4o-mini-2024-07-18", - object="chat.completion.chunk", - service_tier="default", - system_fingerprint="fp_00428b782a", - usage=None, - ) - chunks = [ - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": None, - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.910076", - }, - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id="call_ZOj5l67zhZOx6jqjg7ATQwb6", - function=chat_completion_chunk.ChoiceDeltaToolCallFunction( - arguments="", name="rag_pipeline_tool" - ), - type="function", - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.913919", - }, - ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"qu', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.914439", - }, + def test_handle_stream_response(self): + openai_chunks = [ + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, function_call=None, refusal=None, role="assistant", tool_calls=None + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ery":', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.924146", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_zcvlnVaTeJWRjLAFfYxX69z4", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments=' "Wher', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.924420", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='{"ci', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="e do", name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.944398", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='ty": ', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="es Ma", name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.944958", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='"Paris', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="rk liv", name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.945507", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='"}', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='e?"}', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.946018", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id="call_C88m67V16CrETq6jbNXjdZI9", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id="call_STxsYY69wVOvxWqopAt3uWTB", - function=chat_completion_chunk.ChoiceDeltaToolCallFunction( - arguments="", name="get_weather" - ), - type="function", - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.946578", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='{"ci', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"ci', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.946981", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='ty": ', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ty": ', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.947411", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='"Berli', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='"Berli', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.947643", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + content=None, + function_call=None, + refusal=None, + role=None, + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id=None, + function=ChoiceDeltaToolCallFunction(arguments='n"}', name=None), + type=None, + ) + ], + ), + finish_reason=None, + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": [ - chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='n"}', name=None), - type=None, - ) - ], - "finish_reason": None, - "received_at": "2025-02-19T16:02:55.947939", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta(content=None, function_call=None, refusal=None, role=None, tool_calls=None), + finish_reason="tool_calls", + index=0, + logprobs=None, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=None, ), - StreamingChunk( - content="", - meta={ - "model": "gpt-4o-mini-2024-07-18", - "index": 0, - "tool_calls": None, - "finish_reason": "tool_calls", - "received_at": "2025-02-19T16:02:55.948772", - }, + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=CompletionUsage( + completion_tokens=42, + prompt_tokens=282, + total_tokens=324, + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 + ), + prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), + ), ), ] - - # Convert chunks to a chat message - result = component._convert_streaming_chunks_to_chat_message(chunk, chunks) + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + result = component._handle_stream_response(openai_chunks, callback=lambda chunk: None)[0] # type: ignore assert not result.texts assert not result.text @@ -988,8 +1077,12 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_tools_streaming(self, tools): - chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] - component = OpenAIChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] + component = OpenAIChatGenerator( + tools=tools, + streaming_callback=print_streaming_chunk, + generation_kwargs={"stream_options": {"include_usage": True}}, + ) results = component.run(chat_messages) assert len(results["replies"]) == 1 message = results["replies"][0] @@ -997,7 +1090,7 @@ def test_live_run_with_tools_streaming(self, tools): assert not message.texts assert not message.text assert message.tool_calls - tool_call = message.tool_call + tool_call = message.tool_calls assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} @@ -1039,7 +1132,7 @@ def test_live_run_with_toolset(self, tools): assert not message.texts assert not message.text assert message.tool_calls - tool_call = message.tool_call + tool_call = message.tool_calls assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} From 6cb7a310174b0a3987aa2114195abc16dd10754d Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 12:00:52 +0200 Subject: [PATCH 03/11] Working version! --- haystack/components/generators/chat/openai.py | 99 ++++++++++++------- haystack/components/generators/utils.py | 39 ++++---- haystack/components/tools/tool_invoker.py | 5 + haystack/dataclasses/streaming_chunk.py | 32 +++--- 4 files changed, 106 insertions(+), 69 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 200a0546b0..38062940f0 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -21,7 +21,6 @@ SyncStreamingCallbackT, ToolCall, ToolCallDelta, - ToolCallResult, select_streaming_callback, ) from haystack.tools import ( @@ -420,29 +419,39 @@ def _prepare_api_call( # noqa: PLR0913 } def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]: - openai_chunks = [] + openai_chunks: List[ChatCompletionChunk] = [] chunks: List[StreamingChunk] = [] + chunk_deltas: List[StreamingChunk] chunk_delta: StreamingChunk for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." + chunk_deltas = self._convert_chat_completion_chunk_to_streaming_chunk( + chunk=chunk, previous_chunks=openai_chunks + ) + for chunk_delta in chunk_deltas: + chunks.append(chunk_delta) + callback(chunk_delta) openai_chunks.append(chunk) - chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) - chunks.append(chunk_delta) - callback(chunk_delta) return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)] async def _handle_async_stream_response( self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT ) -> List[ChatMessage]: + openai_chunks: List[ChatCompletionChunk] = [] chunks: List[StreamingChunk] = [] + chunk_deltas: List[StreamingChunk] chunk_delta: StreamingChunk async for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk) - chunks.append(chunk_delta) - await callback(chunk_delta) + chunk_deltas = self._convert_chat_completion_chunk_to_streaming_chunk( + chunk=chunk, previous_chunks=openai_chunks + ) + for chunk_delta in chunk_deltas: + chunks.append(chunk_delta) + await callback(chunk_delta) + openai_chunks.append(chunk) return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)] def _check_finish_reason(self, meta: Dict[str, Any]) -> None: @@ -560,11 +569,14 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c ) return chat_message - def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk: + def _convert_chat_completion_chunk_to_streaming_chunk( + self, chunk: ChatCompletionChunk, previous_chunks: List[ChatCompletionChunk] + ) -> List[StreamingChunk]: """ Converts the streaming response chunk from the OpenAI API to a StreamingChunk. :param chunk: The chunk returned by the OpenAI API. + :param previous_chunks: The previous chunks received from the OpenAI API. :returns: The StreamingChunk. @@ -572,39 +584,58 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio # Choices is empty on the very first chunk which provides role information (e.g. "assistant"). # It is also empty if include_usage is set to True where the usage information is returned. if len(chunk.choices) == 0: - return StreamingChunk( - content="", - meta={ - "model": chunk.model, - "received_at": datetime.now().isoformat(), - "usage": self._serialize_usage(chunk.usage), - }, - ) + return [ + StreamingChunk( + content="", + # Index is None since it's only used when a content block is present + index=None, + meta={ + "model": chunk.model, + "received_at": datetime.now().isoformat(), + "usage": self._serialize_usage(chunk.usage), + }, + ) + ] # we stream the content of the chunk if it's not a tool or function call choice: ChunkChoice = chunk.choices[0] content = choice.delta.content or "" + # create a list of ToolCallDelta objects from the tool calls - tool_call_deltas = [] if choice.delta.tool_calls: + chunk_messages = [] for tool_call in choice.delta.tool_calls: - tool_call_deltas.append( - ToolCallDelta( - index=tool_call.index, - id=tool_call.id, - name=tool_call.function.name, - arguments=tool_call.function.arguments or None, - ) + chunk_message = StreamingChunk( + content=content, + # We adopt the tool_call.index as the index of the chunk + index=tool_call.index, + tool_call=ToolCallDelta( + id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments or None + ), + start=tool_call.function.name is not None, + meta={ + "model": chunk.model, + "index": choice.index, + "tool_calls": choice.delta.tool_calls, + "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), + "usage": self._serialize_usage(chunk.usage), + }, ) - # determine if this is the start of a content block (e.g. new tool call, completion, etc.) - start = None - if tool_call_deltas: - start = any(tc.name is not None for tc in tool_call_deltas) - # TODO Need to add check for when the start of a normal content stream is + chunk_messages.append(chunk_message) + return chunk_messages + + # If we reach here content should not be empty chunk_message = StreamingChunk( content=content, - tool_calls=tool_call_deltas if tool_call_deltas else None, - start=start, + # We set the index to be 0 since if text content is being streamed then no tool calls are being streamed + # NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like + # Anthropic/Bedrock + index=0, + tool_call=None, + # The first chunk is always a start message chunk, so if we reach here and previous_chunks is length 1 + # then this is the start of text content + start=len(previous_chunks) == 1 or None, meta={ "model": chunk.model, "index": choice.index, @@ -614,9 +645,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio "usage": self._serialize_usage(chunk.usage), }, ) - # but save the tool calls and function call in the meta if they are present - # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method - return chunk_message + return [chunk_message] def _serialize_usage(self, usage): """Convert OpenAI usage object to serializable dict recursively""" diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 6130fe835e..34d00236cd 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -21,37 +21,38 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: tool results. """ ## Tool Call streaming - if chunk.tool_calls: - for tool_call in chunk.tool_calls: - # Presence of tool_name indicates beginning of a tool call - # or chunk.start: would be equivalent here - if tool_call.name: - # If this is not the first tool call, we add two new lines - if tool_call.index > 0: - print("\n\n", flush=True, end="") - - print("[TOOL CALL]\n", flush=True, end="") - print(f"Tool: {tool_call.name} ", flush=True, end="") - print("\nArguments: ", flush=True, end="") - - # print the tool arguments - if tool_call.arguments: - print(tool_call.arguments, flush=True, end="") + if chunk.tool_call: + # Presence of tool_name indicates beginning of a tool call + # or chunk.tool_call.name: would be equivalent here + if chunk.start: + # If this is not the first content block of the message, add two new lines + if chunk.index > 0: + print("\n\n", flush=True, end="") + print("[TOOL CALL]\n", flush=True, end="") + print(f"Tool: {chunk.tool_call.name} ", flush=True, end="") + print("\nArguments: ", flush=True, end="") + + # print the tool arguments + if chunk.tool_call.arguments: + print(chunk.tool_call.arguments, flush=True, end="") ## Tool Call Result streaming # Print tool call results if available (from ToolInvoker) - if chunk.tool_call_results: + if chunk.tool_call_result: # Tool Call Result is fully formed so delta accumulation is not needed - print(f"[TOOL RESULT]\n{chunk.tool_call_results[0]}\n\n", flush=True, end="") + print(f"[TOOL RESULT]\n{chunk.tool_call_result}\n\n", flush=True, end="") ## Normal content streaming # Print the main content of the chunk (from ChatGenerator) if chunk.content: if chunk.start: + # If this is not the first content block of the message, add two new lines + if chunk.index > 0: + print("\n\n", flush=True, end="") print("[ASSISTANT]\n", flush=True, end="") print(chunk.content, flush=True, end="") # End of LLM assistant message so we add two new lines - # This ensures spacing between LLM message and next LLM message or Tool Call Result + # This ensures spacing between multiple LLM messages (e.g. Agent) or Tool Call Result if chunk.meta.get("finish_reason") is not None: print("\n\n", flush=True, end="") diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index d351bae0b6..21ff39b7dd 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -503,6 +503,9 @@ def run( streaming_callback( StreamingChunk( content="", + index=len(tool_messages) - 1, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, ) ) @@ -604,6 +607,8 @@ async def run_async( await streaming_callback( StreamingChunk( content="", + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, ) ) # type: ignore[misc] # we have checked that streaming_callback is not None and async diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index fcec91da2a..0338454732 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -17,10 +17,9 @@ class ToolCallDelta: :param id: The ID of the Tool call. :param name: The name of the Tool to call. - :param arguments: + :param arguments: Either the full arguments in JSON format or a delta of the arguments. """ - index: int id: Optional[str] = None # noqa: A003 name: Optional[str] = None arguments: Optional[str] = None @@ -28,8 +27,8 @@ class ToolCallDelta: def __post_init__(self): if self.name is None and self.arguments is None: raise ValueError("At least one of tool_name or arguments must be provided.") - if self.name is not None and self.arguments is not None: - raise ValueError("Only one of tool_name or arguments can be provided.") + # NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the + # name and full arguments in one chunk @dataclass @@ -40,25 +39,28 @@ class StreamingChunk: This structure facilitates the handling and processing of streamed data in a systematic manner. :param content: The content of the message chunk as a string. - :param tool_calls: An optional ToolCallDelta object representing a tool call associated with the message chunk. - :param tool_call_results: An optional ToolCallResult object representing the result of a tool call. - :param start: A boolean indicating whether this chunk marks the start of a message. :param meta: A dictionary containing metadata related to the message chunk. + :param index: An optional integer index representing which content block this chunk belongs to. + :param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk. + :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. + :param start: A boolean indicating whether this chunk marks the start of a content block. """ content: str meta: Dict[str, Any] = field(default_factory=dict, hash=False) - tool_calls: Optional[List[ToolCallDelta]] = None - tool_call_results: Optional[List[ToolCallResult]] = None + index: Optional[int] = None + tool_call: Optional[ToolCallDelta] = None + tool_call_result: Optional[ToolCallResult] = None start: Optional[bool] = None def __post_init__(self): - if self.tool_calls and self.content: - raise ValueError("A StreamingChunk should not have both content and tool calls.") - if self.tool_call_results and self.content: - raise ValueError("A StreamingChunk should not have both content and tool call results.") - if self.tool_calls and self.tool_call_results: - raise ValueError("A StreamingChunk should not have both tool calls and tool call results.") + fields_set = sum(bool(x) for x in (self.content, self.tool_call, self.tool_call_result)) + if fields_set > 1: + raise ValueError( + "Only one of `content`, `tool_call`, or `tool_call_result` may be set in a StreamingChunk. " + f"Got content: '{self.content}', tool_call: '{self.tool_call}', " + f"tool_call_result: '{self.tool_call_result}'" + ) SyncStreamingCallbackT = Callable[[StreamingChunk], None] From 005ef6934cd8c11f18244ced471bfa9371dffe38 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 12:08:25 +0200 Subject: [PATCH 04/11] Some tweaks and also make ToolInvoker stream a chunk with a finish reason --- haystack/components/generators/utils.py | 12 +++++------- haystack/components/tools/tool_invoker.py | 4 ++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 34d00236cd..8b539b9fe5 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -20,14 +20,15 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and tool results. """ + if chunk.start and chunk.index > 0: + # If this is not the first content block of the message, add two new lines + print("\n\n", flush=True, end="") + ## Tool Call streaming if chunk.tool_call: # Presence of tool_name indicates beginning of a tool call # or chunk.tool_call.name: would be equivalent here if chunk.start: - # If this is not the first content block of the message, add two new lines - if chunk.index > 0: - print("\n\n", flush=True, end="") print("[TOOL CALL]\n", flush=True, end="") print(f"Tool: {chunk.tool_call.name} ", flush=True, end="") print("\nArguments: ", flush=True, end="") @@ -40,15 +41,12 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: # Print tool call results if available (from ToolInvoker) if chunk.tool_call_result: # Tool Call Result is fully formed so delta accumulation is not needed - print(f"[TOOL RESULT]\n{chunk.tool_call_result}\n\n", flush=True, end="") + print(f"[TOOL RESULT]\n{chunk.tool_call_result}", flush=True, end="") ## Normal content streaming # Print the main content of the chunk (from ChatGenerator) if chunk.content: if chunk.start: - # If this is not the first content block of the message, add two new lines - if chunk.index > 0: - print("\n\n", flush=True, end="") print("[ASSISTANT]\n", flush=True, end="") print(chunk.content, flush=True, end="") diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 21ff39b7dd..33ddfd4df5 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -510,6 +510,10 @@ def run( ) ) + # We stream one more chunk that contains a finish_reason + if streaming_callback is not None: + streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"})) + return {"tool_messages": tool_messages, "state": state} @component.output_types(tool_messages=List[ChatMessage], state=State) From e29b6f2f56459d192b248ad8ead2e1a6c963af35 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 12:13:04 +0200 Subject: [PATCH 05/11] Properly update test --- test/components/generators/chat/test_openai.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 14c6dbcac6..2e3bfdb46f 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1090,10 +1090,11 @@ def test_live_run_with_tools_streaming(self, tools): assert not message.texts assert not message.text assert message.tool_calls - tool_call = message.tool_calls - assert isinstance(tool_call, ToolCall) - assert tool_call.tool_name == "weather" - assert tool_call.arguments == {"city": "Paris"} + tool_calls = message.tool_calls + for tool_call in tool_calls: + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} or tool_call.arguments == {"city": "Berlin"} assert message.meta["finish_reason"] == "tool_calls" def test_openai_chat_generator_with_toolset_initialization(self, tools, monkeypatch): From 5914d5bcc921a486e52aa92cbe764c4b03de6a08 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 12:33:51 +0200 Subject: [PATCH 06/11] Change to tool_name, remove kw_only since its python 3.10 only and update HuggingFaceAPIChatGenerator to start following new StreamingChunk --- .../generators/chat/hugging_face_api.py | 36 +++++++++++-------- haystack/components/generators/chat/openai.py | 4 ++- haystack/dataclasses/streaming_chunk.py | 8 ++--- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index f88afdd4f3..93349ac099 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -407,7 +407,6 @@ def _run_streaming( first_chunk_time = None finish_reason = None usage = None - meta: Dict[str, Any] = {} for chunk in api_output: # The chunk with usage returns an empty array for choices @@ -423,7 +422,13 @@ def _run_streaming( if choice.finish_reason: finish_reason = choice.finish_reason - stream_chunk = StreamingChunk(text, meta) + stream_chunk = StreamingChunk( + content=text, + index=choice.index, + # TODO Correctly evaluate start + start=None, + meta={"model": chunk.model, "finish_reason": choice.finish_reason}, + ) streaming_callback(stream_chunk) if chunk.usage: @@ -437,17 +442,16 @@ def _run_streaming( else: usage_dict = {"prompt_tokens": 0, "completion_tokens": 0} - meta.update( - { + message = ChatMessage.from_assistant( + text=generated_text, + meta={ "model": self._client.model, "index": 0, "finish_reason": finish_reason, "usage": usage_dict, "completion_start_time": first_chunk_time, - } + }, ) - - message = ChatMessage.from_assistant(text=generated_text, meta=meta) return {"replies": [message]} def _run_non_streaming( @@ -503,7 +507,6 @@ async def _run_streaming_async( first_chunk_time = None finish_reason = None usage = None - meta: Dict[str, Any] = {} async for chunk in api_output: # The chunk with usage returns an empty array for choices @@ -519,7 +522,13 @@ async def _run_streaming_async( if choice.finish_reason: finish_reason = choice.finish_reason - stream_chunk = StreamingChunk(text, meta) + stream_chunk = StreamingChunk( + content=text, + index=choice.index, + # TODO Correctly evaluate start + start=None, + meta={"model": chunk.model, "finish_reason": choice.finish_reason}, + ) await streaming_callback(stream_chunk) # type: ignore if chunk.usage: @@ -533,17 +542,16 @@ async def _run_streaming_async( else: usage_dict = {"prompt_tokens": 0, "completion_tokens": 0} - meta.update( - { + message = ChatMessage.from_assistant( + text=generated_text, + meta={ "model": self._async_client.model, "index": 0, "finish_reason": finish_reason, "usage": usage_dict, "completion_start_time": first_chunk_time, - } + }, ) - - message = ChatMessage.from_assistant(text=generated_text, meta=meta) return {"replies": [message]} async def _run_non_streaming_async( diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 38062940f0..f4b322611e 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -610,7 +610,9 @@ def _convert_chat_completion_chunk_to_streaming_chunk( # We adopt the tool_call.index as the index of the chunk index=tool_call.index, tool_call=ToolCallDelta( - id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments or None + id=tool_call.id, + tool_name=tool_call.function.name, + arguments=tool_call.function.arguments or None, ), start=tool_call.function.name is not None, meta={ diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index 0338454732..a611fc8cf5 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -9,23 +9,21 @@ from haystack.utils.asynchronous import is_callable_async_compatible -# Similar to ChoiceDeltaToolCall from OpenAI -@dataclass(kw_only=True) class ToolCallDelta: """ Represents a Tool call prepared by the model, usually contained in an assistant message. :param id: The ID of the Tool call. - :param name: The name of the Tool to call. + :param tool_name: The name of the Tool to call. :param arguments: Either the full arguments in JSON format or a delta of the arguments. """ id: Optional[str] = None # noqa: A003 - name: Optional[str] = None + tool_name: Optional[str] = None arguments: Optional[str] = None def __post_init__(self): - if self.name is None and self.arguments is None: + if self.tool_name is None and self.arguments is None: raise ValueError("At least one of tool_name or arguments must be provided.") # NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the # name and full arguments in one chunk From d141b47ed91dd51f48f5e427e6620b2c617cb733 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 13:30:54 +0200 Subject: [PATCH 07/11] Add reno --- ...update-streaming-chunk-more-info-9008e05b21eef349.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml diff --git a/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml b/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml new file mode 100644 index 0000000000..e2adc5cf7d --- /dev/null +++ b/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Updated StreamingChunk to add the fields `tool_call`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback. + - Added new dataclass ToolCallDelta for the `StreamingChunk.tool_call` field to reflect that the arguments can be a string delta. + - Updated `print_streaming_chunk` utility method to use these new fields. This especially improves the formatting when using this with Agent. + - Updated `OpenAIChatGenerator` and `HuggingFaceAPIChatGenerator` to follow the new dataclass. + - Updated `ToolInvoker` to follow the new format and also it now returns a final StreamingChunk that contains `finish_reason="tool_call_result"` in its metadata. From ac51918616ae9d1de2b58718c9389c859ad978dd Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 13:31:09 +0200 Subject: [PATCH 08/11] Some cleanup --- haystack/components/generators/chat/openai.py | 1 - haystack/components/generators/hugging_face_api.py | 1 + haystack/components/generators/openai.py | 9 +++++---- haystack/components/generators/utils.py | 6 +++--- haystack/dataclasses/streaming_chunk.py | 1 + 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index f4b322611e..be318fa459 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -634,7 +634,6 @@ def _convert_chat_completion_chunk_to_streaming_chunk( # NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like # Anthropic/Bedrock index=0, - tool_call=None, # The first chunk is always a start message chunk, so if we reach here and previous_chunks is length 1 # then this is the start of text content start=len(previous_chunks) == 1 or None, diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index 6d62670f3d..5402afc0af 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -227,6 +227,7 @@ def _stream_and_build_response( if first_chunk_time is None: first_chunk_time = datetime.now().isoformat() + # TODO Consider adding start stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata) chunks.append(stream_chunk) streaming_callback(stream_chunk) diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 901b06eb2d..838670ab10 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -320,14 +320,15 @@ def _build_chunk(chunk: Any) -> StreamingChunk: """ choice = chunk.choices[0] content = choice.delta.content or "" - chunk_message = StreamingChunk(content=content) - chunk_message.meta.update( - { + # TODO Consider adding start + chunk_message = StreamingChunk( + content=content, + meta={ "model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason, "received_at": datetime.now().isoformat(), - } + }, ) return chunk_message diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 8b539b9fe5..7eee5a8d2c 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -20,7 +20,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and tool results. """ - if chunk.start and chunk.index > 0: + if chunk.start and chunk.index and chunk.index > 0: # If this is not the first content block of the message, add two new lines print("\n\n", flush=True, end="") @@ -30,7 +30,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: # or chunk.tool_call.name: would be equivalent here if chunk.start: print("[TOOL CALL]\n", flush=True, end="") - print(f"Tool: {chunk.tool_call.name} ", flush=True, end="") + print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="") print("\nArguments: ", flush=True, end="") # print the tool arguments @@ -41,7 +41,7 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: # Print tool call results if available (from ToolInvoker) if chunk.tool_call_result: # Tool Call Result is fully formed so delta accumulation is not needed - print(f"[TOOL RESULT]\n{chunk.tool_call_result}", flush=True, end="") + print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="") ## Normal content streaming # Print the main content of the chunk (from ChatGenerator) diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index a611fc8cf5..e7330277c5 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -9,6 +9,7 @@ from haystack.utils.asynchronous import is_callable_async_compatible +@dataclass class ToolCallDelta: """ Represents a Tool call prepared by the model, usually contained in an assistant message. From 012c0bb425c5eb81b6936c3c1785a8d2e82f50ec Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 14:01:19 +0200 Subject: [PATCH 09/11] Fix unit tests --- .../components/generators/chat/test_openai.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 2e3bfdb46f..711f310e0d 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -509,7 +509,7 @@ def test_run_with_tools(self, tools): assert not message.text assert message.tool_calls - tool_call = message.tool_calls + tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} @@ -542,7 +542,7 @@ def streaming_callback(chunk: StreamingChunk) -> None: message = response["replies"][0] assert message.tool_calls - tool_call = message.tool_calls + tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} @@ -959,18 +959,18 @@ def test_handle_stream_response(self): # Verify both tool calls were found and processed assert len(result.tool_calls) == 2 - assert result.tool_calls[0].id == "call_ZOj5l67zhZOx6jqjg7ATQwb6" - assert result.tool_calls[0].tool_name == "rag_pipeline_tool" - assert result.tool_calls[0].arguments == {"query": "Where does Mark live?"} - assert result.tool_calls[1].id == "call_STxsYY69wVOvxWqopAt3uWTB" - assert result.tool_calls[1].tool_name == "get_weather" + assert result.tool_calls[0].id == "call_zcvlnVaTeJWRjLAFfYxX69z4" + assert result.tool_calls[0].tool_name == "weather" + assert result.tool_calls[0].arguments == {"city": "Paris"} + assert result.tool_calls[1].id == "call_C88m67V16CrETq6jbNXjdZI9" + assert result.tool_calls[1].tool_name == "weather" assert result.tool_calls[1].arguments == {"city": "Berlin"} # Verify meta information assert result.meta["model"] == "gpt-4o-mini-2024-07-18" assert result.meta["finish_reason"] == "tool_calls" assert result.meta["index"] == 0 - assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076" + assert result.meta["completion_start_time"] is not None def test_convert_usage_chunk_to_streaming_chunk(self): component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) @@ -992,8 +992,11 @@ def test_convert_usage_chunk_to_streaming_chunk(self): prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), ), ) - result = component._convert_chat_completion_chunk_to_streaming_chunk(chunk) + result = component._convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])[0] assert result.content == "" + assert result.start is None + assert result.tool_call is None + assert result.tool_call_result is None assert result.meta["model"] == "gpt-4o-mini-2024-07-18" assert result.meta["received_at"] is not None From 604832897c4952389d06eee2bbc4961b4c662863 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 14:44:14 +0200 Subject: [PATCH 10/11] Fix mypy and integration test --- haystack/components/generators/chat/openai.py | 7 ++++--- test/components/generators/chat/test_openai.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index be318fa459..cfac0183a5 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -605,16 +605,17 @@ def _convert_chat_completion_chunk_to_streaming_chunk( if choice.delta.tool_calls: chunk_messages = [] for tool_call in choice.delta.tool_calls: + function = tool_call.function chunk_message = StreamingChunk( content=content, # We adopt the tool_call.index as the index of the chunk index=tool_call.index, tool_call=ToolCallDelta( id=tool_call.id, - tool_name=tool_call.function.name, - arguments=tool_call.function.arguments or None, + tool_name=function.name if function else None, + arguments=function.arguments if function and function.arguments else None, ), - start=tool_call.function.name is not None, + start=function.name is not None if function else None, meta={ "model": chunk.model, "index": choice.index, diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 711f310e0d..eda4d9b934 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -1136,7 +1136,7 @@ def test_live_run_with_toolset(self, tools): assert not message.texts assert not message.text assert message.tool_calls - tool_call = message.tool_calls + tool_call = message.tool_call assert isinstance(tool_call, ToolCall) assert tool_call.tool_name == "weather" assert tool_call.arguments == {"city": "Paris"} From 010c037c7b8d3fd717692058db9ee22a478acc9b Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Thu, 22 May 2025 17:52:13 +0200 Subject: [PATCH 11/11] Fix pylint --- haystack/dataclasses/streaming_chunk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index e7330277c5..588c067e52 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Dict, List, Optional, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Union from haystack.dataclasses.chat_message import ToolCallResult from haystack.utils.asynchronous import is_callable_async_compatible