Skip to content

Commit 4b46384

Browse files
Amnah199anakin87
andauthored
feat: support structured outputs in LlamaStackChatGenerator (#2535)
* Support structured outputs * Fix tests * Fix linting * Update tests * Update integrations/llama_stack/tests/test_llama_stack_chat_generator.py Co-authored-by: Stefano Fiorucci <[email protected]> --------- Co-authored-by: Stefano Fiorucci <[email protected]>
1 parent 475f3d0 commit 4b46384

File tree

3 files changed

+151
-3
lines changed

3 files changed

+151
-3
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
# This example demonstrates how to use the LlamaStackChatGenerator component
7+
# with structured outputs.
8+
# To run this example, you will need to
9+
# set up Llama Stack Server and have a model available
10+
11+
from haystack.dataclasses import ChatMessage
12+
from pydantic import BaseModel
13+
14+
from haystack_integrations.components.generators.llama_stack import LlamaStackChatGenerator
15+
16+
17+
class NobelPrizeInfo(BaseModel):
18+
recipient_name: str
19+
award_year: int
20+
category: str
21+
achievement_description: str
22+
nationality: str
23+
24+
25+
chat_messages = [
26+
ChatMessage.from_user(
27+
"In 2021, American scientist David Julius received the Nobel Prize in"
28+
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
29+
" senses temperature and touch."
30+
)
31+
]
32+
component = LlamaStackChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
33+
results = component.run(chat_messages)
34+
35+
# print(results)

integrations/llama_stack/src/haystack_integrations/components/generators/llama_stack/chat/chat_generator.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from haystack.tools import ToolsType, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
1111
from haystack.utils import deserialize_callable, serialize_callable
1212
from haystack.utils.auth import Secret
13+
from openai.lib._pydantic import to_strict_json_schema
14+
from pydantic import BaseModel
1315

1416
logger = logging.getLogger(__name__)
1517

@@ -94,6 +96,13 @@ def __init__(
9496
events as they become available, with the stream terminated by a data: [DONE] message.
9597
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
9698
- `random_seed`: The seed to use for random sampling.
99+
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
100+
If provided, the output will always be validated against this
101+
format (unless the model returns a tool call).
102+
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
103+
Notes:
104+
- For structured outputs with streaming,
105+
the `response_format` must be a JSON schema and not a Pydantic model.
97106
:param timeout:
98107
Timeout for client calls using OpenAI API. If not set, it defaults to either the
99108
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
@@ -137,13 +146,29 @@ def to_dict(self) -> dict[str, Any]:
137146
The serialized component as a dictionary.
138147
"""
139148
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
149+
generation_kwargs = self.generation_kwargs.copy()
150+
response_format = generation_kwargs.get("response_format")
151+
# If the response format is a Pydantic model, it's converted to openai's json schema format
152+
# If it's already a json schema, it's left as is
153+
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
154+
json_schema = {
155+
"type": "json_schema",
156+
"json_schema": {
157+
"name": response_format.__name__,
158+
"strict": True,
159+
"schema": to_strict_json_schema(response_format),
160+
},
161+
}
162+
163+
generation_kwargs["response_format"] = json_schema
164+
140165
return default_to_dict(
141166
self,
142167
model=self.model,
143168
streaming_callback=callback_name,
144169
api_base_url=self.api_base_url,
145170
organization=self.organization,
146-
generation_kwargs=self.generation_kwargs,
171+
generation_kwargs=generation_kwargs,
147172
timeout=self.timeout,
148173
max_retries=self.max_retries,
149174
tools=serialize_tools_or_toolset(self.tools),

integrations/llama_stack/tests/test_llama_stack_chat_generator.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from datetime import datetime
23
from unittest.mock import patch
34

@@ -8,10 +9,22 @@
89
from haystack.tools import Tool, Toolset
910
from openai.types.chat import ChatCompletion, ChatCompletionMessage
1011
from openai.types.chat.chat_completion import Choice
12+
from pydantic import BaseModel
1113

1214
from haystack_integrations.components.generators.llama_stack.chat.chat_generator import LlamaStackChatGenerator
1315

1416

17+
class CalendarEvent(BaseModel):
18+
event_name: str
19+
event_date: str
20+
event_location: str
21+
22+
23+
@pytest.fixture
24+
def calendar_event_model():
25+
return CalendarEvent
26+
27+
1528
@pytest.fixture
1629
def chat_messages():
1730
return [
@@ -135,12 +148,17 @@ def test_to_dict_default(
135148

136149
def test_to_dict_with_parameters(
137150
self,
151+
calendar_event_model,
138152
):
139153
component = LlamaStackChatGenerator(
140154
model="ollama/llama3.2:3b",
141155
streaming_callback=print_streaming_chunk,
142156
api_base_url="test-base-url",
143-
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
157+
generation_kwargs={
158+
"max_tokens": 10,
159+
"some_test_param": "test-params",
160+
"response_format": calendar_event_model,
161+
},
144162
timeout=10,
145163
max_retries=10,
146164
tools=None,
@@ -158,7 +176,28 @@ def test_to_dict_with_parameters(
158176
"model": "ollama/llama3.2:3b",
159177
"api_base_url": "test-base-url",
160178
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
161-
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
179+
"generation_kwargs": {
180+
"max_tokens": 10,
181+
"some_test_param": "test-params",
182+
"response_format": {
183+
"type": "json_schema",
184+
"json_schema": {
185+
"name": "CalendarEvent",
186+
"strict": True,
187+
"schema": {
188+
"properties": {
189+
"event_name": {"title": "Event Name", "type": "string"},
190+
"event_date": {"title": "Event Date", "type": "string"},
191+
"event_location": {"title": "Event Location", "type": "string"},
192+
},
193+
"required": ["event_name", "event_date", "event_location"],
194+
"title": "CalendarEvent",
195+
"type": "object",
196+
"additionalProperties": False,
197+
},
198+
},
199+
},
200+
},
162201
"timeout": 10,
163202
"max_retries": 10,
164203
"tools": None,
@@ -407,3 +446,52 @@ def test_live_run_with_mixed_tools(self, mixed_tools):
407446
assert len(final_message.text) > 0
408447
assert "paris" in final_message.text.lower()
409448
assert "berlin" in final_message.text.lower()
449+
450+
@pytest.mark.integration
451+
def test_live_run_with_response_format_json_schema(self):
452+
response_schema = {
453+
"type": "json_schema",
454+
"json_schema": {
455+
"name": "CapitalCity",
456+
"strict": True,
457+
"schema": {
458+
"title": "CapitalCity",
459+
"type": "object",
460+
"properties": {
461+
"city": {"title": "City", "type": "string"},
462+
"country": {"title": "Country", "type": "string"},
463+
},
464+
"required": ["city", "country"],
465+
"additionalProperties": False,
466+
},
467+
},
468+
}
469+
470+
chat_messages = [ChatMessage.from_user("What's the capital of France?")]
471+
comp = LlamaStackChatGenerator(
472+
model="ollama/llama3.2:3b", generation_kwargs={"response_format": response_schema}
473+
)
474+
results = comp.run(chat_messages)
475+
assert len(results["replies"]) == 1
476+
message: ChatMessage = results["replies"][0]
477+
msg = json.loads(message.text)
478+
assert "Paris" in msg["city"]
479+
assert isinstance(msg["country"], str)
480+
assert "France" in msg["country"]
481+
assert message.meta["finish_reason"] == "stop"
482+
483+
@pytest.mark.integration
484+
def test_live_run_with_response_format_pydantic_model(self, calendar_event_model):
485+
chat_messages = [
486+
ChatMessage.from_user("The marketing summit takes place on October 12th at the Hilton Hotel downtown.")
487+
]
488+
component = LlamaStackChatGenerator(
489+
model="ollama/llama3.2:3b", generation_kwargs={"response_format": calendar_event_model}
490+
)
491+
results = component.run(chat_messages)
492+
assert len(results["replies"]) == 1
493+
message: ChatMessage = results["replies"][0]
494+
msg = json.loads(message.text)
495+
assert "Marketing Summit" in msg["event_name"]
496+
assert isinstance(msg["event_date"], str)
497+
assert isinstance(msg["event_location"], str)

0 commit comments

Comments
 (0)