Skip to content

Commit 9426ea2

Browse files
authored
fix: Update model in tests for IBM Watsonx and default models (#2515)
* Update model * update default model * change default model * Update model defaults * fix test
1 parent 08b2ea3 commit 9426ea2

File tree

8 files changed

+52
-87
lines changed

8 files changed

+52
-87
lines changed

integrations/watsonx/src/haystack_integrations/components/embedders/watsonx/document_embedder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class WatsonxDocumentEmbedder:
2929
]
3030
3131
document_embedder = WatsonxDocumentEmbedder(
32-
model="ibm/slate-30m-english-rtrvr",
32+
model="ibm/slate-30m-english-rtrvr-v2",
3333
api_key=Secret.from_env_var("WATSONX_API_KEY"),
3434
api_base_url="https://us-south.ml.cloud.ibm.com",
3535
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
@@ -45,7 +45,7 @@ class WatsonxDocumentEmbedder:
4545
def __init__(
4646
self,
4747
*,
48-
model: str = "ibm/slate-30m-english-rtrvr",
48+
model: str = "ibm/slate-30m-english-rtrvr-v2",
4949
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"), # noqa: B008
5050
api_base_url: str = "https://us-south.ml.cloud.ibm.com",
5151
project_id: Secret = Secret.from_env_var("WATSONX_PROJECT_ID"), # noqa: B008
@@ -64,7 +64,7 @@ def __init__(
6464
6565
:param model:
6666
The name of the model to use for calculating embeddings.
67-
Default is "ibm/slate-30m-english-rtrvr".
67+
Default is "ibm/slate-30m-english-rtrvr-v2".
6868
:param api_key:
6969
The WATSONX API key. Can be set via environment variable WATSONX_API_KEY.
7070
:param api_base_url:

integrations/watsonx/src/haystack_integrations/components/embedders/watsonx/text_embedder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class WatsonxTextEmbedder:
2525
text_to_embed = "I love pizza!"
2626
2727
text_embedder = WatsonxTextEmbedder(
28-
model="ibm/slate-30m-english-rtrvr",
28+
model="ibm/slate-30m-english-rtrvr-v2",
2929
api_key=Secret.from_env_var("WATSONX_API_KEY"),
3030
api_base_url="https://us-south.ml.cloud.ibm.com",
3131
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
@@ -34,15 +34,15 @@ class WatsonxTextEmbedder:
3434
print(text_embedder.run(text_to_embed))
3535
3636
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
37-
# 'meta': {'model': 'ibm/slate-30m-english-rtrvr',
37+
# 'meta': {'model': 'ibm/slate-30m-english-rtrvr-v2',
3838
# 'truncated_input_tokens': 3}}
3939
```
4040
"""
4141

4242
def __init__(
4343
self,
4444
*,
45-
model: str = "ibm/slate-30m-english-rtrvr",
45+
model: str = "ibm/slate-30m-english-rtrvr-v2",
4646
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"), # noqa: B008
4747
api_base_url: str = "https://us-south.ml.cloud.ibm.com",
4848
project_id: Secret = Secret.from_env_var("WATSONX_PROJECT_ID"), # noqa: B008
@@ -57,7 +57,7 @@ def __init__(
5757
5858
:param model:
5959
The name of the IBM watsonx model to use for calculating embeddings.
60-
Default is "ibm/slate-30m-english-rtrvr".
60+
Default is "ibm/slate-30m-english-rtrvr-v2".
6161
:param api_key:
6262
The WATSONX API key. Can be set via environment variable WATSONX_API_KEY.
6363
:param api_base_url:

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,8 @@ class WatsonxChatGenerator:
3939
models. It supports the [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage) format for both input
4040
and output, including multimodal inputs with text and images.
4141
42-
The generator works with IBM's foundation models including:
43-
- granite-13b-chat-v2
44-
- llama-2-70b-chat
45-
- llama-3-70b-instruct
46-
- llama-3-2-11b-vision-instruct (multimodal)
47-
- llama-3-2-90b-vision-instruct (multimodal)
48-
- pixtral-12b (multimodal)
49-
- Other watsonx.ai chat models
42+
The generator works with IBM's foundation models that are listed
43+
[here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&audience=wdp).
5044
5145
You can customize the generation behavior by passing parameters to the watsonx.ai API through the
5246
`generation_kwargs` argument. These parameters are passed directly to the watsonx.ai inference endpoint.
@@ -98,7 +92,7 @@ def __init__(
9892
self,
9993
*,
10094
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"), # noqa: B008
101-
model: str = "ibm/granite-3-2b-instruct",
95+
model: str = "ibm/granite-3-3-8b-instruct",
10296
project_id: Secret = Secret.from_env_var("WATSONX_PROJECT_ID"), # noqa: B008
10397
api_base_url: str = "https://us-south.ml.cloud.ibm.com",
10498
generation_kwargs: dict[str, Any] | None = None,

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/generator.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@ class WatsonxGenerator(WatsonxChatGenerator):
2121
This component extends WatsonxChatGenerator to provide the standard Generator interface that works with prompt
2222
strings instead of ChatMessage objects.
2323
24-
The generator works with IBM's foundation models including:
25-
- granite-13b-chat-v2
26-
- llama-2-70b-chat
27-
- llama-3-70b-instruct
28-
- Other watsonx.ai chat models
24+
The generator works with IBM's foundation models that are listed
25+
[here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&audience=wdp).
2926
3027
You can customize the generation behavior by passing parameters to the watsonx.ai API through the
3128
`generation_kwargs` argument. These parameters are passed directly to the watsonx.ai inference endpoint.
@@ -74,7 +71,7 @@ def __init__(
7471
self,
7572
*,
7673
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"), # noqa: B008
77-
model: str = "ibm/granite-3-2b-instruct",
74+
model: str = "ibm/granite-3-3-8b-instruct",
7875
project_id: Secret = Secret.from_env_var("WATSONX_PROJECT_ID"), # noqa: B008
7976
api_base_url: str = "https://us-south.ml.cloud.ibm.com",
8077
system_prompt: str | None = None,

integrations/watsonx/tests/test_chat_generator.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,31 +100,30 @@ async def __anext__(self):
100100

101101
def test_init_default(self, mock_watsonx):
102102
generator = WatsonxChatGenerator(
103-
model="ibm/granite-3-2b-instruct", project_id=Secret.from_token("fake-project-id")
103+
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_token("fake-project-id")
104104
)
105105

106106
_, kwargs = mock_watsonx["model"].call_args
107-
assert kwargs["model_id"] == "ibm/granite-3-2b-instruct"
107+
assert kwargs["model_id"] == "ibm/granite-3-3-8b-instruct"
108108
assert kwargs["project_id"] == "fake-project-id"
109109
assert kwargs["verify"] is None
110110

111-
assert generator.model == "ibm/granite-3-2b-instruct"
111+
assert generator.model == "ibm/granite-3-3-8b-instruct"
112112
assert isinstance(generator.project_id, Secret)
113113
assert generator.project_id.resolve_value() == "fake-project-id"
114114
assert generator.api_base_url == "https://us-south.ml.cloud.ibm.com"
115115

116116
def test_init_with_all_params(self, mock_watsonx):
117117
generator = WatsonxChatGenerator(
118118
api_key=Secret.from_token("test-api-key"),
119-
model="ibm/granite-3-2b-instruct",
120119
project_id=Secret.from_token("test-project"),
121120
api_base_url="https://custom-url.com",
122121
generation_kwargs={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9},
123122
verify=False,
124123
)
125124

126125
_, kwargs = mock_watsonx["model"].call_args
127-
assert kwargs["model_id"] == "ibm/granite-3-2b-instruct"
126+
assert kwargs["model_id"] == "ibm/granite-3-3-8b-instruct"
128127
assert kwargs["project_id"] == "test-project"
129128
assert kwargs["verify"] is False
130129

@@ -135,11 +134,10 @@ def test_init_fails_without_project(self, mock_watsonx):
135134
os.environ.pop("WATSONX_PROJECT_ID", None)
136135

137136
with pytest.raises(ValueError, match="None of the following authentication environment variables are set"):
138-
WatsonxChatGenerator(api_key=Secret.from_token("test-api-key"), model="ibm/granite-3-2b-instruct")
137+
WatsonxChatGenerator(api_key=Secret.from_token("test-api-key"))
139138

140139
def test_to_dict(self, mock_watsonx):
141140
generator = WatsonxChatGenerator(
142-
model="ibm/granite-3-2b-instruct",
143141
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
144142
generation_kwargs={"max_tokens": 100},
145143
)
@@ -150,7 +148,7 @@ def test_to_dict(self, mock_watsonx):
150148
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
151149
"init_parameters": {
152150
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
153-
"model": "ibm/granite-3-2b-instruct",
151+
"model": "ibm/granite-3-3-8b-instruct",
154152
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
155153
"api_base_url": "https://us-south.ml.cloud.ibm.com",
156154
"generation_kwargs": {"max_tokens": 100},
@@ -164,7 +162,6 @@ def test_to_dict(self, mock_watsonx):
164162

165163
def test_to_dict_with_params(self, mock_watsonx):
166164
generator = WatsonxChatGenerator(
167-
model="ibm/granite-3-2b-instruct",
168165
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
169166
generation_kwargs={"max_tokens": 100},
170167
streaming_callback=print_streaming_chunk,
@@ -176,7 +173,7 @@ def test_to_dict_with_params(self, mock_watsonx):
176173
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
177174
"init_parameters": {
178175
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
179-
"model": "ibm/granite-3-2b-instruct",
176+
"model": "ibm/granite-3-3-8b-instruct",
180177
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
181178
"api_base_url": "https://us-south.ml.cloud.ibm.com",
182179
"generation_kwargs": {"max_tokens": 100},
@@ -194,14 +191,14 @@ def test_from_dict(self, mock_watsonx):
194191
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
195192
"init_parameters": {
196193
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
197-
"model": "ibm/granite-3-2b-instruct",
194+
"model": "ibm/granite-3-3-8b-instruct",
198195
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
199196
"generation_kwargs": {"max_tokens": 100},
200197
},
201198
}
202199

203200
generator = WatsonxChatGenerator.from_dict(data)
204-
assert generator.model == "ibm/granite-3-2b-instruct"
201+
assert generator.model == "ibm/granite-3-3-8b-instruct"
205202
assert isinstance(generator.project_id, Secret)
206203
assert generator.project_id.resolve_value() == "fake-project-id"
207204
assert generator.generation_kwargs == {"max_tokens": 100}
@@ -212,7 +209,7 @@ def test_from_dict_with_callback(self, mock_watsonx):
212209
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
213210
"init_parameters": {
214211
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
215-
"model": "ibm/granite-3-2b-instruct",
212+
"model": "ibm/granite-3-3-8b-instruct",
216213
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
217214
"streaming_callback": callback_str,
218215
},
@@ -224,7 +221,6 @@ def test_from_dict_with_callback(self, mock_watsonx):
224221
def test_run_single_message(self, mock_watsonx):
225222
generator = WatsonxChatGenerator(
226223
api_key=Secret.from_token("test-api-key"),
227-
model="ibm/granite-3-2b-instruct",
228224
project_id=Secret.from_token("test-project"),
229225
)
230226

@@ -242,7 +238,6 @@ def test_run_single_message(self, mock_watsonx):
242238
def test_run_with_generation_params(self, mock_watsonx):
243239
generator = WatsonxChatGenerator(
244240
api_key=Secret.from_token("test-api-key"),
245-
model="ibm/granite-3-2b-instruct",
246241
project_id=Secret.from_token("test-project"),
247242
generation_kwargs={"max_tokens": 100, "temperature": 0.7, "top_p": 0.9},
248243
)
@@ -287,7 +282,6 @@ def test_run_with_streaming(self, mock_watsonx):
287282
def test_run_with_empty_messages(self, mock_watsonx):
288283
generator = WatsonxChatGenerator(
289284
api_key=Secret.from_token("test-api-key"),
290-
model="ibm/granite-3-2b-instruct",
291285
project_id=Secret.from_token("test-project"),
292286
)
293287

@@ -296,7 +290,6 @@ def test_run_with_empty_messages(self, mock_watsonx):
296290

297291
def test_skips_tool_messages(self, mock_watsonx):
298292
generator = WatsonxChatGenerator(
299-
model="ibm/granite-3-2b-instruct",
300293
project_id=Secret.from_token("test-project"),
301294
)
302295

@@ -313,7 +306,6 @@ def custom_callback(chunk: StreamingChunk):
313306
pass
314307

315308
generator = WatsonxChatGenerator(
316-
model="ibm/granite-3-2b-instruct",
317309
project_id=Secret.from_token("test-project"),
318310
streaming_callback=custom_callback,
319311
)
@@ -327,7 +319,6 @@ def run_callback(chunk: StreamingChunk):
327319
pass
328320

329321
generator = WatsonxChatGenerator(
330-
model="ibm/granite-3-2b-instruct",
331322
project_id=Secret.from_token("test-project"),
332323
streaming_callback=init_callback,
333324
)
@@ -343,7 +334,6 @@ def run_callback(chunk: StreamingChunk):
343334
async def test_run_async_single_message(self, mock_watsonx):
344335
generator = WatsonxChatGenerator(
345336
api_key=Secret.from_token("test-api-key"),
346-
model="ibm/granite-3-2b-instruct",
347337
project_id=Secret.from_token("test-project"),
348338
)
349339

@@ -358,7 +348,6 @@ async def test_run_async_single_message(self, mock_watsonx):
358348
async def test_run_async_streaming(self, mock_watsonx):
359349
generator = WatsonxChatGenerator(
360350
api_key=Secret.from_token("test-api-key"),
361-
model="ibm/granite-3-2b-instruct",
362351
project_id=Secret.from_token("test-project"),
363352
)
364353
received_chunks = []
@@ -551,7 +540,7 @@ class TestWatsonxChatGeneratorIntegration:
551540
)
552541
def test_live_run(self):
553542
generator = WatsonxChatGenerator(
554-
model="ibm/granite-3-2b-instruct",
543+
model="ibm/granite-3-3-8b-instruct",
555544
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
556545
generation_kwargs={"max_tokens": 50, "temperature": 0.7, "top_p": 0.9},
557546
)
@@ -572,7 +561,7 @@ def test_live_run(self):
572561
)
573562
def test_live_run_streaming(self):
574563
generator = WatsonxChatGenerator(
575-
model="ibm/granite-3-2b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
564+
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
576565
)
577566
collected_chunks = []
578567

@@ -597,7 +586,7 @@ def callback(chunk: StreamingChunk):
597586
)
598587
async def test_live_run_async(self):
599588
generator = WatsonxChatGenerator(
600-
model="ibm/granite-3-2b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
589+
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
601590
)
602591
messages = [ChatMessage.from_user("What's the capital of Germany? Answer concisely.")]
603592
results = await generator.run_async(messages=messages)

integrations/watsonx/tests/test_document_embedder.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_init_default(self, mock_watsonx):
4444
api_key="fake-api-key", url="https://us-south.ml.cloud.ibm.com"
4545
)
4646
mock_watsonx["embeddings"].assert_called_once_with(
47-
model_id="ibm/slate-30m-english-rtrvr",
47+
model_id="ibm/slate-30m-english-rtrvr-v2",
4848
credentials=mock_watsonx["creds_instance"],
4949
project_id="fake-project-id",
5050
params=None,
@@ -53,7 +53,7 @@ def test_init_default(self, mock_watsonx):
5353
max_retries=None,
5454
)
5555

56-
assert embedder.model == "ibm/slate-30m-english-rtrvr"
56+
assert embedder.model == "ibm/slate-30m-english-rtrvr-v2"
5757
assert embedder.prefix == ""
5858
assert embedder.suffix == ""
5959
assert embedder.batch_size == 1000
@@ -64,7 +64,6 @@ def test_init_default(self, mock_watsonx):
6464
def test_init_with_parameters(self, mock_watsonx):
6565
embedder = WatsonxDocumentEmbedder(
6666
api_key=Secret.from_token("fake-api-key"),
67-
model="ibm/slate-125m-english-rtrvr",
6867
api_base_url="https://custom-url.ibm.com",
6968
project_id=Secret.from_token("custom-project-id"),
7069
truncate_input_tokens=128,
@@ -78,7 +77,7 @@ def test_init_with_parameters(self, mock_watsonx):
7877

7978
mock_watsonx["credentials"].assert_called_once_with(api_key="fake-api-key", url="https://custom-url.ibm.com")
8079
mock_watsonx["embeddings"].assert_called_once_with(
81-
model_id="ibm/slate-125m-english-rtrvr",
80+
model_id="ibm/slate-30m-english-rtrvr-v2",
8281
credentials=mock_watsonx["creds_instance"],
8382
project_id="custom-project-id",
8483
params={"truncate_input_tokens": 128},
@@ -110,7 +109,7 @@ def test_to_dict(self, mock_watsonx):
110109
"type": "haystack_integrations.components.embedders.watsonx.document_embedder.WatsonxDocumentEmbedder",
111110
"init_parameters": {
112111
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
113-
"model": "ibm/slate-30m-english-rtrvr",
112+
"model": "ibm/slate-30m-english-rtrvr-v2",
114113
"api_base_url": "https://us-south.ml.cloud.ibm.com",
115114
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
116115
"truncate_input_tokens": None,
@@ -173,7 +172,7 @@ def test_run_empty_documents(self, mock_watsonx):
173172
result = embedder.run(documents=[])
174173
assert result == {
175174
"documents": [],
176-
"meta": {"model": "ibm/slate-30m-english-rtrvr", "truncate_input_tokens": None, "batch_size": 1000},
175+
"meta": {"model": "ibm/slate-30m-english-rtrvr-v2", "truncate_input_tokens": None, "batch_size": 1000},
177176
}
178177

179178

@@ -196,7 +195,6 @@ def test_documents(self):
196195
def test_run(self, test_documents):
197196
"""Test real API call with documents"""
198197
embedder = WatsonxDocumentEmbedder(
199-
model="ibm/slate-30m-english-rtrvr",
200198
api_key=Secret.from_env_var("WATSONX_API_KEY"),
201199
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
202200
truncate_input_tokens=128,
@@ -209,7 +207,7 @@ def test_run(self, test_documents):
209207
assert len(doc.embedding) > 0
210208
assert all(isinstance(x, float) for x in doc.embedding)
211209

212-
assert result["meta"]["model"] == "ibm/slate-30m-english-rtrvr"
210+
assert result["meta"]["model"] == "ibm/slate-30m-english-rtrvr-v2"
213211

214212
@pytest.mark.skipif(
215213
not os.environ.get("WATSONX_API_KEY") or not os.environ.get("WATSONX_PROJECT_ID"),
@@ -218,7 +216,6 @@ def test_run(self, test_documents):
218216
def test_batch_processing(self, test_documents):
219217
"""Test that batch processing works"""
220218
embedder = WatsonxDocumentEmbedder(
221-
model="ibm/slate-30m-english-rtrvr",
222219
api_key=Secret.from_env_var("WATSONX_API_KEY"),
223220
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
224221
batch_size=2,
@@ -239,7 +236,6 @@ def test_text_truncation(self):
239236
long_document = Document(content=long_content)
240237

241238
embedder = WatsonxDocumentEmbedder(
242-
model="ibm/slate-30m-english-rtrvr",
243239
api_key=Secret.from_env_var("WATSONX_API_KEY"),
244240
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
245241
truncate_input_tokens=4,

0 commit comments

Comments
 (0)