Skip to content

feat: Update streaming chunk #9424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0446fe5
Start expanding StreamingChunk
sjrl May 15, 2025
f2ddbff
First pass at expanding Streaming Chunk
sjrl May 22, 2025
6cb7a31
Working version!
sjrl May 22, 2025
005ef69
Some tweaks and also make ToolInvoker stream a chunk with a finish re…
sjrl May 22, 2025
e29b6f2
Properly update test
sjrl May 22, 2025
5914d5b
Change to tool_name, remove kw_only since its python 3.10 only and up…
sjrl May 22, 2025
d141b47
Add reno
sjrl May 22, 2025
ac51918
Some cleanup
sjrl May 22, 2025
012c0bb
Fix unit tests
sjrl May 22, 2025
6048328
Fix mypy and integration test
sjrl May 22, 2025
010c037
Fix pylint
sjrl May 22, 2025
93758fd
Merge branch 'main' of github.com:deepset-ai/haystack into update-str…
sjrl May 26, 2025
f43477d
Start refactoring huggingface local api
sjrl May 26, 2025
a907d9e
Refactor openai generator and chat generator to reuse util methods
sjrl May 26, 2025
ced8fd8
Did some reorg
sjrl May 26, 2025
22314b8
Reusue utility method in HuggingFaceAPI
sjrl May 26, 2025
bc306d3
Merge branch 'main' of github.com:deepset-ai/haystack into update-str…
sjrl May 27, 2025
b625395
Merge branch 'main' of github.com:deepset-ai/haystack into update-str…
sjrl May 28, 2025
8cbefeb
Get rid of unneeded default values in tests
sjrl May 28, 2025
7cac572
Update conversion of streaming chunks to chat message to not rely on …
sjrl May 28, 2025
4bfbe58
Fix tests and loosen check in StreamingChunk post_init
sjrl May 28, 2025
3f8f661
Fixes
sjrl May 28, 2025
51c8440
Fix license header
sjrl May 28, 2025
658b47b
Add start and index to HFAPIGenerator
sjrl May 28, 2025
27ca068
Fix mypy
sjrl May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 82 additions & 26 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
StreamingChunk,
SyncStreamingCallbackT,
ToolCall,
ToolCallDelta,
select_streaming_callback,
)
from haystack.tools import (
Expand Down Expand Up @@ -418,30 +419,40 @@ def _prepare_api_call( # noqa: PLR0913
}

def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]:
openai_chunks: List[ChatCompletionChunk] = []
chunks: List[StreamingChunk] = []
chunk = None
chunk_deltas: List[StreamingChunk]
chunk_delta: StreamingChunk

for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunks.append(chunk_delta)
callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
chunk_deltas = self._convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=openai_chunks
)
for chunk_delta in chunk_deltas:
chunks.append(chunk_delta)
callback(chunk_delta)
openai_chunks.append(chunk)
return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)]

async def _handle_async_stream_response(
self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT
) -> List[ChatMessage]:
openai_chunks: List[ChatCompletionChunk] = []
chunks: List[StreamingChunk] = []
chunk = None
chunk_deltas: List[StreamingChunk]
chunk_delta: StreamingChunk

async for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunks.append(chunk_delta)
await callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
chunk_deltas = self._convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=openai_chunks
)
for chunk_delta in chunk_deltas:
chunks.append(chunk_delta)
await callback(chunk_delta)
openai_chunks.append(chunk)
return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)]

def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
if meta["finish_reason"] == "length":
Expand All @@ -458,13 +469,10 @@ def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
finish_reason=meta["finish_reason"],
)

def _convert_streaming_chunks_to_chat_message(
self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
) -> ChatMessage:
def _convert_streaming_chunks_to_chat_message(self, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.

:param last_chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.

:returns: The ChatMessage.
Expand Down Expand Up @@ -514,11 +522,11 @@ def _convert_streaming_chunks_to_chat_message(
finish_reason = finish_reasons[-1] if finish_reasons else None

meta = {
"model": last_chunk.model,
"model": chunks[-1].meta.get("model"),
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available
"usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available
}

return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
Expand Down Expand Up @@ -561,35 +569,83 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c
)
return chat_message

def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
def _convert_chat_completion_chunk_to_streaming_chunk(
self, chunk: ChatCompletionChunk, previous_chunks: List[ChatCompletionChunk]
) -> List[StreamingChunk]:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.

:param chunk: The chunk returned by the OpenAI API.
:param previous_chunks: The previous chunks received from the OpenAI API.

:returns:
The StreamingChunk.
"""
# if there are no choices, return an empty chunk
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
# It is also empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})
return [
StreamingChunk(
content="",
# Index is None since it's only used when a content block is present
index=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(),
"usage": self._serialize_usage(chunk.usage),
},
)
]

# we stream the content of the chunk if it's not a tool or function call
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
# but save the tool calls and function call in the meta if they are present
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
chunk_message.meta.update(
{

# create a list of ToolCallDelta objects from the tool calls
if choice.delta.tool_calls:
chunk_messages = []
for tool_call in choice.delta.tool_calls:
chunk_message = StreamingChunk(
content=content,
# We adopt the tool_call.index as the index of the chunk
index=tool_call.index,
tool_call=ToolCallDelta(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments or None
),
start=tool_call.function.name is not None,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
"usage": self._serialize_usage(chunk.usage),
},
)
chunk_messages.append(chunk_message)
return chunk_messages

# If we reach here content should not be empty
chunk_message = StreamingChunk(
content=content,
# We set the index to be 0 since if text content is being streamed then no tool calls are being streamed
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
# Anthropic/Bedrock
index=0,
tool_call=None,
# The first chunk is always a start message chunk, so if we reach here and previous_chunks is length 1
# then this is the start of text content
start=len(previous_chunks) == 1 or None,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
}
"usage": self._serialize_usage(chunk.usage),
},
)
return chunk_message
return [chunk_message]

def _serialize_usage(self, usage):
"""Convert OpenAI usage object to serializable dict recursively"""
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _stream_and_build_response(
if first_chunk_time is None:
first_chunk_time = datetime.now().isoformat()

stream_chunk = StreamingChunk(token.text, chunk_metadata)
stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata)
chunks.append(stream_chunk)
streaming_callback(stream_chunk)

Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _build_chunk(chunk: Any) -> StreamingChunk:
"""
choice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
chunk_message = StreamingChunk(content=content)
chunk_message.meta.update(
{
"model": chunk.model,
Expand Down
47 changes: 29 additions & 18 deletions haystack/components/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall

from haystack.dataclasses import StreamingChunk


Expand All @@ -22,24 +20,37 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
:param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
tool results.
"""
# Print tool call metadata if available (from ChatGenerator)
if chunk.meta.get("tool_calls"):
for tool_call in chunk.meta["tool_calls"]:
if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function:
# print the tool name
if tool_call.function.name and not tool_call.function.arguments:
print("[TOOL CALL]\n", flush=True, end="")
print(f"Tool: {tool_call.function.name} ", flush=True, end="")
print("\nArguments: ", flush=True, end="")

# print the tool arguments
if tool_call.function.arguments:
print(tool_call.function.arguments, flush=True, end="")

if chunk.start and chunk.index > 0:
# If this is not the first content block of the message, add two new lines
print("\n\n", flush=True, end="")

## Tool Call streaming
if chunk.tool_call:
# Presence of tool_name indicates beginning of a tool call
# or chunk.tool_call.name: would be equivalent here
if chunk.start:
print("[TOOL CALL]\n", flush=True, end="")
print(f"Tool: {chunk.tool_call.name} ", flush=True, end="")
print("\nArguments: ", flush=True, end="")

# print the tool arguments
if chunk.tool_call.arguments:
print(chunk.tool_call.arguments, flush=True, end="")

## Tool Call Result streaming
# Print tool call results if available (from ToolInvoker)
if chunk.meta.get("tool_result"):
print(f"\n\n[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="")
if chunk.tool_call_result:
# Tool Call Result is fully formed so delta accumulation is not needed
print(f"[TOOL RESULT]\n{chunk.tool_call_result}", flush=True, end="")

## Normal content streaming
# Print the main content of the chunk (from ChatGenerator)
if chunk.content:
if chunk.start:
print("[ASSISTANT]\n", flush=True, end="")
print(chunk.content, flush=True, end="")

# End of LLM assistant message so we add two new lines
# This ensures spacing between multiple LLM messages (e.g. Agent) or Tool Call Result
if chunk.meta.get("finish_reason") is not None:
print("\n\n", flush=True, end="")
9 changes: 9 additions & 0 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,17 @@ def run(
streaming_callback(
StreamingChunk(
content="",
index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
)

# We stream one more chunk that contains a finish_reason
if streaming_callback is not None:
streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))

return {"tool_messages": tool_messages, "state": state}

@component.output_types(tool_messages=List[ChatMessage], state=State)
Expand Down Expand Up @@ -604,6 +611,8 @@ async def run_async(
await streaming_callback(
StreamingChunk(
content="",
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
) # type: ignore[misc] # we have checked that streaming_callback is not None and async
Expand Down
2 changes: 2 additions & 0 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"AsyncStreamingCallbackT",
"StreamingCallbackT",
"SyncStreamingCallbackT",
"ToolCallDelta",
"select_streaming_callback",
],
}
Expand All @@ -35,6 +36,7 @@
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
ToolCallDelta,
select_streaming_callback,
)
else:
Expand Down
42 changes: 41 additions & 1 deletion haystack/dataclasses/streaming_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,34 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, Optional, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union

from haystack.dataclasses.chat_message import ToolCallResult
from haystack.utils.asynchronous import is_callable_async_compatible


# Similar to ChoiceDeltaToolCall from OpenAI
@dataclass(kw_only=True)
class ToolCallDelta:
"""
Represents a Tool call prepared by the model, usually contained in an assistant message.

:param id: The ID of the Tool call.
:param name: The name of the Tool to call.
:param arguments: Either the full arguments in JSON format or a delta of the arguments.
"""

id: Optional[str] = None # noqa: A003
name: Optional[str] = None
arguments: Optional[str] = None

def __post_init__(self):
if self.name is None and self.arguments is None:
raise ValueError("At least one of tool_name or arguments must be provided.")
# NOTE: We allow for name and arguments to both be present because some providers like Mistral provide the
# name and full arguments in one chunk


@dataclass
class StreamingChunk:
"""
Expand All @@ -17,10 +40,27 @@ class StreamingChunk:

:param content: The content of the message chunk as a string.
:param meta: A dictionary containing metadata related to the message chunk.
:param index: An optional integer index representing which content block this chunk belongs to.
:param tool_call: An optional ToolCallDelta object representing a tool call associated with the message chunk.
:param tool_call_result: An optional ToolCallResult object representing the result of a tool call.
:param start: A boolean indicating whether this chunk marks the start of a content block.
"""

content: str
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
index: Optional[int] = None
tool_call: Optional[ToolCallDelta] = None
tool_call_result: Optional[ToolCallResult] = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about making

content = Union[str, ToolCallDelta, ToolCallResult]

but this would be a breaking change b/c users expect content to always be a string. And this breaks StreamingChunk to ChatMessage implementations (mostly in private methods).

start: Optional[bool] = None

def __post_init__(self):
fields_set = sum(bool(x) for x in (self.content, self.tool_call, self.tool_call_result))
if fields_set > 1:
raise ValueError(
"Only one of `content`, `tool_call`, or `tool_call_result` may be set in a StreamingChunk. "
f"Got content: '{self.content}', tool_call: '{self.tool_call}', "
f"tool_call_result: '{self.tool_call_result}'"
)


SyncStreamingCallbackT = Callable[[StreamingChunk], None]
Expand Down
Loading
Loading