diff --git a/haystack/components/generators/chat/hugging_face_api.py b/haystack/components/generators/chat/hugging_face_api.py index f88afdd4f3..5df577985f 100644 --- a/haystack/components/generators/chat/hugging_face_api.py +++ b/haystack/components/generators/chat/hugging_face_api.py @@ -7,7 +7,7 @@ from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback +from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback from haystack.dataclasses.streaming_chunk import StreamingCallbackT from haystack.lazy_imports import LazyImport from haystack.tools import ( @@ -409,6 +409,10 @@ def _run_streaming( usage = None meta: Dict[str, Any] = {} + # get the component name and type + component_info = ComponentInfo.from_component(self) + + # Set up streaming handler for chunk in api_output: # The chunk with usage returns an empty array for choices if len(chunk.choices) > 0: @@ -423,7 +427,7 @@ def _run_streaming( if choice.finish_reason: finish_reason = choice.finish_reason - stream_chunk = StreamingChunk(text, meta) + stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info) streaming_callback(stream_chunk) if chunk.usage: @@ -505,6 +509,9 @@ async def _run_streaming_async( usage = None meta: Dict[str, Any] = {} + # get the component name and type + component_info = ComponentInfo.from_component(self) + async for chunk in api_output: # The chunk with usage returns an empty array for choices if len(chunk.choices) > 0: @@ -516,10 +523,7 @@ async def _run_streaming_async( text = choice.delta.content or "" generated_text += text - if choice.finish_reason: - finish_reason = choice.finish_reason - - stream_chunk = StreamingChunk(text, meta) + stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info) await streaming_callback(stream_chunk) # type: ignore if chunk.usage: diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index f95c53a347..71a6989b82 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast from haystack import component, default_from_dict, default_to_dict, logging -from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback +from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback from haystack.lazy_imports import LazyImport from haystack.tools import ( Tool, @@ -384,8 +384,13 @@ def run( ) logger.warning(msg, num_responses=num_responses) generation_kwargs["num_return_sequences"] = 1 + + # Get component name and type + component_info = ComponentInfo.from_component(self) # streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming - generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) + generation_kwargs["streamer"] = HFTokenStreamingHandler( + tokenizer, streaming_callback, stop_words, component_info + ) # convert messages to HF format hf_messages = [convert_message_to_hf_format(message) for message in messages] @@ -573,8 +578,11 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id ) - # Set up streaming handler - generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words) + # get the component name and type + component_info = ComponentInfo.from_component(self) + generation_kwargs["streamer"] = HFTokenStreamingHandler( + tokenizer, streaming_callback, stop_words, component_info + ) # Generate responses asynchronously output = await asyncio.get_running_loop().run_in_executor( diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 8f0d99cad4..c84f68afa4 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -16,6 +16,7 @@ from haystack.dataclasses import ( AsyncStreamingCallbackT, ChatMessage, + ComponentInfo, StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, @@ -570,14 +571,23 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio :returns: The StreamingChunk. """ + + # get the component name and type + component_info = ComponentInfo.from_component(self) + + # we stream the content of the chunk if it's not a tool or function call # if there are no choices, return an empty chunk if len(chunk.choices) == 0: - return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()}) + return StreamingChunk( + content="", + meta={"model": chunk.model, "received_at": datetime.now().isoformat()}, + component_info=component_info, + ) - # we stream the content of the chunk if it's not a tool or function call choice: ChunkChoice = chunk.choices[0] content = choice.delta.content or "" - chunk_message = StreamingChunk(content) + chunk_message = StreamingChunk(content, component_info=component_info) + # but save the tool calls and function call in the meta if they are present # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method chunk_message.meta.update( diff --git a/haystack/core/pipeline/base.py b/haystack/core/pipeline/base.py index 5dfeec3caa..a44cb5f724 100644 --- a/haystack/core/pipeline/base.py +++ b/haystack/core/pipeline/base.py @@ -366,6 +366,7 @@ def add_component(self, name: str, instance: Component) -> None: raise PipelineError(msg) setattr(instance, "__haystack_added_to_pipeline__", self) + setattr(instance, "__component_name__", name) # Add component to the graph, disconnected logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance) diff --git a/haystack/dataclasses/__init__.py b/haystack/dataclasses/__init__.py index 15da1ae84a..48c5544f49 100644 --- a/haystack/dataclasses/__init__.py +++ b/haystack/dataclasses/__init__.py @@ -20,6 +20,7 @@ "StreamingCallbackT", "SyncStreamingCallbackT", "select_streaming_callback", + "ComponentInfo", ], } @@ -32,6 +33,7 @@ from .state import State from .streaming_chunk import ( AsyncStreamingCallbackT, + ComponentInfo, StreamingCallbackT, StreamingChunk, SyncStreamingCallbackT, diff --git a/haystack/dataclasses/streaming_chunk.py b/haystack/dataclasses/streaming_chunk.py index b11ef5f7f2..b7d1e91d65 100644 --- a/haystack/dataclasses/streaming_chunk.py +++ b/haystack/dataclasses/streaming_chunk.py @@ -5,22 +5,54 @@ from dataclasses import dataclass, field from typing import Any, Awaitable, Callable, Dict, Optional, Union +from haystack.core.component import Component from haystack.utils.asynchronous import is_callable_async_compatible +@dataclass +class ComponentInfo: + """ + The `ComponentInfo` class encapsulates information about a component. + + :param type: The type of the component. + :param name: The name of the component assigned when adding it to a pipeline. + + """ + + type: str + name: Optional[str] = field(default=None) + + @classmethod + def from_component(cls, component: Component) -> "ComponentInfo": + """ + Create a `ComponentInfo` object from a `Component` instance. + + :param component: + The `Component` instance. + :returns: + The `ComponentInfo` object with the type and name of the given component. + """ + component_type = f"{component.__class__.__module__}.{component.__class__.__name__}" + component_name = getattr(component, "__component_name__", None) + return cls(type=component_type, name=component_name) + + @dataclass class StreamingChunk: """ - The StreamingChunk class encapsulates a segment of streamed content along with associated metadata. + The `StreamingChunk` class encapsulates a segment of streamed content along with associated metadata. This structure facilitates the handling and processing of streamed data in a systematic manner. :param content: The content of the message chunk as a string. :param meta: A dictionary containing metadata related to the message chunk. + :param component_info: A `ComponentInfo` object containing information about the component that generated the chunk, + such as the component name and type. """ content: str meta: Dict[str, Any] = field(default_factory=dict, hash=False) + component_info: Optional[ComponentInfo] = field(default=None, hash=False) SyncStreamingCallbackT = Callable[[StreamingChunk], None] diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index 946bcaeadc..a787265874 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union from haystack import logging -from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk +from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk from haystack.lazy_imports import LazyImport from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice @@ -359,13 +359,15 @@ def __init__( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], stream_handler: StreamingCallbackT, stop_words: Optional[List[str]] = None, + component_info: Optional[ComponentInfo] = None, ): super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore self.token_handler = stream_handler self.stop_words = stop_words or [] + self.component_info = component_info def on_finalized_text(self, word: str, stream_end: bool = False) -> None: """Callback function for handling the generated text.""" word_to_send = word + "\n" if stream_end else word if word_to_send.strip() not in self.stop_words: - self.token_handler(StreamingChunk(content=word_to_send)) + self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info)) diff --git a/releasenotes/notes/add-component-info-dataclass-be115dee2fa50abd.yaml b/releasenotes/notes/add-component-info-dataclass-be115dee2fa50abd.yaml new file mode 100644 index 0000000000..feaef293b2 --- /dev/null +++ b/releasenotes/notes/add-component-info-dataclass-be115dee2fa50abd.yaml @@ -0,0 +1,7 @@ +--- +features: + - | + - Add a `ComponentInfo` dataclass to the `haystack.dataclasses` module. + This dataclass is used to store information about the component. We pass it to `StreamingChunk` so we can tell from which component a stream is coming from. + + - Pass the `component_info` to the `StreamingChunk` in the `OpenAIChatGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceAPIChatGenerator` and `HuggingFaceLocalChatGenerator`. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 7ebdbe2091..dcd4aa2659 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -20,7 +20,7 @@ from haystack import component from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import StreamingChunk, ComponentInfo from haystack.utils.auth import Secret from haystack.dataclasses import ChatMessage, ToolCall from haystack.tools import ComponentTool, Tool @@ -625,6 +625,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.910076", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -644,6 +645,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.913919", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -661,6 +663,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.914439", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -678,6 +681,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.924146", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -695,6 +699,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.924420", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -712,6 +717,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.944398", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -729,6 +735,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.944958", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -763,6 +770,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.946018", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -782,6 +790,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.946578", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -799,6 +808,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.946981", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -816,6 +826,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.947411", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -833,6 +844,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.947643", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -850,6 +862,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": None, "received_at": "2025-02-19T16:02:55.947939", }, + component_info=ComponentInfo(name="test", type="test"), ), StreamingChunk( content="", @@ -860,6 +873,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self): "finish_reason": "tool_calls", "received_at": "2025-02-19T16:02:55.948772", }, + component_info=ComponentInfo(name="test", type="test"), ), ] diff --git a/test/dataclasses/test_streaming_chunk.py b/test/dataclasses/test_streaming_chunk.py index 6f161d87ef..a12372cedc 100644 --- a/test/dataclasses/test_streaming_chunk.py +++ b/test/dataclasses/test_streaming_chunk.py @@ -2,9 +2,21 @@ # # SPDX-License-Identifier: Apache-2.0 -import pytest -from haystack.dataclasses import StreamingChunk +from haystack.dataclasses import StreamingChunk, ComponentInfo +from unittest.mock import Mock +from haystack.core.component import Component +from haystack import component +from haystack import Pipeline + + +@component +class TestComponent: + def __init__(self): + self.name = "test_component" + + def run(self) -> str: + return "Test content" def test_create_chunk_with_content_and_metadata(): @@ -30,3 +42,27 @@ def test_create_chunk_with_empty_content(): chunk = StreamingChunk(content="") assert chunk.content == "" assert chunk.meta == {} + + +def test_create_chunk_with_all_fields(): + component_info = ComponentInfo(type="test.component", name="test_component") + chunk = StreamingChunk(content="Test content", meta={"key": "value"}, component_info=component_info) + + assert chunk.content == "Test content" + assert chunk.meta == {"key": "value"} + assert chunk.component_info == component_info + + +def test_component_info_from_component(): + component = TestComponent() + component_info = ComponentInfo.from_component(component) + assert component_info.type == "test_streaming_chunk.TestComponent" + + +def test_component_info_from_component_with_name_from_pipeline(): + pipeline = Pipeline() + component = TestComponent() + pipeline.add_component("pipeline_component", component) + component_info = ComponentInfo.from_component(component) + assert component_info.type == "test_streaming_chunk.TestComponent" + assert component_info.name == "pipeline_component"