Skip to content

feat: Update streaming chunk #9424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ def _run_streaming(
first_chunk_time = None
finish_reason = None
usage = None
meta: Dict[str, Any] = {}

for chunk in api_output:
# The chunk with usage returns an empty array for choices
Expand All @@ -423,7 +422,13 @@ def _run_streaming(
if choice.finish_reason:
finish_reason = choice.finish_reason

stream_chunk = StreamingChunk(text, meta)
stream_chunk = StreamingChunk(
content=text,
index=choice.index,
# TODO Correctly evaluate start
start=None,
meta={"model": chunk.model, "finish_reason": choice.finish_reason},
)
streaming_callback(stream_chunk)

if chunk.usage:
Expand All @@ -437,17 +442,16 @@ def _run_streaming(
else:
usage_dict = {"prompt_tokens": 0, "completion_tokens": 0}

meta.update(
{
message = ChatMessage.from_assistant(
text=generated_text,
meta={
"model": self._client.model,
"index": 0,
"finish_reason": finish_reason,
"usage": usage_dict,
"completion_start_time": first_chunk_time,
}
},
)

message = ChatMessage.from_assistant(text=generated_text, meta=meta)
return {"replies": [message]}

def _run_non_streaming(
Expand Down Expand Up @@ -503,7 +507,6 @@ async def _run_streaming_async(
first_chunk_time = None
finish_reason = None
usage = None
meta: Dict[str, Any] = {}

async for chunk in api_output:
# The chunk with usage returns an empty array for choices
Expand All @@ -519,7 +522,13 @@ async def _run_streaming_async(
if choice.finish_reason:
finish_reason = choice.finish_reason

stream_chunk = StreamingChunk(text, meta)
stream_chunk = StreamingChunk(
content=text,
index=choice.index,
# TODO Correctly evaluate start
start=None,
meta={"model": chunk.model, "finish_reason": choice.finish_reason},
)
await streaming_callback(stream_chunk) # type: ignore

if chunk.usage:
Expand All @@ -533,17 +542,16 @@ async def _run_streaming_async(
else:
usage_dict = {"prompt_tokens": 0, "completion_tokens": 0}

meta.update(
{
message = ChatMessage.from_assistant(
text=generated_text,
meta={
"model": self._async_client.model,
"index": 0,
"finish_reason": finish_reason,
"usage": usage_dict,
"completion_start_time": first_chunk_time,
}
},
)

message = ChatMessage.from_assistant(text=generated_text, meta=meta)
return {"replies": [message]}

async def _run_non_streaming_async(
Expand Down
110 changes: 84 additions & 26 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
StreamingChunk,
SyncStreamingCallbackT,
ToolCall,
ToolCallDelta,
select_streaming_callback,
)
from haystack.tools import (
Expand Down Expand Up @@ -418,30 +419,40 @@ def _prepare_api_call( # noqa: PLR0913
}

def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]:
openai_chunks: List[ChatCompletionChunk] = []
chunks: List[StreamingChunk] = []
chunk = None
chunk_deltas: List[StreamingChunk]
chunk_delta: StreamingChunk

for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunks.append(chunk_delta)
callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
chunk_deltas = self._convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=openai_chunks
)
for chunk_delta in chunk_deltas:
chunks.append(chunk_delta)
callback(chunk_delta)
openai_chunks.append(chunk)
return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)]

async def _handle_async_stream_response(
self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT
) -> List[ChatMessage]:
openai_chunks: List[ChatCompletionChunk] = []
chunks: List[StreamingChunk] = []
chunk = None
chunk_deltas: List[StreamingChunk]
chunk_delta: StreamingChunk

async for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
chunks.append(chunk_delta)
await callback(chunk_delta)
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
chunk_deltas = self._convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=openai_chunks
)
for chunk_delta in chunk_deltas:
chunks.append(chunk_delta)
await callback(chunk_delta)
openai_chunks.append(chunk)
return [self._convert_streaming_chunks_to_chat_message(chunks=chunks)]

def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
if meta["finish_reason"] == "length":
Expand All @@ -458,13 +469,10 @@ def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
finish_reason=meta["finish_reason"],
)

def _convert_streaming_chunks_to_chat_message(
self, last_chunk: ChatCompletionChunk, chunks: List[StreamingChunk]
) -> ChatMessage:
def _convert_streaming_chunks_to_chat_message(self, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.

:param last_chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.

:returns: The ChatMessage.
Expand Down Expand Up @@ -514,11 +522,11 @@ def _convert_streaming_chunks_to_chat_message(
finish_reason = finish_reasons[-1] if finish_reasons else None

meta = {
"model": last_chunk.model,
"model": chunks[-1].meta.get("model"),
"index": 0,
"finish_reason": finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": self._serialize_usage(last_chunk.usage), # last chunk has the final usage data if available
"usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available
}

return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
Expand Down Expand Up @@ -561,35 +569,85 @@ def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, c
)
return chat_message

def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
def _convert_chat_completion_chunk_to_streaming_chunk(
self, chunk: ChatCompletionChunk, previous_chunks: List[ChatCompletionChunk]
) -> List[StreamingChunk]:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.

:param chunk: The chunk returned by the OpenAI API.
:param previous_chunks: The previous chunks received from the OpenAI API.

:returns:
The StreamingChunk.
"""
# if there are no choices, return an empty chunk
# Choices is empty on the very first chunk which provides role information (e.g. "assistant").
# It is also empty if include_usage is set to True where the usage information is returned.
if len(chunk.choices) == 0:
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})
return [
StreamingChunk(
content="",
# Index is None since it's only used when a content block is present
index=None,
meta={
"model": chunk.model,
"received_at": datetime.now().isoformat(),
"usage": self._serialize_usage(chunk.usage),
},
)
]

# we stream the content of the chunk if it's not a tool or function call
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
# but save the tool calls and function call in the meta if they are present
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
chunk_message.meta.update(
{

# create a list of ToolCallDelta objects from the tool calls
if choice.delta.tool_calls:
chunk_messages = []
for tool_call in choice.delta.tool_calls:
function = tool_call.function
chunk_message = StreamingChunk(
content=content,
# We adopt the tool_call.index as the index of the chunk
index=tool_call.index,
tool_call=ToolCallDelta(
id=tool_call.id,
tool_name=function.name if function else None,
arguments=function.arguments if function and function.arguments else None,
),
start=function.name is not None if function else None,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
"usage": self._serialize_usage(chunk.usage),
},
)
chunk_messages.append(chunk_message)
return chunk_messages

# If we reach here content should not be empty
chunk_message = StreamingChunk(
content=content,
# We set the index to be 0 since if text content is being streamed then no tool calls are being streamed
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
# Anthropic/Bedrock
index=0,
# The first chunk is always a start message chunk, so if we reach here and previous_chunks is length 1
# then this is the start of text content
start=len(previous_chunks) == 1 or None,
meta={
"model": chunk.model,
"index": choice.index,
"tool_calls": choice.delta.tool_calls,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
}
"usage": self._serialize_usage(chunk.usage),
},
)
return chunk_message
return [chunk_message]

def _serialize_usage(self, usage):
"""Convert OpenAI usage object to serializable dict recursively"""
Expand Down
3 changes: 2 additions & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def _stream_and_build_response(
if first_chunk_time is None:
first_chunk_time = datetime.now().isoformat()

stream_chunk = StreamingChunk(token.text, chunk_metadata)
# TODO Consider adding start
stream_chunk = StreamingChunk(content=token.text, meta=chunk_metadata)
chunks.append(stream_chunk)
streaming_callback(stream_chunk)

Expand Down
9 changes: 5 additions & 4 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,15 @@ def _build_chunk(chunk: Any) -> StreamingChunk:
"""
choice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
chunk_message.meta.update(
{
# TODO Consider adding start
chunk_message = StreamingChunk(
content=content,
meta={
"model": chunk.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"received_at": datetime.now().isoformat(),
}
},
)
return chunk_message

Expand Down
47 changes: 29 additions & 18 deletions haystack/components/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall

from haystack.dataclasses import StreamingChunk


Expand All @@ -22,24 +20,37 @@ def print_streaming_chunk(chunk: StreamingChunk) -> None:
:param chunk: A chunk of streaming data containing content and optional metadata, such as tool calls and
tool results.
"""
# Print tool call metadata if available (from ChatGenerator)
if chunk.meta.get("tool_calls"):
for tool_call in chunk.meta["tool_calls"]:
if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function:
# print the tool name
if tool_call.function.name and not tool_call.function.arguments:
print("[TOOL CALL]\n", flush=True, end="")
print(f"Tool: {tool_call.function.name} ", flush=True, end="")
print("\nArguments: ", flush=True, end="")

# print the tool arguments
if tool_call.function.arguments:
print(tool_call.function.arguments, flush=True, end="")

if chunk.start and chunk.index and chunk.index > 0:
# If this is not the first content block of the message, add two new lines
print("\n\n", flush=True, end="")

## Tool Call streaming
if chunk.tool_call:
# Presence of tool_name indicates beginning of a tool call
# or chunk.tool_call.name: would be equivalent here
if chunk.start:
print("[TOOL CALL]\n", flush=True, end="")
print(f"Tool: {chunk.tool_call.tool_name} ", flush=True, end="")
print("\nArguments: ", flush=True, end="")

# print the tool arguments
if chunk.tool_call.arguments:
print(chunk.tool_call.arguments, flush=True, end="")

## Tool Call Result streaming
# Print tool call results if available (from ToolInvoker)
if chunk.meta.get("tool_result"):
print(f"\n\n[TOOL RESULT]\n{chunk.meta['tool_result']}\n\n", flush=True, end="")
if chunk.tool_call_result:
# Tool Call Result is fully formed so delta accumulation is not needed
print(f"[TOOL RESULT]\n{chunk.tool_call_result.result}", flush=True, end="")

## Normal content streaming
# Print the main content of the chunk (from ChatGenerator)
if chunk.content:
if chunk.start:
print("[ASSISTANT]\n", flush=True, end="")
print(chunk.content, flush=True, end="")

# End of LLM assistant message so we add two new lines
# This ensures spacing between multiple LLM messages (e.g. Agent) or Tool Call Result
if chunk.meta.get("finish_reason") is not None:
print("\n\n", flush=True, end="")
9 changes: 9 additions & 0 deletions haystack/components/tools/tool_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,10 +503,17 @@ def run(
streaming_callback(
StreamingChunk(
content="",
index=len(tool_messages) - 1,
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
)

# We stream one more chunk that contains a finish_reason
if streaming_callback is not None:
streaming_callback(StreamingChunk(content="", meta={"finish_reason": "tool_call_results"}))

return {"tool_messages": tool_messages, "state": state}

@component.output_types(tool_messages=List[ChatMessage], state=State)
Expand Down Expand Up @@ -604,6 +611,8 @@ async def run_async(
await streaming_callback(
StreamingChunk(
content="",
tool_call_result=tool_messages[-1].tool_call_results[0],
start=True,
meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
)
) # type: ignore[misc] # we have checked that streaming_callback is not None and async
Expand Down
Loading
Loading