diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index 28dc674b38..fd0884969a 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -103,7 +103,9 @@ def _convert_tools_to_hfapi_tools( def _convert_chat_completion_stream_output_to_streaming_chunk( - chunk: "ChatCompletionStreamOutput", component_info: Optional[ComponentInfo] = None + chunk: "ChatCompletionStreamOutput", + previous_chunks: List[StreamingChunk], + component_info: Optional[ComponentInfo] = None, ) -> StreamingChunk: """ Converts the Hugging Face API ChatCompletionStreamOutput to a StreamingChunk. @@ -127,6 +129,8 @@ def _convert_chat_completion_stream_output_to_streaming_chunk( content=choice.delta.content or "", meta={"model": chunk.model, "received_at": datetime.now().isoformat(), "finish_reason": choice.finish_reason}, component_info=component_info, + index=0 if choice.finish_reason is None else None, + start=True if len(previous_chunks) == 0 else None, ) return stream_chunk @@ -434,10 +438,10 @@ def _run_streaming( ) component_info = ComponentInfo.from_component(self) - streaming_chunks = [] + streaming_chunks: List[StreamingChunk] = [] for chunk in api_output: streaming_chunk = _convert_chat_completion_stream_output_to_streaming_chunk( - chunk=chunk, component_info=component_info + chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info ) streaming_chunks.append(streaming_chunk) streaming_callback(streaming_chunk) @@ -498,10 +502,10 @@ async def _run_streaming_async( ) component_info = ComponentInfo.from_component(self) - streaming_chunks = [] + streaming_chunks: List[StreamingChunk] = [] async for chunk in api_output: stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk( - chunk=chunk, component_info=component_info + chunk=chunk, previous_chunks=streaming_chunks, component_info=component_info ) streaming_chunks.append(stream_chunk) await streaming_callback(stream_chunk) # type: ignore diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 9b8a374842..3a8d50de5a 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -22,6 +22,7 @@ StreamingChunk, SyncStreamingCallbackT, ToolCall, + ToolCallDelta, select_streaming_callback, ) from haystack.tools import ( @@ -422,9 +423,12 @@ def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreami chunks: List[StreamingChunk] = [] for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info) - chunks.append(chunk_delta) - callback(chunk_delta) + chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk( + chunk=chunk, previous_chunks=chunks, component_info=component_info + ) + for chunk_delta in chunk_deltas: + chunks.append(chunk_delta) + callback(chunk_delta) return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] async def _handle_async_stream_response( @@ -434,9 +438,12 @@ async def _handle_async_stream_response( chunks: List[StreamingChunk] = [] async for chunk in chat_completion: # pylint: disable=not-an-iterable assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, component_info=component_info) - chunks.append(chunk_delta) - await callback(chunk_delta) + chunk_deltas = _convert_chat_completion_chunk_to_streaming_chunk( + chunk=chunk, previous_chunks=chunks, component_info=component_info + ) + for chunk_delta in chunk_deltas: + chunks.append(chunk_delta) + await callback(chunk_delta) return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] @@ -497,12 +504,13 @@ def _convert_chat_completion_to_chat_message(completion: ChatCompletion, choice: def _convert_chat_completion_chunk_to_streaming_chunk( - chunk: ChatCompletionChunk, component_info: Optional[ComponentInfo] = None -) -> StreamingChunk: + chunk: ChatCompletionChunk, previous_chunks: List[StreamingChunk], component_info: Optional[ComponentInfo] = None +) -> List[StreamingChunk]: """ Converts the streaming response chunk from the OpenAI API to a StreamingChunk. :param chunk: The chunk returned by the OpenAI API. + :param previous_chunks: The previous chunks processed from the OpenAI API. :returns: The StreamingChunk. @@ -510,21 +518,62 @@ def _convert_chat_completion_chunk_to_streaming_chunk( # Choices is empty on the very first chunk which provides role information (e.g. "assistant"). # It is also empty if include_usage is set to True where the usage information is returned. if len(chunk.choices) == 0: - return StreamingChunk( - content="", - meta={ - "model": chunk.model, - "received_at": datetime.now().isoformat(), - "usage": _serialize_usage(chunk.usage), - }, - component_info=component_info, - ) + return [ + StreamingChunk( + content="", + component_info=component_info, + # Index is None since it's only set to an int when a content block is present + index=None, + meta={ + "model": chunk.model, + "received_at": datetime.now().isoformat(), + "usage": _serialize_usage(chunk.usage), + }, + ) + ] choice: ChunkChoice = chunk.choices[0] content = choice.delta.content or "" + # create a list of ToolCallDelta objects from the tool calls + if choice.delta.tool_calls: + chunk_messages = [] + for tool_call in choice.delta.tool_calls: + function = tool_call.function + chunk_message = StreamingChunk( + content=content, + # We adopt the tool_call.index as the index of the chunk + component_info=component_info, + index=tool_call.index, + tool_call=ToolCallDelta( + id=tool_call.id, + tool_name=function.name if function else None, + arguments=function.arguments if function and function.arguments else None, + ), + start=function.name is not None if function else None, + meta={ + "model": chunk.model, + "index": choice.index, + "tool_calls": choice.delta.tool_calls, + "finish_reason": choice.finish_reason, + "received_at": datetime.now().isoformat(), + "usage": _serialize_usage(chunk.usage), + }, + ) + chunk_messages.append(chunk_message) + return chunk_messages + chunk_message = StreamingChunk( content=content, + component_info=component_info, + # We set the index to be 0 since if text content is being streamed then no tool calls are being streamed + # NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like + # Anthropic Claude + index=0, + # The first chunk is always a start message chunk, so if we reach here and previous_chunks is length 1 + # then this is the start of text content. Previous length should be 1 since first chunk always contains + # role information. + start=len(previous_chunks) == 1 or None, meta={ "model": chunk.model, "index": choice.index, @@ -533,9 +582,8 @@ def _convert_chat_completion_chunk_to_streaming_chunk( "received_at": datetime.now().isoformat(), "usage": _serialize_usage(chunk.usage), }, - component_info=component_info, ) - return chunk_message + return [chunk_message] def _serialize_usage(usage): diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index f30d37ce2a..b2ed507788 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -230,7 +230,13 @@ def _stream_and_build_response( if first_chunk_time is None: first_chunk_time = datetime.now().isoformat() - stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata, component_info=component_info) + stream_chunk = StreamingChunk( + content=token.text, + meta=chunk_metadata, + component_info=component_info, + index=0, + start=True if len(chunks) == 0 else None, + ) chunks.append(stream_chunk) streaming_callback(stream_chunk) diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index e83ce6b0a3..40e36a9ad3 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -247,8 +247,9 @@ def run( for chunk in completion: chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk( chunk=chunk, # type: ignore + previous_chunks=chunks, component_info=component_info, - ) + )[0] chunks.append(chunk_delta) streaming_callback(chunk_delta) diff --git a/haystack/components/generators/utils.py b/haystack/components/generators/utils.py index 33fd3cb5be..19c33aec37 100644 --- a/haystack/components/generators/utils.py +++ b/haystack/components/generators/utils.py @@ -3,9 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Any, Dict, List - -from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall +from typing import Dict, List from haystack import logging from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall @@ -28,33 +26,38 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None: :param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and tool results. """ - # Print tool call metadata if available (from ChatGenerator) - if tool_calls := chunk.meta.get("tool_calls"): - for tool_call in tool_calls: - # Convert to dict if tool_call is a ChoiceDeltaToolCall - tool_call_dict: Dict[str, Any] = ( - tool_call.to_dict() if isinstance(tool_call, ChoiceDeltaToolCall) else tool_call - ) + if chunk.start and chunk.index and chunk.index > 0: + # If this is the start of a new content block but not the first content block, print two new lines + print("\n\n", flush=True, end="") - if function := tool_call_dict.get("function"): - if name := function.get("name"): - print("\n\n[TOOL CALL]\n", flush=True, end="") - print(f"Tool: {name} ", flush=True, end="") - print("\nArguments: ", flush=True, end="") + ## Tool Call streaming + if chunk.tool_call: + # If chunk.start is True indicates beginning of a tool call + # Also presence of chunk.tool_call.name indicates the start of a tool call too + if chunk.start: + print("[TOOL CALL]\n", flush=True, end="") + print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="") + print("\nArguments: ", flush=True, end="") - if arguments := function.get("arguments"): - print(arguments, flush=True, end="") + # print the tool arguments + if chunk.tool_call.arguments: + print(chunk.tool_call.arguments, flush=True, end="") + ## Tool Call Result streaming # Print tool call results if available (from ToolInvoker) - if tool_result := chunk.meta.get("tool_result"): - print(f"\n\n[TOOL RESULT]\n{tool_result}", flush=True, end="") + if chunk.tool_call_result: + # Tool Call Result is fully formed so delta accumulation is not needed + print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="") + ## Normal content streaming # Print the main content of the chunk (from ChatGenerator) - if content := chunk.content: - print(content, flush=True, end="") + if chunk.content: + if chunk.start: + print("[ASSISTANT]\n", flush=True, end="") + print(chunk.content, flush=True, end="") # End of LLM assistant message so we add two new lines - # This ensures spacing between multiple LLM messages (e.g. Agent) + # This ensures spacing between multiple LLM messages (e.g. Agent) or multiple Tool Call Results if chunk.meta.get("finish_reason") is not None: print("\n\n", flush=True, end="") @@ -71,38 +74,41 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C tool_calls = [] # Process tool calls if present in any chunk - tool_call_data: Dict[str, Dict[str, str]] = {} # Track tool calls by index - for chunk_payload in chunks: - tool_calls_meta = chunk_payload.meta.get("tool_calls") - if tool_calls_meta is not None: - for delta in tool_calls_meta: - # We use the index of the tool call to track it across chunks since the ID is not always provided - if delta.index not in tool_call_data: - tool_call_data[delta.index] = {"id": "", "name": "", "arguments": ""} - - # Save the ID if present - if delta.id is not None: - tool_call_data[delta.index]["id"] = delta.id - - if delta.function is not None: - if delta.function.name is not None: - tool_call_data[delta.index]["name"] += delta.function.name - if delta.function.arguments is not None: - tool_call_data[delta.index]["arguments"] += delta.function.arguments + tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index + for chunk in chunks: + if chunk.tool_call: + # We do this to make mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if + # tool_call is present + assert chunk.index is not None + + # We use the index of the chunk track the tool call across chunks since the ID is not always provided + if chunk.index not in tool_call_data: + tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""} + + # Save the ID if present + if chunk.tool_call.id is not None: + tool_call_data[chunk.index]["id"] = chunk.tool_call.id + + if chunk.tool_call.tool_name is not None: + tool_call_data[chunk.index]["name"] += chunk.tool_call.tool_name + if chunk.tool_call.arguments is not None: + tool_call_data[chunk.index]["arguments"] += chunk.tool_call.arguments # Convert accumulated tool call data into ToolCall objects - for call_data in tool_call_data.values(): + sorted_keys = sorted(tool_call_data.keys()) + for key in sorted_keys: + tool_call = tool_call_data[key] try: - arguments = json.loads(call_data["arguments"]) - tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments)) + arguments = json.loads(tool_call["arguments"]) + tool_calls.append(ToolCall(id=tool_call["id"], tool_name=tool_call["name"], arguments=arguments)) except json.JSONDecodeError: logger.warning( "OpenAI returned a malformed JSON string for tool call arguments. This tool call " "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. " "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}", - _id=call_data["id"], - _name=call_data["name"], - _arguments=call_data["arguments"], + _id=tool_call["id"], + _name=tool_call["name"], + _arguments=tool_call["arguments"], ) # finish_reason can appear in different places so we look for the last one diff --git a/haystack/components/tools/tool_invoker.py b/haystack/components/tools/tool_invoker.py index 32c08a3018..ffdc43311d 100644 --- a/haystack/components/tools/tool_invoker.py +++ b/haystack/components/tools/tool_invoker.py @@ -504,6 +504,9 @@ def run( streaming_callback( StreamingChunk( content="", + index=len(tool_messages) - 1, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, ) ) @@ -609,6 +612,9 @@ async def run_async( await streaming_callback( StreamingChunk( content="", + index=len(tool_messages) - 1, + tool_call_result=tool_messages[-1].tool_call_results[0], + start=True, meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call}, ) ) # type: ignore[misc] # we have checked that streaming_callback is not None and async diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 48c5544f49..f2adde2c66 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -19,6 +19,7 @@ "AsyncStreamingCallbackT", "StreamingCallbackT", "SyncStreamingCallbackT", + "ToolCallDelta", "select_streaming_callback", "ComponentInfo", ], @@ -37,6 +38,7 @@ StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, + ToolCallDelta, select_streaming_callback, ) else: diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index b7d1e91d65..85c2f2c8c3 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -6,9 +6,31 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Union from haystack.core.component import Component +from haystack.dataclasses.chat_message import ToolCallResult from haystack.utils.asynchronous import is_callable_async_compatible +@dataclass +class ToolCallDelta: + """ + Represents a Tool call prepared by the model, usually contained in an assistant message. + + :param id: The ID of the Tool call. + :param tool_name: The name of the Tool to call. + :param arguments: Either the full arguments in JSON format or a delta of the arguments. + """ + + id: Optional[str] = field(default=None) # noqa: A003 + tool_name: Optional[str] = field(default=None) + arguments: Optional[str] = field(default=None) + + def __post_init__(self): + if self.tool_name is None and self.arguments is None: + raise ValueError("At least one of tool_name or arguments must be provided.") + # NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the + # name and full arguments in one chunk + + @dataclass class ComponentInfo: """ @@ -48,11 +70,32 @@ class StreamingChunk: :param meta: A dictionary containing metadata related to the message chunk. :param component_info: A `ComponentInfo` object containing information about the component that generated the chunk, such as the component name and type. + :param index: An optional integer index representing which content block this chunk belongs to. + :param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk. + :param tool_call_result: An optional ToolCallResult object representing the result of a tool call. + :param start: A boolean indicating whether this chunk marks the start of a content block. """ content: str meta: Dict[str, Any] = field(default_factory=dict, hash=False) - component_info: Optional[ComponentInfo] = field(default=None, hash=False) + component_info: Optional[ComponentInfo] = field(default=None) + index: Optional[int] = field(default=None) + tool_call: Optional[ToolCallDelta] = field(default=None) + tool_call_result: Optional[ToolCallResult] = field(default=None) + start: Optional[bool] = field(default=None) + + def __post_init__(self): + fields_set = sum(bool(x) for x in (self.content, self.tool_call, self.tool_call_result)) + if fields_set > 1: + raise ValueError( + "Only one of `content`, `tool_call`, or `tool_call_result` may be set in a StreamingChunk. " + f"Got content: '{self.content}', tool_call: '{self.tool_call}', " + f"tool_call_result: '{self.tool_call_result}'" + ) + + # NOTE: We don't enforce this for self.content otherwise it would be a breaking change + if (self.tool_call or self.tool_call_result) and self.index is None: + raise ValueError("If `tool_call`, or `tool_call_result` is set, `index` must also be set.") SyncStreamingCallbackT = Callable[[StreamingChunk], None] diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index a787265874..a222c65eeb 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -365,9 +365,18 @@ def __init__( self.token_handler = stream_handler self.stop_words = stop_words or [] self.component_info = component_info + self._call_counter = 0 def on_finalized_text(self, word: str, stream_end: bool = False) -> None: """Callback function for handling the generated text.""" + self._call_counter += 1 word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: - self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info)) + self.token_handler( + StreamingChunk( + content=word_to_send, + index=0, + start=True if self._call_counter == 1 else None, + component_info=self.component_info, + ) + ) diff --git a/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml b/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml new file mode 100644 index 0000000000..e2adc5cf7d --- /dev/null +++ b/releasenotes/notes/update-streaming-chunk-more-info-9008e05b21eef349.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Updated StreamingChunk to add the fields `tool_call`, `tool_call_result`, `index`, and `start` to make it easier to format the stream in a streaming callback. + - Added new dataclass ToolCallDelta for the `StreamingChunk.tool_call` field to reflect that the arguments can be a string delta. + - Updated `print_streaming_chunk` utility method to use these new fields. This especially improves the formatting when using this with Agent. + - Updated `OpenAIChatGenerator` and `HuggingFaceAPIChatGenerator` to follow the new dataclass. + - Updated `ToolInvoker` to follow the new format and also it now returns a final StreamingChunk that contains `finish_reason="tool_call_result"` in its metadata. diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index bd8eda1e1c..049f01c08b 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -664,7 +664,7 @@ def test_convert_hfapi_tool_calls_invalid_type_arguments(self): assert len(tool_calls) == 0 @pytest.mark.parametrize( - "hf_stream_output, expected_stream_chunk", + "hf_stream_output, expected_stream_chunk, dummy_previous_chunks", [ ( ChatCompletionStreamOutput( @@ -685,7 +685,10 @@ def test_convert_hfapi_tool_calls_invalid_type_arguments(self): "model": "microsoft/Phi-3.5-mini-instruct", "finish_reason": None, }, + index=0, + start=True, ), + [], ), ( ChatCompletionStreamOutput( @@ -709,6 +712,7 @@ def test_convert_hfapi_tool_calls_invalid_type_arguments(self): "finish_reason": "stop", }, ), + [0], ), ( ChatCompletionStreamOutput( @@ -727,11 +731,16 @@ def test_convert_hfapi_tool_calls_invalid_type_arguments(self): "usage": {"completion_tokens": 2, "prompt_tokens": 21}, }, ), + [0, 1], ), ], ) - def test_convert_chat_completion_stream_output_to_streaming_chunk(self, hf_stream_output, expected_stream_chunk): - converted_stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk(chunk=hf_stream_output) + def test_convert_chat_completion_stream_output_to_streaming_chunk( + self, hf_stream_output, expected_stream_chunk, dummy_previous_chunks + ): + converted_stream_chunk = _convert_chat_completion_stream_output_to_streaming_chunk( + chunk=hf_stream_output, previous_chunks=dummy_previous_chunks + ) # Remove timestamp from comparison since it's always the current time converted_stream_chunk.meta.pop("received_at", None) expected_stream_chunk.meta.pop("received_at", None) diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 6639f8674b..e53d867e80 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -14,6 +14,7 @@ from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.chat import chat_completion_chunk @@ -597,6 +598,268 @@ def test_invalid_tool_call_json(self, tools, caplog): assert message.meta["finish_reason"] == "tool_calls" assert message.meta["usage"]["completion_tokens"] == 47 + def test_handle_stream_response(self): + openai_chunks = [ + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(role="assistant"), index=0)], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_zcvlnVaTeJWRjLAFfYxX69z4", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='{"ci')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='ty": ')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"Paris')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=0, function=ChoiceDeltaToolCallFunction(arguments='"}')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=1, + id="call_C88m67V16CrETq6jbNXjdZI9", + function=ChoiceDeltaToolCallFunction(arguments="", name="weather"), + type="function", + ) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='{"ci')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='ty": ')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='"Berli')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[ + chat_completion_chunk.Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall(index=1, function=ChoiceDeltaToolCallFunction(arguments='n"}')) + ] + ), + index=0, + ) + ], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[chat_completion_chunk.Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + ), + ChatCompletionChunk( + id="chatcmpl-BZdwjFecdcaQfCf7bn319vRp6fY8F", + choices=[], + created=1747834733, + model="gpt-4o-mini-2024-07-18", + object="chat.completion.chunk", + service_tier="default", + system_fingerprint="fp_54eb4bd693", + usage=CompletionUsage( + completion_tokens=42, + prompt_tokens=282, + total_tokens=324, + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0 + ), + prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), + ), + ), + ] + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) + result = component._handle_stream_response(openai_chunks, callback=lambda chunk: None)[0] # type: ignore + + assert not result.texts + assert not result.text + + # Verify both tool calls were found and processed + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].id == "call_zcvlnVaTeJWRjLAFfYxX69z4" + assert result.tool_calls[0].tool_name == "weather" + assert result.tool_calls[0].arguments == {"city": "Paris"} + assert result.tool_calls[1].id == "call_C88m67V16CrETq6jbNXjdZI9" + assert result.tool_calls[1].tool_name == "weather" + assert result.tool_calls[1].arguments == {"city": "Berlin"} + + # Verify meta information + assert result.meta["model"] == "gpt-4o-mini-2024-07-18" + assert result.meta["finish_reason"] == "tool_calls" + assert result.meta["index"] == 0 + assert result.meta["completion_start_time"] is not None + assert result.meta["usage"] == { + "completion_tokens": 42, + "prompt_tokens": 282, + "total_tokens": 324, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + } + def test_convert_usage_chunk_to_streaming_chunk(self): chunk = ChatCompletionChunk( id="chatcmpl-BC1y4wqIhe17R8sv3lgLcWlB4tXCw", @@ -616,8 +879,11 @@ def test_convert_usage_chunk_to_streaming_chunk(self): prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0), ), ) - result = _convert_chat_completion_chunk_to_streaming_chunk(chunk) + result = _convert_chat_completion_chunk_to_streaming_chunk(chunk=chunk, previous_chunks=[])[0] assert result.content == "" + assert result.start is None + assert result.tool_call is None + assert result.tool_call_result is None assert result.meta["model"] == "gpt-4o-mini-2024-07-18" assert result.meta["received_at"] is not None @@ -701,8 +967,12 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_tools_streaming(self, tools): - chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")] - component = OpenAIChatGenerator(tools=tools, streaming_callback=print_streaming_chunk) + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")] + component = OpenAIChatGenerator( + tools=tools, + streaming_callback=print_streaming_chunk, + generation_kwargs={"stream_options": {"include_usage": True}}, + ) results = component.run(chat_messages) assert len(results["replies"]) == 1 message = results["replies"][0] @@ -710,10 +980,11 @@ def test_live_run_with_tools_streaming(self, tools): assert not message.texts assert not message.text assert message.tool_calls - tool_call = message.tool_call - assert isinstance(tool_call, ToolCall) - assert tool_call.tool_name == "weather" - assert tool_call.arguments == {"city": "Paris"} + tool_calls = message.tool_calls + for tool_call in tool_calls: + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert tool_call.arguments == {"city": "Paris"} or tool_call.arguments == {"city": "Berlin"} assert message.meta["finish_reason"] == "tool_calls" def test_openai_chat_generator_with_toolset_initialization(self, tools, monkeypatch): diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 3e34487294..7ea0fcc37f 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -2,9 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -import logging import os -from typing import List from datetime import datetime import pytest from openai import OpenAIError @@ -13,7 +11,7 @@ from haystack.components.generators import OpenAIGenerator from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, StreamingChunk +from haystack.dataclasses import StreamingChunk from haystack.utils.auth import Secret diff --git a/test/components/generators/test_utils.py b/test/components/generators/test_utils.py index 0208c7702c..cc6d6edfdf 100644 --- a/test/components/generators/test_utils.py +++ b/test/components/generators/test_utils.py @@ -5,7 +5,7 @@ from openai.types.chat import chat_completion_chunk from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message -from haystack.dataclasses import ComponentInfo, StreamingChunk +from haystack.dataclasses import ComponentInfo, StreamingChunk, ToolCallDelta def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): @@ -40,6 +40,9 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "received_at": "2025-02-19T16:02:55.913919", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + start=True, + tool_call=ToolCallDelta(id="call_ZOj5l67zhZOx6jqjg7ATQwb6", tool_name="rag_pipeline_tool", arguments=""), ), StreamingChunk( content="", @@ -48,16 +51,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"qu', name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"qu') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.914439", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments='{"qu'), ), StreamingChunk( content="", @@ -66,16 +68,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ery":', name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ery":') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.924146", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments='ery":'), ), StreamingChunk( content="", @@ -84,16 +85,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments=' "Wher', name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments=' "Wher') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.924420", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments=' "Wher'), ), StreamingChunk( content="", @@ -102,16 +102,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="e do", name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="e do") ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.944398", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments="e do"), ), StreamingChunk( content="", @@ -120,16 +119,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="es Ma", name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="es Ma") ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.944958", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments="es Ma"), ), StreamingChunk( content="", @@ -138,15 +136,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="rk liv", name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments="rk liv") ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.945507", }, + component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments="rk liv"), ), StreamingChunk( content="", @@ -155,16 +153,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=0, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='e?"}', name=None), - type=None, + index=0, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='e?"}') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.946018", }, component_info=ComponentInfo(name="test", type="test"), + index=0, + tool_call=ToolCallDelta(arguments='e?"}'), ), StreamingChunk( content="", @@ -183,6 +180,9 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "received_at": "2025-02-19T16:02:55.946578", }, component_info=ComponentInfo(name="test", type="test"), + index=1, + start=True, + tool_call=ToolCallDelta(id="call_STxsYY69wVOvxWqopAt3uWTB", tool_name="get_weather", arguments=""), ), StreamingChunk( content="", @@ -191,16 +191,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"ci', name=None), - type=None, + index=1, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='{"ci') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.946981", }, component_info=ComponentInfo(name="test", type="test"), + index=1, + tool_call=ToolCallDelta(arguments='{"ci'), ), StreamingChunk( content="", @@ -209,16 +208,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ty": ', name=None), - type=None, + index=1, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='ty": ') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.947411", }, component_info=ComponentInfo(name="test", type="test"), + index=1, + tool_call=ToolCallDelta(arguments='ty": '), ), StreamingChunk( content="", @@ -227,16 +225,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='"Berli', name=None), - type=None, + index=1, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='"Berli') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.947643", }, component_info=ComponentInfo(name="test", type="test"), + index=1, + tool_call=ToolCallDelta(arguments='"Berli'), ), StreamingChunk( content="", @@ -245,16 +242,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): "index": 0, "tool_calls": [ chat_completion_chunk.ChoiceDeltaToolCall( - index=1, - id=None, - function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='n"}', name=None), - type=None, + index=1, function=chat_completion_chunk.ChoiceDeltaToolCallFunction(arguments='n"}') ) ], "finish_reason": None, "received_at": "2025-02-19T16:02:55.947939", }, component_info=ComponentInfo(name="test", type="test"), + index=1, + tool_call=ToolCallDelta(arguments='n"}'), ), StreamingChunk( content="", @@ -267,6 +263,29 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): }, component_info=ComponentInfo(name="test", type="test"), ), + StreamingChunk( + content="", + meta={ + "model": "gpt-4o-mini-2024-07-18", + "index": 0, + "tool_calls": None, + "finish_reason": None, + "received_at": "2025-02-19T16:02:55.948772", + "usage": { + "completion_tokens": 42, + "prompt_tokens": 282, + "total_tokens": 324, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + }, + }, + component_info=ComponentInfo(name="test", type="test"), + ), ] # Convert chunks to a chat message @@ -289,3 +308,15 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(): assert result.meta["finish_reason"] == "tool_calls" assert result.meta["index"] == 0 assert result.meta["completion_start_time"] == "2025-02-19T16:02:55.910076" + assert result.meta["usage"] == { + "completion_tokens": 42, + "prompt_tokens": 282, + "total_tokens": 324, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0, + }, + "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, + } diff --git a/test/dataclasses/test_streaming_chunk.py b/test/dataclasses/test_streaming_chunk.py index a12372cedc..695d155483 100644 --- a/test/dataclasses/test_streaming_chunk.py +++ b/test/dataclasses/test_streaming_chunk.py @@ -2,16 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 +import pytest -from haystack.dataclasses import StreamingChunk, ComponentInfo -from unittest.mock import Mock -from haystack.core.component import Component +from haystack.dataclasses import StreamingChunk, ComponentInfo, ToolCallDelta, ToolCallResult, ToolCall from haystack import component from haystack import Pipeline @component -class TestComponent: +class ExampleComponent: def __init__(self): self.name = "test_component" @@ -53,16 +52,52 @@ def test_create_chunk_with_all_fields(): assert chunk.component_info == component_info +def test_create_chunk_with_content_and_tool_call(): + with pytest.raises(ValueError): + # Can't have content + tool_call at the same time + StreamingChunk( + content="Test content", + meta={"key": "value"}, + tool_call=ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}'), + ) + + +def test_create_chunk_with_content_and_tool_call_result(): + with pytest.raises(ValueError): + # Can't have content + tool_call_result at the same time + StreamingChunk( + content="Test content", + meta={"key": "value"}, + tool_call_result=ToolCallResult( + result="output", + origin=ToolCall(id="123", tool_name="test_tool", arguments={"arg1": "value1"}), + error=False, + ), + ) + + def test_component_info_from_component(): - component = TestComponent() + component = ExampleComponent() component_info = ComponentInfo.from_component(component) - assert component_info.type == "test_streaming_chunk.TestComponent" + assert component_info.type == "test_streaming_chunk.ExampleComponent" def test_component_info_from_component_with_name_from_pipeline(): pipeline = Pipeline() - component = TestComponent() + component = ExampleComponent() pipeline.add_component("pipeline_component", component) component_info = ComponentInfo.from_component(component) - assert component_info.type == "test_streaming_chunk.TestComponent" + assert component_info.type == "test_streaming_chunk.ExampleComponent" assert component_info.name == "pipeline_component" + + +def test_tool_call_delta(): + tool_call = ToolCallDelta(id="123", tool_name="test_tool", arguments='{"arg1": "value1"}') + assert tool_call.id == "123" + assert tool_call.tool_name == "test_tool" + assert tool_call.arguments == '{"arg1": "value1"}' + + +def test_tool_call_delta_with_missing_fields(): + with pytest.raises(ValueError): + _ = ToolCallDelta(id="123")