Skip to content

Commit 64def6d

Browse files
Amnah199sjrl
andauthored
feat: add component name and type to StreamingChunk (#9426)
* Stream component name in openai * Fix type * PR comments * Update huggingface gen * Typing fix * Update huggingfacelocal gen * Fix errors * Remove model changes * Fix minor errors * Update releasenotes/notes/add-component-info-dataclass-be115dee2fa50abd.yaml Co-authored-by: Sebastian Husch Lee <[email protected]> * PR comments * update annotation * Update hf files * Fix linting * Add a from_component method * use add_component --------- Co-authored-by: Sebastian Husch Lee <[email protected]>
1 parent 085c3ad commit 64def6d

File tree

10 files changed

+135
-19
lines changed

10 files changed

+135
-19
lines changed

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Union
88

99
from haystack import component, default_from_dict, default_to_dict, logging
10-
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall, select_streaming_callback
10+
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingChunk, ToolCall, select_streaming_callback
1111
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1212
from haystack.lazy_imports import LazyImport
1313
from haystack.tools import (
@@ -409,6 +409,10 @@ def _run_streaming(
409409
usage = None
410410
meta: Dict[str, Any] = {}
411411

412+
# get the component name and type
413+
component_info = ComponentInfo.from_component(self)
414+
415+
# Set up streaming handler
412416
for chunk in api_output:
413417
# The chunk with usage returns an empty array for choices
414418
if len(chunk.choices) > 0:
@@ -423,7 +427,7 @@ def _run_streaming(
423427
if choice.finish_reason:
424428
finish_reason = choice.finish_reason
425429

426-
stream_chunk = StreamingChunk(text, meta)
430+
stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info)
427431
streaming_callback(stream_chunk)
428432

429433
if chunk.usage:
@@ -505,6 +509,9 @@ async def _run_streaming_async(
505509
usage = None
506510
meta: Dict[str, Any] = {}
507511

512+
# get the component name and type
513+
component_info = ComponentInfo.from_component(self)
514+
508515
async for chunk in api_output:
509516
# The chunk with usage returns an empty array for choices
510517
if len(chunk.choices) > 0:
@@ -516,10 +523,7 @@ async def _run_streaming_async(
516523
text = choice.delta.content or ""
517524
generated_text += text
518525

519-
if choice.finish_reason:
520-
finish_reason = choice.finish_reason
521-
522-
stream_chunk = StreamingChunk(text, meta)
526+
stream_chunk = StreamingChunk(content=text, meta=meta, component_info=component_info)
523527
await streaming_callback(stream_chunk) # type: ignore
524528

525529
if chunk.usage:

haystack/components/generators/chat/hugging_face_local.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
1111

1212
from haystack import component, default_from_dict, default_to_dict, logging
13-
from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall, select_streaming_callback
13+
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, ToolCall, select_streaming_callback
1414
from haystack.lazy_imports import LazyImport
1515
from haystack.tools import (
1616
Tool,
@@ -384,8 +384,13 @@ def run(
384384
)
385385
logger.warning(msg, num_responses=num_responses)
386386
generation_kwargs["num_return_sequences"] = 1
387+
388+
# Get component name and type
389+
component_info = ComponentInfo.from_component(self)
387390
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
388-
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
391+
generation_kwargs["streamer"] = HFTokenStreamingHandler(
392+
tokenizer, streaming_callback, stop_words, component_info
393+
)
389394

390395
# convert messages to HF format
391396
hf_messages = [convert_message_to_hf_format(message) for message in messages]
@@ -573,8 +578,11 @@ async def _run_streaming_async( # pylint: disable=too-many-positional-arguments
573578
generation_kwargs.get("pad_token_id", tokenizer.pad_token_id) or tokenizer.eos_token_id
574579
)
575580

576-
# Set up streaming handler
577-
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
581+
# get the component name and type
582+
component_info = ComponentInfo.from_component(self)
583+
generation_kwargs["streamer"] = HFTokenStreamingHandler(
584+
tokenizer, streaming_callback, stop_words, component_info
585+
)
578586

579587
# Generate responses asynchronously
580588
output = await asyncio.get_running_loop().run_in_executor(

haystack/components/generators/chat/openai.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from haystack.dataclasses import (
1717
AsyncStreamingCallbackT,
1818
ChatMessage,
19+
ComponentInfo,
1920
StreamingCallbackT,
2021
StreamingChunk,
2122
SyncStreamingCallbackT,
@@ -570,14 +571,23 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio
570571
:returns:
571572
The StreamingChunk.
572573
"""
574+
575+
# get the component name and type
576+
component_info = ComponentInfo.from_component(self)
577+
578+
# we stream the content of the chunk if it's not a tool or function call
573579
# if there are no choices, return an empty chunk
574580
if len(chunk.choices) == 0:
575-
return StreamingChunk(content="", meta={"model": chunk.model, "received_at": datetime.now().isoformat()})
581+
return StreamingChunk(
582+
content="",
583+
meta={"model": chunk.model, "received_at": datetime.now().isoformat()},
584+
component_info=component_info,
585+
)
576586

577-
# we stream the content of the chunk if it's not a tool or function call
578587
choice: ChunkChoice = chunk.choices[0]
579588
content = choice.delta.content or ""
580-
chunk_message = StreamingChunk(content)
589+
chunk_message = StreamingChunk(content, component_info=component_info)
590+
581591
# but save the tool calls and function call in the meta if they are present
582592
# and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
583593
chunk_message.meta.update(

haystack/core/pipeline/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def add_component(self, name: str, instance: Component) -> None:
366366
raise PipelineError(msg)
367367

368368
setattr(instance, "__haystack_added_to_pipeline__", self)
369+
setattr(instance, "__component_name__", name)
369370

370371
# Add component to the graph, disconnected
371372
logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)

haystack/dataclasses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"StreamingCallbackT",
2121
"SyncStreamingCallbackT",
2222
"select_streaming_callback",
23+
"ComponentInfo",
2324
],
2425
}
2526

@@ -32,6 +33,7 @@
3233
from .state import State
3334
from .streaming_chunk import (
3435
AsyncStreamingCallbackT,
36+
ComponentInfo,
3537
StreamingCallbackT,
3638
StreamingChunk,
3739
SyncStreamingCallbackT,

haystack/dataclasses/streaming_chunk.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,54 @@
55
from dataclasses import dataclass, field
66
from typing import Any, Awaitable, Callable, Dict, Optional, Union
77

8+
from haystack.core.component import Component
89
from haystack.utils.asynchronous import is_callable_async_compatible
910

1011

12+
@dataclass
13+
class ComponentInfo:
14+
"""
15+
The `ComponentInfo` class encapsulates information about a component.
16+
17+
:param type: The type of the component.
18+
:param name: The name of the component assigned when adding it to a pipeline.
19+
20+
"""
21+
22+
type: str
23+
name: Optional[str] = field(default=None)
24+
25+
@classmethod
26+
def from_component(cls, component: Component) -> "ComponentInfo":
27+
"""
28+
Create a `ComponentInfo` object from a `Component` instance.
29+
30+
:param component:
31+
The `Component` instance.
32+
:returns:
33+
The `ComponentInfo` object with the type and name of the given component.
34+
"""
35+
component_type = f"{component.__class__.__module__}.{component.__class__.__name__}"
36+
component_name = getattr(component, "__component_name__", None)
37+
return cls(type=component_type, name=component_name)
38+
39+
1140
@dataclass
1241
class StreamingChunk:
1342
"""
14-
The StreamingChunk class encapsulates a segment of streamed content along with associated metadata.
43+
The `StreamingChunk` class encapsulates a segment of streamed content along with associated metadata.
1544
1645
This structure facilitates the handling and processing of streamed data in a systematic manner.
1746
1847
:param content: The content of the message chunk as a string.
1948
:param meta: A dictionary containing metadata related to the message chunk.
49+
:param component_info: A `ComponentInfo` object containing information about the component that generated the chunk,
50+
such as the component name and type.
2051
"""
2152

2253
content: str
2354
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
55+
component_info: Optional[ComponentInfo] = field(default=None, hash=False)
2456

2557

2658
SyncStreamingCallbackT = Callable[[StreamingChunk], None]

haystack/utils/hf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Dict, List, Optional, Union
88

99
from haystack import logging
10-
from haystack.dataclasses import ChatMessage, StreamingCallbackT, StreamingChunk
10+
from haystack.dataclasses import ChatMessage, ComponentInfo, StreamingCallbackT, StreamingChunk
1111
from haystack.lazy_imports import LazyImport
1212
from haystack.utils.auth import Secret
1313
from haystack.utils.device import ComponentDevice
@@ -359,13 +359,15 @@ def __init__(
359359
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
360360
stream_handler: StreamingCallbackT,
361361
stop_words: Optional[List[str]] = None,
362+
component_info: Optional[ComponentInfo] = None,
362363
):
363364
super().__init__(tokenizer=tokenizer, skip_prompt=True) # type: ignore
364365
self.token_handler = stream_handler
365366
self.stop_words = stop_words or []
367+
self.component_info = component_info
366368

367369
def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
368370
"""Callback function for handling the generated text."""
369371
word_to_send = word + "\n" if stream_end else word
370372
if word_to_send.strip() not in self.stop_words:
371-
self.token_handler(StreamingChunk(content=word_to_send))
373+
self.token_handler(StreamingChunk(content=word_to_send, component_info=self.component_info))
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
features:
3+
- |
4+
- Add a `ComponentInfo` dataclass to the `haystack.dataclasses` module.
5+
This dataclass is used to store information about the component. We pass it to `StreamingChunk` so we can tell from which component a stream is coming from.
6+
7+
- Pass the `component_info` to the `StreamingChunk` in the `OpenAIChatGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceAPIChatGenerator` and `HuggingFaceLocalChatGenerator`.

test/components/generators/chat/test_openai.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from haystack import component
2222
from haystack.components.generators.utils import print_streaming_chunk
23-
from haystack.dataclasses import StreamingChunk
23+
from haystack.dataclasses import StreamingChunk, ComponentInfo
2424
from haystack.utils.auth import Secret
2525
from haystack.dataclasses import ChatMessage, ToolCall
2626
from haystack.tools import ComponentTool, Tool
@@ -625,6 +625,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
625625
"finish_reason": None,
626626
"received_at": "2025-02-19T16:02:55.910076",
627627
},
628+
component_info=ComponentInfo(name="test", type="test"),
628629
),
629630
StreamingChunk(
630631
content="",
@@ -644,6 +645,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
644645
"finish_reason": None,
645646
"received_at": "2025-02-19T16:02:55.913919",
646647
},
648+
component_info=ComponentInfo(name="test", type="test"),
647649
),
648650
StreamingChunk(
649651
content="",
@@ -661,6 +663,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
661663
"finish_reason": None,
662664
"received_at": "2025-02-19T16:02:55.914439",
663665
},
666+
component_info=ComponentInfo(name="test", type="test"),
664667
),
665668
StreamingChunk(
666669
content="",
@@ -678,6 +681,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
678681
"finish_reason": None,
679682
"received_at": "2025-02-19T16:02:55.924146",
680683
},
684+
component_info=ComponentInfo(name="test", type="test"),
681685
),
682686
StreamingChunk(
683687
content="",
@@ -695,6 +699,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
695699
"finish_reason": None,
696700
"received_at": "2025-02-19T16:02:55.924420",
697701
},
702+
component_info=ComponentInfo(name="test", type="test"),
698703
),
699704
StreamingChunk(
700705
content="",
@@ -712,6 +717,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
712717
"finish_reason": None,
713718
"received_at": "2025-02-19T16:02:55.944398",
714719
},
720+
component_info=ComponentInfo(name="test", type="test"),
715721
),
716722
StreamingChunk(
717723
content="",
@@ -729,6 +735,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
729735
"finish_reason": None,
730736
"received_at": "2025-02-19T16:02:55.944958",
731737
},
738+
component_info=ComponentInfo(name="test", type="test"),
732739
),
733740
StreamingChunk(
734741
content="",
@@ -763,6 +770,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
763770
"finish_reason": None,
764771
"received_at": "2025-02-19T16:02:55.946018",
765772
},
773+
component_info=ComponentInfo(name="test", type="test"),
766774
),
767775
StreamingChunk(
768776
content="",
@@ -782,6 +790,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
782790
"finish_reason": None,
783791
"received_at": "2025-02-19T16:02:55.946578",
784792
},
793+
component_info=ComponentInfo(name="test", type="test"),
785794
),
786795
StreamingChunk(
787796
content="",
@@ -799,6 +808,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
799808
"finish_reason": None,
800809
"received_at": "2025-02-19T16:02:55.946981",
801810
},
811+
component_info=ComponentInfo(name="test", type="test"),
802812
),
803813
StreamingChunk(
804814
content="",
@@ -816,6 +826,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
816826
"finish_reason": None,
817827
"received_at": "2025-02-19T16:02:55.947411",
818828
},
829+
component_info=ComponentInfo(name="test", type="test"),
819830
),
820831
StreamingChunk(
821832
content="",
@@ -833,6 +844,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
833844
"finish_reason": None,
834845
"received_at": "2025-02-19T16:02:55.947643",
835846
},
847+
component_info=ComponentInfo(name="test", type="test"),
836848
),
837849
StreamingChunk(
838850
content="",
@@ -850,6 +862,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
850862
"finish_reason": None,
851863
"received_at": "2025-02-19T16:02:55.947939",
852864
},
865+
component_info=ComponentInfo(name="test", type="test"),
853866
),
854867
StreamingChunk(
855868
content="",
@@ -860,6 +873,7 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk(self):
860873
"finish_reason": "tool_calls",
861874
"received_at": "2025-02-19T16:02:55.948772",
862875
},
876+
component_info=ComponentInfo(name="test", type="test"),
863877
),
864878
]
865879

0 commit comments

Comments
 (0)