55import pytz
66from haystack .components .generators .utils import print_streaming_chunk
77from haystack .dataclasses import ChatMessage , ChatRole , StreamingChunk , ToolCall
8- from haystack .tools import Tool
8+ from haystack .tools import Tool , Toolset
99from openai .types .chat import ChatCompletion , ChatCompletionMessage
1010from openai .types .chat .chat_completion import Choice
1111
@@ -25,6 +25,11 @@ def weather(city: str):
2525 return f"The weather in { city } is sunny and 32°C"
2626
2727
28+ def population (city : str ):
29+ """Get population for a given city."""
30+ return f"The population of { city } is 2.2 million"
31+
32+
2833@pytest .fixture
2934def tools ():
3035 tool_parameters = {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]}
@@ -38,6 +43,25 @@ def tools():
3843 return [tool ]
3944
4045
46+ @pytest .fixture
47+ def mixed_tools ():
48+ """Fixture that returns a mixed list of Tool and Toolset."""
49+ weather_tool = Tool (
50+ name = "weather" ,
51+ description = "useful to determine the weather in a given location" ,
52+ parameters = {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]},
53+ function = weather ,
54+ )
55+ population_tool = Tool (
56+ name = "population" ,
57+ description = "useful to determine the population of a given location" ,
58+ parameters = {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]},
59+ function = population ,
60+ )
61+ toolset = Toolset ([population_tool ])
62+ return [weather_tool , toolset ]
63+
64+
4165@pytest .fixture
4266def mock_chat_completion ():
4367 """
@@ -46,7 +70,7 @@ def mock_chat_completion():
4670 with patch ("openai.resources.chat.completions.Completions.create" ) as mock_chat_completion_create :
4771 completion = ChatCompletion (
4872 id = "foo" ,
49- model = "llama3.2:3b" ,
73+ model = "ollama/ llama3.2:3b" ,
5074 object = "chat.completion" ,
5175 choices = [
5276 Choice (
@@ -66,27 +90,27 @@ def mock_chat_completion():
6690
6791class TestLlamaStackChatGenerator :
6892 def test_init_default (self ):
69- component = LlamaStackChatGenerator (model = "llama3.2:3b" )
70- assert component .model == "llama3.2:3b"
93+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" )
94+ assert component .model == "ollama/ llama3.2:3b"
7195 assert component .api_base_url == "http://localhost:8321/v1/openai/v1"
7296 assert component .streaming_callback is None
7397 assert not component .generation_kwargs
7498
7599 def test_init_with_parameters (self ):
76100 component = LlamaStackChatGenerator (
77- model = "llama3.2:3b" ,
101+ model = "ollama/ llama3.2:3b" ,
78102 streaming_callback = print_streaming_chunk ,
79103 api_base_url = "test-base-url" ,
80104 generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
81105 )
82- assert component .model == "llama3.2:3b"
106+ assert component .model == "ollama/ llama3.2:3b"
83107 assert component .streaming_callback is print_streaming_chunk
84108 assert component .generation_kwargs == {"max_tokens" : 10 , "some_test_param" : "test-params" }
85109
86110 def test_to_dict_default (
87111 self ,
88112 ):
89- component = LlamaStackChatGenerator (model = "llama3.2:3b" )
113+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" )
90114 data = component .to_dict ()
91115
92116 assert (
@@ -95,7 +119,7 @@ def test_to_dict_default(
95119 )
96120
97121 expected_params = {
98- "model" : "llama3.2:3b" ,
122+ "model" : "ollama/ llama3.2:3b" ,
99123 "streaming_callback" : None ,
100124 "api_base_url" : "http://localhost:8321/v1/openai/v1" ,
101125 "generation_kwargs" : {},
@@ -113,7 +137,7 @@ def test_to_dict_with_parameters(
113137 self ,
114138 ):
115139 component = LlamaStackChatGenerator (
116- model = "llama3.2:3b" ,
140+ model = "ollama/ llama3.2:3b" ,
117141 streaming_callback = print_streaming_chunk ,
118142 api_base_url = "test-base-url" ,
119143 generation_kwargs = {"max_tokens" : 10 , "some_test_param" : "test-params" },
@@ -131,7 +155,7 @@ def test_to_dict_with_parameters(
131155 )
132156
133157 expected_params = {
134- "model" : "llama3.2:3b" ,
158+ "model" : "ollama/ llama3.2:3b" ,
135159 "api_base_url" : "test-base-url" ,
136160 "streaming_callback" : "haystack.components.generators.utils.print_streaming_chunk" ,
137161 "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
@@ -153,7 +177,7 @@ def test_from_dict(
153177 "haystack_integrations.components.generators.llama_stack.chat.chat_generator.LlamaStackChatGenerator"
154178 ),
155179 "init_parameters" : {
156- "model" : "llama3.2:3b" ,
180+ "model" : "ollama/ llama3.2:3b" ,
157181 "api_base_url" : "test-base-url" ,
158182 "streaming_callback" : "haystack.components.generators.utils.print_streaming_chunk" ,
159183 "generation_kwargs" : {"max_tokens" : 10 , "some_test_param" : "test-params" },
@@ -165,7 +189,7 @@ def test_from_dict(
165189 },
166190 }
167191 component = LlamaStackChatGenerator .from_dict (data )
168- assert component .model == "llama3.2:3b"
192+ assert component .model == "ollama/ llama3.2:3b"
169193 assert component .streaming_callback is print_streaming_chunk
170194 assert component .api_base_url == "test-base-url"
171195 assert component .generation_kwargs == {"max_tokens" : 10 , "some_test_param" : "test-params" }
@@ -175,8 +199,33 @@ def test_from_dict(
175199 assert component .max_retries == 10
176200 assert not component .tools_strict
177201
202+ def test_init_with_mixed_tools (self ):
203+ def tool_fn (city : str ) -> str :
204+ return city
205+
206+ weather_tool = Tool (
207+ name = "weather" ,
208+ description = "Weather lookup" ,
209+ parameters = {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]},
210+ function = tool_fn ,
211+ )
212+ population_tool = Tool (
213+ name = "population" ,
214+ description = "Population lookup" ,
215+ parameters = {"type" : "object" , "properties" : {"city" : {"type" : "string" }}, "required" : ["city" ]},
216+ function = tool_fn ,
217+ )
218+ toolset = Toolset ([population_tool ])
219+
220+ generator = LlamaStackChatGenerator (
221+ model = "ollama/llama3.2:3b" ,
222+ tools = [weather_tool , toolset ],
223+ )
224+
225+ assert generator .tools == [weather_tool , toolset ]
226+
178227 def test_run (self , chat_messages , mock_chat_completion ): # noqa: ARG002
179- component = LlamaStackChatGenerator (model = "llama3.2:3b" )
228+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" )
180229 response = component .run (chat_messages )
181230
182231 # check that the component returns the correct ChatMessage response
@@ -188,7 +237,7 @@ def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002
188237
189238 def test_run_with_params (self , chat_messages , mock_chat_completion ):
190239 component = LlamaStackChatGenerator (
191- model = "llama3.2:3b" ,
240+ model = "ollama/ llama3.2:3b" ,
192241 generation_kwargs = {"max_tokens" : 10 , "temperature" : 0.5 },
193242 )
194243 response = component .run (chat_messages )
@@ -208,7 +257,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion):
208257 @pytest .mark .integration
209258 def test_live_run (self ):
210259 chat_messages = [ChatMessage .from_user ("What's the capital of France" )]
211- component = LlamaStackChatGenerator (model = "llama3.2:3b" )
260+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" )
212261 results = component .run (chat_messages )
213262 assert len (results ["replies" ]) == 1
214263 message : ChatMessage = results ["replies" ][0 ]
@@ -228,7 +277,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
228277 self .responses += chunk .content if chunk .content else ""
229278
230279 callback = Callback ()
231- component = LlamaStackChatGenerator (model = "llama3.2:3b" , streaming_callback = callback )
280+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" , streaming_callback = callback )
232281 results = component .run ([ChatMessage .from_user ("What's the capital of France?" )])
233282
234283 assert len (results ["replies" ]) == 1
@@ -244,7 +293,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
244293 @pytest .mark .integration
245294 def test_live_run_with_tools (self , tools ):
246295 chat_messages = [ChatMessage .from_user ("What's the weather like in Paris?" )]
247- component = LlamaStackChatGenerator (model = "llama3.2:3b" , tools = tools )
296+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" , tools = tools )
248297 results = component .run (chat_messages )
249298 assert len (results ["replies" ]) == 1
250299 message = results ["replies" ][0 ]
@@ -263,7 +312,7 @@ def test_live_run_with_tools_and_response(self, tools):
263312 Integration test that the LlamaStackChatGenerator component can run with tools and get a response.
264313 """
265314 initial_messages = [ChatMessage .from_user ("What's the weather like in Paris?" )]
266- component = LlamaStackChatGenerator (model = "llama3.2:3b" , tools = tools )
315+ component = LlamaStackChatGenerator (model = "ollama/ llama3.2:3b" , tools = tools )
267316 results = component .run (messages = initial_messages , generation_kwargs = {"tool_choice" : "auto" })
268317
269318 assert len (results ["replies" ]) > 0 , "No replies received"
@@ -298,3 +347,63 @@ def test_live_run_with_tools_and_response(self, tools):
298347 assert not final_message .tool_call
299348 assert len (final_message .text ) > 0
300349 assert "paris" in final_message .text .lower ()
350+
351+ @pytest .mark .integration
352+ def test_live_run_with_mixed_tools (self , mixed_tools ):
353+ """
354+ Integration test that verifies LlamaStackChatGenerator works with mixed Tool and Toolset.
355+ This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset.
356+ """
357+ initial_messages = [
358+ ChatMessage .from_user ("What's the weather like in Paris and what is the population of Berlin?" )
359+ ]
360+ component = LlamaStackChatGenerator (model = "ollama/llama3.2:3b" , tools = mixed_tools )
361+ results = component .run (messages = initial_messages )
362+
363+ assert len (results ["replies" ]) > 0 , "No replies received"
364+
365+ # Find the message with tool calls
366+ tool_call_message = None
367+ for message in results ["replies" ]:
368+ if message .tool_calls :
369+ tool_call_message = message
370+ break
371+
372+ assert tool_call_message is not None , "No message with tool call found"
373+ assert isinstance (tool_call_message , ChatMessage ), "Tool message is not a ChatMessage instance"
374+ assert ChatMessage .is_from (tool_call_message , ChatRole .ASSISTANT ), "Tool message is not from the assistant"
375+
376+ tool_calls = tool_call_message .tool_calls
377+ assert len (tool_calls ) == 2 , f"Expected 2 tool calls, got { len (tool_calls )} "
378+
379+ # Verify we got calls to both weather and population tools
380+ tool_names = {tc .tool_name for tc in tool_calls }
381+ assert "weather" in tool_names , "Expected 'weather' tool call"
382+ assert "population" in tool_names , "Expected 'population' tool call"
383+
384+ # Verify tool call details
385+ for tool_call in tool_calls :
386+ assert tool_call .id , "Tool call does not contain value for 'id' key"
387+ assert tool_call .tool_name in ["weather" , "population" ]
388+ assert "city" in tool_call .arguments
389+ assert tool_call .arguments ["city" ] in ["Paris" , "Berlin" ]
390+ assert tool_call_message .meta ["finish_reason" ] == "tool_calls"
391+
392+ # Mock the response we'd get from ToolInvoker
393+ tool_result_messages = []
394+ for tool_call in tool_calls :
395+ if tool_call .tool_name == "weather" :
396+ result = "The weather in Paris is sunny and 32°C"
397+ else : # population
398+ result = "The population of Berlin is 2.2 million"
399+ tool_result_messages .append (ChatMessage .from_tool (tool_result = result , origin = tool_call ))
400+
401+ new_messages = [* initial_messages , tool_call_message , * tool_result_messages ]
402+ results = component .run (new_messages )
403+
404+ assert len (results ["replies" ]) == 1
405+ final_message = results ["replies" ][0 ]
406+ assert not final_message .tool_call
407+ assert len (final_message .text ) > 0
408+ assert "paris" in final_message .text .lower ()
409+ assert "berlin" in final_message .text .lower ()
0 commit comments