Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Expand Down Expand Up @@ -171,6 +171,9 @@ class HuggingFaceAPIChatGenerator:
```
"""

# Type annotation for the component name
__component_name__: str

def __init__( # pylint: disable=too-many-positional-arguments
self,
api_type: Union[HFGenerationAPIType, str],
Expand Down Expand Up @@ -409,6 +412,12 @@ def _run_streaming(
usage = None
meta: Dict[str, Any] = {}

# get component name and type
component_name = self.__component_name__ if hasattr(self, "__component_name__") else None
component_type = self.__class__.__module__ + "." + self.__class__.__name__
component_info = ComponentInfo(name=component_name, type=component_type)

# Set up streaming handler
for chunk in api_output:
# The chunk with usage returns an empty array for choices
if len(chunk.choices) > 0:
Expand All @@ -423,7 +432,7 @@ def _run_streaming(
if choice.finish_reason:
finish_reason = choice.finish_reason

stream_chunk = StreamingChunk(text, meta)
stream_chunk = StreamingChunk(text, meta, component_info)
streaming_callback(stream_chunk)

if chunk.usage:
Expand Down Expand Up @@ -505,6 +514,10 @@ async def _run_streaming_async(
usage = None
meta: Dict[str, Any] = {}

component_name = self.__component_name__ if hasattr(self, "__component_name__") else None
component_type = self.__class__.__module__ + "." + self.__class__.__name__
component_info = ComponentInfo(name=component_name, type=component_type)

async for chunk in api_output:
# The chunk with usage returns an empty array for choices
if len(chunk.choices) > 0:
Expand All @@ -516,10 +529,7 @@ async def _run_streaming_async(
text = choice.delta.content or ""
generated_text += text

if choice.finish_reason:
finish_reason = choice.finish_reason

stream_chunk = StreamingChunk(text, meta)
stream_chunk = StreamingChunk(text, meta, component_info)
await streaming_callback(stream_chunk) # type: ignore

if chunk.usage:
Expand Down
23 changes: 19 additions & 4 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
from haystack.lazy_imports import LazyImport
from haystack.tools import (
Tool,
Expand Down Expand Up @@ -120,6 +120,9 @@ class HuggingFaceLocalChatGenerator:
```
"""

# Type annotation for the component name
__component_name__: str

def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str = "HuggingFaceH4/zephyr-7b-beta",
Expand Down Expand Up @@ -381,8 +384,15 @@ def run(
)
logger.warning(msg, num_responses=num_responses)
generation_kwargs["num_return_sequences"] = 1

# Get component name and type
component_name = self.__component_name__ if hasattr(self, "__component_name__") else None
component_type = self.__class__.__module__ + "." + self.__class__.__name__
component_info = ComponentInfo(name=component_name, type=component_type)
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
generation_kwargs["streamer"] = HFTokenStreamingHandler(
tokenizer, streaming_callback, stop_words, component_info
)

# convert messages to HF format
hf_messages = [convert_message_to_hf_format(message) for message in messages]
Expand Down Expand Up @@ -565,8 +575,13 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
)

# Set up streaming handler
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
# Get component name and type
component_name = self.__component_name__ if hasattr(self, "__component_name__") else None
component_type = self.__class__.__module__ + "." + self.__class__.__name__
component_info = ComponentInfo(name=component_name, type=component_type)
generation_kwargs["streamer"] = HFTokenStreamingHandler(
tokenizer, streaming_callback, stop_words, component_info
)

# Generate responses asynchronously
output = await asyncio.get_running_loop().run_in_executor(
Expand Down
23 changes: 20 additions & 3 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from haystack.dataclasses import (
AsyncStreamingCallbackT,
ChatMessage,
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
Expand Down Expand Up @@ -79,6 +80,9 @@ class OpenAIChatGenerator:
```
"""

# Type annotation for the component name
__component_name__: str

def __init__( # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
Expand Down Expand Up @@ -570,14 +574,27 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio
:returns:
The StreamingChunk.
"""

# get the component name and type
component_info = ComponentInfo()
component_info.name = (
str(self.__component_name__) if getattr(self, "__component_name__", None) is not None else None
)
component_info.type = f"{self.__class__.__module__}.{self.__class__.__name__}"

# we stream the content of the chunk if it's not a tool or function call
# if there are no choices, return an empty chunk
if len(chunk.choices) == 0:
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})
return StreamingChunk(
content="",
meta={"model": chunk.model, "received_at": datetime.now().isoformat()},
component_info=component_info,
)

# we stream the content of the chunk if it's not a tool or function call
choice: ChunkChoice = chunk.choices[0]
content = choice.delta.content or ""
chunk_message = StreamingChunk(content)
chunk_message = StreamingChunk(content, component_info=component_info)

# but save the tool calls and function call in the meta if they are present
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
chunk_message.meta.update(
Expand Down
1 change: 1 addition & 0 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def add_component(self, name: str, instance: Component) -> None:
raise PipelineError(msg)

setattr(instance, "__haystack_added_to_pipeline__", self)
setattr(instance, "__component_name__", name)

# Add component to the graph, disconnected
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
Expand Down
2 changes: 2 additions & 0 deletions haystack/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"StreamingCallbackT",
"SyncStreamingCallbackT",
"select_streaming_callback",
"ComponentInfo",
],
}

Expand All @@ -32,6 +33,7 @@
from .state import State
from .streaming_chunk import (
AsyncStreamingCallbackT,
ComponentInfo,
StreamingCallbackT,
StreamingChunk,
SyncStreamingCallbackT,
Expand Down
16 changes: 16 additions & 0 deletions haystack/dataclasses/streaming_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
from haystack.utils.asynchronous import is_callable_async_compatible


@dataclass
class ComponentInfo:
"""
The ComponentInfo class encapsulates information about a component.

:param name: The name of the component assigned while adding it to the pipeline.
:param type: The type of the component.
"""

name: Optional[str] = field(default=None)
type: str = field(default="")


@dataclass
class StreamingChunk:
"""
Expand All @@ -17,10 +30,13 @@ class StreamingChunk:

:param content: The content of the message chunk as a string.
:param meta: A dictionary containing metadata related to the message chunk.
:param component_info: Information about the component that generated the chunk,
such as the component name and type.
"""

content: str
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
component_info: ComponentInfo = field(default_factory=ComponentInfo, hash=False)


SyncStreamingCallbackT = Callable[[StreamingChunk], None]
Expand Down
6 changes: 4 additions & 2 deletions haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Union

from haystack import logging
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
Expand Down Expand Up @@ -359,13 +359,15 @@ def __init__(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
stream_handler: Callable[[StreamingChunk], None],
stop_words: Optional[List[str]] = None,
component_info: ComponentInfo = ComponentInfo(),
):
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
self.token_handler = stream_handler
self.stop_words = stop_words or []
self.component_info = component_info

def on_finalized_text(self, word: str, stream_end: bool = False):
"""Callback function for handling the generated text."""
word_to_send = word + "\n" if stream_end else word
if word_to_send.strip() not in self.stop_words:
self.token_handler(StreamingChunk(content=word_to_send))
self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info))
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
- Add a `ComponentInfo` dataclass to the `haystack.dataclasses` module.
This dataclass is used to store information about the component in the `StreamingChunk`.
- Pass the `component_info` to the `StreamingChunk` in the `OpenAIChatGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceAPIChatGenerator` and `HuggingFaceLocalChatGenerator`.
16 changes: 15 additions & 1 deletion test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from haystack import component
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses import StreamingChunk, ComponentInfo
from haystack.utils.auth import Secret
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.tools import ComponentTool, Tool
Expand Down Expand Up @@ -624,6 +624,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.910076",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -643,6 +644,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.913919",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -660,6 +662,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.914439",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -677,6 +680,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.924146",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -694,6 +698,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.924420",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -711,6 +716,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.944398",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -728,6 +734,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.944958",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand Down Expand Up @@ -762,6 +769,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.946018",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -781,6 +789,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.946578",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -798,6 +807,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.946981",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -815,6 +825,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.947411",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -832,6 +843,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.947643",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -849,6 +861,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": None,
"received_at": "2025-02-19T16:02:55.947939",
},
component_info=ComponentInfo(name="test", type="test"),
),
StreamingChunk(
content="",
Expand All @@ -859,6 +872,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
"finish_reason": "tool_calls",
"received_at": "2025-02-19T16:02:55.948772",
},
component_info=ComponentInfo(name="test", type="test"),
),
]

Expand Down
Loading