Skip to content

Commit a2a2ceb

Browse files
authored
feat: LlamaStackChatGenerator update tools param to ToolsType (#2436)
* Update tools param to ToolsType * Fix llama stack build, use run instead * Test naming scheme * llama3.2:3b -> ollama/llama3.2:3b * Small fix * More small fixes * Update pydocs and examples * Lint
1 parent a3574a6 commit a2a2ceb

File tree

5 files changed

+160
-34
lines changed

5 files changed

+160
-34
lines changed

.github/workflows/llama_stack.yml

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,30 @@ jobs:
8282
OLLAMA_URL: http://localhost:11434
8383
shell: bash
8484
run: |
85-
pip install uv
86-
uv run --with llama-stack llama stack build --distro starter --image-type venv --run < /dev/null > server.log 2>&1 &
87-
sleep 120
88-
# Verify it's running
89-
curl -f http://localhost:8321/v1/models || { cat server.log; exit 1; }
90-
91-
echo "Llama Stack Server started successfully."
85+
set -euo pipefail
86+
pip install -q uv
87+
88+
# Install the starter distro's deps into the uv environment
89+
uv run --with llama-stack bash -lc 'llama stack list-deps starter | xargs -L1 uv pip install'
90+
91+
# Start Llama Stack (no more --image-type flag)
92+
uv run --with llama-stack llama stack run starter > server.log 2>&1 &
93+
SERVER_PID=$!
94+
95+
# Wait up to ~120s for health; fail fast if process dies
96+
for i in {1..60}; do
97+
if curl -fsS http://localhost:8321/v1/models >/dev/null; then
98+
echo "Llama Stack Server started successfully."
99+
break
100+
fi
101+
if ! kill -0 "$SERVER_PID" 2>/dev/null; then
102+
echo "Server exited early. Logs:"; cat server.log; exit 1
103+
fi
104+
sleep 2
105+
done
106+
107+
# Final health check
108+
curl -fsS http://localhost:8321/v1/models || { echo "Health check failed. Logs:"; cat server.log; exit 1; }
92109
93110
- name: Install Hatch
94111
run: pip install --upgrade hatch

integrations/llama_stack/examples/llama-stack-with-tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def weather(city: str):
3434
tool_invoker = ToolInvoker(tools=[weather_tool])
3535

3636
client = LlamaStackChatGenerator(
37-
model="llama3.2:3b", # model name varies depending on the inference provider used for the Llama Stack Server.
37+
model="ollama/llama3.2:3b", # model depends on the inference provider used for the Llama Stack Server.
3838
api_base_url="http://localhost:8321/v1/openai/v1",
3939
)
4040
messages = [ChatMessage.from_user("What's the weather in Tokyo?")]

integrations/llama_stack/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ classifiers = [
2020
"Programming Language :: Python :: Implementation :: CPython",
2121
"Programming Language :: Python :: Implementation :: PyPy",
2222
]
23-
dependencies = ["haystack-ai>=2.14", "llama-stack>=0.2.17"]
23+
dependencies = ["haystack-ai>=2.19.0", "llama-stack>=0.2.17"]
2424

2525
[project.urls]
2626
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama-stack#readme"

integrations/llama_stack/src/haystack_integrations/components/generators/llama_stack/chat/chat_generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Dict, Optional
66

77
from haystack import component, default_from_dict, default_to_dict, logging
88
from haystack.components.generators.chat import OpenAIChatGenerator
99
from haystack.dataclasses import StreamingCallbackT
10-
from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
10+
from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
1111
from haystack.utils import deserialize_callable, serialize_callable
1212
from haystack.utils.auth import Secret
1313

@@ -41,15 +41,15 @@ class LlamaStackChatGenerator(OpenAIChatGenerator):
4141
4242
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
4343
44-
client = LlamaStackChatGenerator(model="llama3.2:3b")
44+
client = LlamaStackChatGenerator(model="ollama/llama3.2:3b")
4545
response = client.run(messages)
4646
print(response)
4747
4848
>>{'replies': [ChatMessage(_content=[TextContent(text='Natural Language Processing (NLP)
4949
is a branch of artificial intelligence
5050
>>that focuses on enabling computers to understand, interpret, and generate human language in a way that is
5151
>>meaningful and useful.')], _role=<ChatRole.ASSISTANT: 'assistant'>, _name=None,
52-
>>_meta={'model': 'llama3.2:3b', 'index': 0, 'finish_reason': 'stop',
52+
>>_meta={'model': 'ollama/llama3.2:3b', 'index': 0, 'finish_reason': 'stop',
5353
>>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]}
5454
"""
5555

@@ -62,7 +62,7 @@ def __init__(
6262
streaming_callback: Optional[StreamingCallbackT] = None,
6363
generation_kwargs: Optional[Dict[str, Any]] = None,
6464
timeout: Optional[int] = None,
65-
tools: Optional[Union[List[Tool], Toolset]] = None,
65+
tools: Optional[ToolsType] = None,
6666
tools_strict: bool = False,
6767
max_retries: Optional[int] = None,
6868
http_client_kwargs: Optional[Dict[str, Any]] = None,
@@ -98,8 +98,8 @@ def __init__(
9898
Timeout for client calls using OpenAI API. If not set, it defaults to either the
9999
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
100100
:param tools:
101-
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
102-
list of `Tool` objects or a `Toolset` instance.
101+
A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
102+
Each tool should have a unique name.
103103
:param tools_strict:
104104
Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
105105
the schema provided in the `parameters` field of the tool definition, but this may increase latency.

integrations/llama_stack/tests/test_llama_stack_chat_generator.py

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytz
66
from haystack.components.generators.utils import print_streaming_chunk
77
from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk, ToolCall
8-
from haystack.tools import Tool
8+
from haystack.tools import Tool, Toolset
99
from openai.types.chat import ChatCompletion, ChatCompletionMessage
1010
from openai.types.chat.chat_completion import Choice
1111

@@ -25,6 +25,11 @@ def weather(city: str):
2525
return f"The weather in {city} is sunny and 32°C"
2626

2727

28+
def population(city: str):
29+
"""Get population for a given city."""
30+
return f"The population of {city} is 2.2 million"
31+
32+
2833
@pytest.fixture
2934
def tools():
3035
tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
@@ -38,6 +43,25 @@ def tools():
3843
return [tool]
3944

4045

46+
@pytest.fixture
47+
def mixed_tools():
48+
"""Fixture that returns a mixed list of Tool and Toolset."""
49+
weather_tool = Tool(
50+
name="weather",
51+
description="useful to determine the weather in a given location",
52+
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
53+
function=weather,
54+
)
55+
population_tool = Tool(
56+
name="population",
57+
description="useful to determine the population of a given location",
58+
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
59+
function=population,
60+
)
61+
toolset = Toolset([population_tool])
62+
return [weather_tool, toolset]
63+
64+
4165
@pytest.fixture
4266
def mock_chat_completion():
4367
"""
@@ -46,7 +70,7 @@ def mock_chat_completion():
4670
with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create:
4771
completion = ChatCompletion(
4872
id="foo",
49-
model="llama3.2:3b",
73+
model="ollama/llama3.2:3b",
5074
object="chat.completion",
5175
choices=[
5276
Choice(
@@ -66,27 +90,27 @@ def mock_chat_completion():
6690

6791
class TestLlamaStackChatGenerator:
6892
def test_init_default(self):
69-
component = LlamaStackChatGenerator(model="llama3.2:3b")
70-
assert component.model == "llama3.2:3b"
93+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b")
94+
assert component.model == "ollama/llama3.2:3b"
7195
assert component.api_base_url == "http://localhost:8321/v1/openai/v1"
7296
assert component.streaming_callback is None
7397
assert not component.generation_kwargs
7498

7599
def test_init_with_parameters(self):
76100
component = LlamaStackChatGenerator(
77-
model="llama3.2:3b",
101+
model="ollama/llama3.2:3b",
78102
streaming_callback=print_streaming_chunk,
79103
api_base_url="test-base-url",
80104
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
81105
)
82-
assert component.model == "llama3.2:3b"
106+
assert component.model == "ollama/llama3.2:3b"
83107
assert component.streaming_callback is print_streaming_chunk
84108
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
85109

86110
def test_to_dict_default(
87111
self,
88112
):
89-
component = LlamaStackChatGenerator(model="llama3.2:3b")
113+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b")
90114
data = component.to_dict()
91115

92116
assert (
@@ -95,7 +119,7 @@ def test_to_dict_default(
95119
)
96120

97121
expected_params = {
98-
"model": "llama3.2:3b",
122+
"model": "ollama/llama3.2:3b",
99123
"streaming_callback": None,
100124
"api_base_url": "http://localhost:8321/v1/openai/v1",
101125
"generation_kwargs": {},
@@ -113,7 +137,7 @@ def test_to_dict_with_parameters(
113137
self,
114138
):
115139
component = LlamaStackChatGenerator(
116-
model="llama3.2:3b",
140+
model="ollama/llama3.2:3b",
117141
streaming_callback=print_streaming_chunk,
118142
api_base_url="test-base-url",
119143
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
@@ -131,7 +155,7 @@ def test_to_dict_with_parameters(
131155
)
132156

133157
expected_params = {
134-
"model": "llama3.2:3b",
158+
"model": "ollama/llama3.2:3b",
135159
"api_base_url": "test-base-url",
136160
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
137161
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
@@ -153,7 +177,7 @@ def test_from_dict(
153177
"haystack_integrations.components.generators.llama_stack.chat.chat_generator.LlamaStackChatGenerator"
154178
),
155179
"init_parameters": {
156-
"model": "llama3.2:3b",
180+
"model": "ollama/llama3.2:3b",
157181
"api_base_url": "test-base-url",
158182
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
159183
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
@@ -165,7 +189,7 @@ def test_from_dict(
165189
},
166190
}
167191
component = LlamaStackChatGenerator.from_dict(data)
168-
assert component.model == "llama3.2:3b"
192+
assert component.model == "ollama/llama3.2:3b"
169193
assert component.streaming_callback is print_streaming_chunk
170194
assert component.api_base_url == "test-base-url"
171195
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@@ -175,8 +199,33 @@ def test_from_dict(
175199
assert component.max_retries == 10
176200
assert not component.tools_strict
177201

202+
def test_init_with_mixed_tools(self):
203+
def tool_fn(city: str) -> str:
204+
return city
205+
206+
weather_tool = Tool(
207+
name="weather",
208+
description="Weather lookup",
209+
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
210+
function=tool_fn,
211+
)
212+
population_tool = Tool(
213+
name="population",
214+
description="Population lookup",
215+
parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
216+
function=tool_fn,
217+
)
218+
toolset = Toolset([population_tool])
219+
220+
generator = LlamaStackChatGenerator(
221+
model="ollama/llama3.2:3b",
222+
tools=[weather_tool, toolset],
223+
)
224+
225+
assert generator.tools == [weather_tool, toolset]
226+
178227
def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002
179-
component = LlamaStackChatGenerator(model="llama3.2:3b")
228+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b")
180229
response = component.run(chat_messages)
181230

182231
# check that the component returns the correct ChatMessage response
@@ -188,7 +237,7 @@ def test_run(self, chat_messages, mock_chat_completion): # noqa: ARG002
188237

189238
def test_run_with_params(self, chat_messages, mock_chat_completion):
190239
component = LlamaStackChatGenerator(
191-
model="llama3.2:3b",
240+
model="ollama/llama3.2:3b",
192241
generation_kwargs={"max_tokens": 10, "temperature": 0.5},
193242
)
194243
response = component.run(chat_messages)
@@ -208,7 +257,7 @@ def test_run_with_params(self, chat_messages, mock_chat_completion):
208257
@pytest.mark.integration
209258
def test_live_run(self):
210259
chat_messages = [ChatMessage.from_user("What's the capital of France")]
211-
component = LlamaStackChatGenerator(model="llama3.2:3b")
260+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b")
212261
results = component.run(chat_messages)
213262
assert len(results["replies"]) == 1
214263
message: ChatMessage = results["replies"][0]
@@ -228,7 +277,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
228277
self.responses += chunk.content if chunk.content else ""
229278

230279
callback = Callback()
231-
component = LlamaStackChatGenerator(model="llama3.2:3b", streaming_callback=callback)
280+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", streaming_callback=callback)
232281
results = component.run([ChatMessage.from_user("What's the capital of France?")])
233282

234283
assert len(results["replies"]) == 1
@@ -244,7 +293,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
244293
@pytest.mark.integration
245294
def test_live_run_with_tools(self, tools):
246295
chat_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
247-
component = LlamaStackChatGenerator(model="llama3.2:3b", tools=tools)
296+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=tools)
248297
results = component.run(chat_messages)
249298
assert len(results["replies"]) == 1
250299
message = results["replies"][0]
@@ -263,7 +312,7 @@ def test_live_run_with_tools_and_response(self, tools):
263312
Integration test that the LlamaStackChatGenerator component can run with tools and get a response.
264313
"""
265314
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
266-
component = LlamaStackChatGenerator(model="llama3.2:3b", tools=tools)
315+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=tools)
267316
results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "auto"})
268317

269318
assert len(results["replies"]) > 0, "No replies received"
@@ -298,3 +347,63 @@ def test_live_run_with_tools_and_response(self, tools):
298347
assert not final_message.tool_call
299348
assert len(final_message.text) > 0
300349
assert "paris" in final_message.text.lower()
350+
351+
@pytest.mark.integration
352+
def test_live_run_with_mixed_tools(self, mixed_tools):
353+
"""
354+
Integration test that verifies LlamaStackChatGenerator works with mixed Tool and Toolset.
355+
This tests that the LLM can correctly invoke tools from both a standalone Tool and a Toolset.
356+
"""
357+
initial_messages = [
358+
ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?")
359+
]
360+
component = LlamaStackChatGenerator(model="ollama/llama3.2:3b", tools=mixed_tools)
361+
results = component.run(messages=initial_messages)
362+
363+
assert len(results["replies"]) > 0, "No replies received"
364+
365+
# Find the message with tool calls
366+
tool_call_message = None
367+
for message in results["replies"]:
368+
if message.tool_calls:
369+
tool_call_message = message
370+
break
371+
372+
assert tool_call_message is not None, "No message with tool call found"
373+
assert isinstance(tool_call_message, ChatMessage), "Tool message is not a ChatMessage instance"
374+
assert ChatMessage.is_from(tool_call_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"
375+
376+
tool_calls = tool_call_message.tool_calls
377+
assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}"
378+
379+
# Verify we got calls to both weather and population tools
380+
tool_names = {tc.tool_name for tc in tool_calls}
381+
assert "weather" in tool_names, "Expected 'weather' tool call"
382+
assert "population" in tool_names, "Expected 'population' tool call"
383+
384+
# Verify tool call details
385+
for tool_call in tool_calls:
386+
assert tool_call.id, "Tool call does not contain value for 'id' key"
387+
assert tool_call.tool_name in ["weather", "population"]
388+
assert "city" in tool_call.arguments
389+
assert tool_call.arguments["city"] in ["Paris", "Berlin"]
390+
assert tool_call_message.meta["finish_reason"] == "tool_calls"
391+
392+
# Mock the response we'd get from ToolInvoker
393+
tool_result_messages = []
394+
for tool_call in tool_calls:
395+
if tool_call.tool_name == "weather":
396+
result = "The weather in Paris is sunny and 32°C"
397+
else: # population
398+
result = "The population of Berlin is 2.2 million"
399+
tool_result_messages.append(ChatMessage.from_tool(tool_result=result, origin=tool_call))
400+
401+
new_messages = [*initial_messages, tool_call_message, *tool_result_messages]
402+
results = component.run(new_messages)
403+
404+
assert len(results["replies"]) == 1
405+
final_message = results["replies"][0]
406+
assert not final_message.tool_call
407+
assert len(final_message.text) > 0
408+
assert "paris" in final_message.text.lower()
409+
assert "berlin" in final_message.text.lower()

0 commit comments

Comments
 (0)