Skip to content

Commit 95a1445

Browse files
authored
chore: Update CohereChatGenerator default model to command-a-03-2025 (#2553)
* Update CohereChatGenerator default model to command-a-03-2025 * Lint * Fix test
1 parent ed73460 commit 95a1445

File tree

7 files changed

+37
-38
lines changed

7 files changed

+37
-38
lines changed

integrations/cohere/examples/cohere_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# The pipeline workflow:
99
# 1. Receives a user message requesting to create a JSON object from "Peter Parker" aka Superman.
10-
# 2. Processes the message through components to generate a response using Cohere command-r-08-2024 model.
10+
# 2. Processes the message through components to generate a response using Cohere command-a-03-2025 model.
1111
# 3. Validates the generated response against a predefined JSON schema for person data.
1212
# 4. If the response does not meet the schema, the JsonSchemaValidator provides details on how to correct the errors.
1313
# 4a. The pipeline loops back, using the error information to generate a new JSON object until it satisfies the schema.
@@ -38,7 +38,7 @@
3838

3939
# Add components to the pipeline
4040
pipe.add_component("joiner", BranchJoiner(list[ChatMessage]))
41-
pipe.add_component("fc_llm", CohereChatGenerator(model="command-r-08-2024"))
41+
pipe.add_component("fc_llm", CohereChatGenerator())
4242
pipe.add_component("validator", JsonSchemaValidator(json_schema=person_schema))
4343
(pipe.add_component("adapter", OutputAdapter("{{chat_message}}", list[ChatMessage])),)
4444
# And connect them

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ class CohereChatGenerator:
439439
from haystack.utils import Secret
440440
from haystack_integrations.components.generators.cohere import CohereChatGenerator
441441
442-
client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY"))
442+
client = CohereChatGenerator(api_key=Secret.from_env_var("COHERE_API_KEY"))
443443
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
444444
client.run(messages)
445445
@@ -499,7 +499,7 @@ def weather(city: str) -> str:
499499
500500
# Create and set up the pipeline
501501
pipeline = Pipeline()
502-
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
502+
pipeline.add_component("generator", CohereChatGenerator(tools=[weather_tool]))
503503
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))
504504
pipeline.connect("generator", "tool_invoker")
505505
@@ -517,7 +517,7 @@ def weather(city: str) -> str:
517517
def __init__(
518518
self,
519519
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
520-
model: str = "command-r-08-2024",
520+
model: str = "command-a-03-2025",
521521
streaming_callback: Optional[StreamingCallbackT] = None,
522522
api_base_url: Optional[str] = None,
523523
generation_kwargs: Optional[dict[str, Any]] = None,

integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class CohereGenerator(CohereChatGenerator):
3232
def __init__(
3333
self,
3434
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
35-
model: str = "command-r-08-2024",
35+
model: str = "command-a-03-2025",
3636
streaming_callback: Optional[Callable] = None,
3737
api_base_url: Optional[str] = None,
3838
**kwargs: Any,

integrations/cohere/tests/test_chat_generator.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_init_default(self, monkeypatch):
156156

157157
component = CohereChatGenerator()
158158
assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"])
159-
assert component.model == "command-r-08-2024"
159+
assert component.model == "command-a-03-2025"
160160
assert component.streaming_callback is None
161161
assert component.api_base_url == "https://api.cohere.com"
162162
assert not component.generation_kwargs
@@ -194,7 +194,7 @@ def test_to_dict_default(self, monkeypatch):
194194
assert data == {
195195
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
196196
"init_parameters": {
197-
"model": "command-r-08-2024",
197+
"model": "command-a-03-2025",
198198
"streaming_callback": None,
199199
"api_key": {
200200
"env_vars": ["COHERE_API_KEY", "CO_API_KEY"],
@@ -246,7 +246,7 @@ def test_from_dict(self, monkeypatch):
246246
data = {
247247
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
248248
"init_parameters": {
249-
"model": "command-r-08-2024",
249+
"model": "command-a-03-2025",
250250
"api_base_url": "test-base-url",
251251
"api_key": {
252252
"env_vars": ["ENV_VAR"],
@@ -261,7 +261,7 @@ def test_from_dict(self, monkeypatch):
261261
},
262262
}
263263
component = CohereChatGenerator.from_dict(data)
264-
assert component.model == "command-r-08-2024"
264+
assert component.model == "command-a-03-2025"
265265
assert component.streaming_callback is print_streaming_chunk
266266
assert component.api_base_url == "test-base-url"
267267
assert component.generation_kwargs == {
@@ -275,7 +275,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
275275
data = {
276276
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
277277
"init_parameters": {
278-
"model": "command-r-08-2024",
278+
"model": "command-a-03-2025",
279279
"api_base_url": "test-base-url",
280280
"api_key": {
281281
"env_vars": ["COHERE_API_KEY", "CO_API_KEY"],
@@ -307,7 +307,6 @@ def test_serde_in_pipeline(self, monkeypatch):
307307
)
308308

309309
generator = CohereChatGenerator(
310-
model="command-r-08-2024",
311310
generation_kwargs={"temperature": 0.7},
312311
streaming_callback=print_streaming_chunk,
313312
tools=[tool],
@@ -326,7 +325,7 @@ def test_serde_in_pipeline(self, monkeypatch):
326325
"generator": {
327326
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", # noqa: E501
328327
"init_parameters": {
329-
"model": "command-r-08-2024",
328+
"model": "command-a-03-2025",
330329
"api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True},
331330
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
332331
"api_base_url": "https://api.cohere.com",
@@ -545,7 +544,7 @@ def test_tools_use_old_way(self):
545544
},
546545
}
547546
]
548-
client = CohereChatGenerator(model="command-r-08-2024")
547+
client = CohereChatGenerator()
549548
response = client.run(
550549
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
551550
generation_kwargs={"tools": tools_schema},
@@ -581,7 +580,7 @@ def test_tools_use_with_tools(self):
581580
function=stock_price,
582581
)
583582
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
584-
client = CohereChatGenerator(model="command-r-08-2024")
583+
client = CohereChatGenerator()
585584
response = client.run(
586585
messages=initial_messages,
587586
tools=[stock_price_tool],
@@ -636,7 +635,7 @@ def test_live_run_with_tools_streaming(self):
636635

637636
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
638637
component = CohereChatGenerator(
639-
model="command-r-08-2024", # Cohere's model that supports tools
638+
# Cohere's model that supports tools
640639
tools=[weather_tool],
641640
streaming_callback=print_streaming_chunk,
642641
)

integrations/cohere/tests/test_chat_generator_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def test_tools_use_with_tools_async(self):
7474
function=stock_price,
7575
)
7676
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
77-
client = CohereChatGenerator(model="command-r-08-2024")
77+
client = CohereChatGenerator()
7878
response = await client.run_async(
7979
messages=initial_messages,
8080
tools=[stock_price_tool],
@@ -137,7 +137,7 @@ async def print_streaming_chunk_async(chunk: StreamingChunk) -> None:
137137

138138
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
139139
component = CohereChatGenerator(
140-
model="command-r-08-2024", # Cohere's model that supports tools
140+
# Cohere's model that supports tools
141141
tools=[weather_tool],
142142
streaming_callback=print_streaming_chunk_async,
143143
)

integrations/cohere/tests/test_chat_generator_chunks.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def expected_streaming_chunks():
136136
finish_reason=None,
137137
tool_calls=None,
138138
meta={
139-
"model": "command-r-08-2024",
139+
"model": "command-a-03-2025",
140140
},
141141
),
142142
# Chunk 2: Tool plan delta
@@ -147,7 +147,7 @@ def expected_streaming_chunks():
147147
finish_reason=None,
148148
tool_calls=None,
149149
meta={
150-
"model": "command-r-08-2024",
150+
"model": "command-a-03-2025",
151151
},
152152
),
153153
# Chunk 3: Tool call start
@@ -165,7 +165,7 @@ def expected_streaming_chunks():
165165
)
166166
],
167167
meta={
168-
"model": "command-r-08-2024",
168+
"model": "command-a-03-2025",
169169
"tool_call_id": "call_weather_paris_123",
170170
},
171171
),
@@ -183,7 +183,7 @@ def expected_streaming_chunks():
183183
)
184184
],
185185
meta={
186-
"model": "command-r-08-2024",
186+
"model": "command-a-03-2025",
187187
},
188188
),
189189
# Chunk 5: Tool call delta - more arguments
@@ -200,7 +200,7 @@ def expected_streaming_chunks():
200200
)
201201
],
202202
meta={
203-
"model": "command-r-08-2024",
203+
"model": "command-a-03-2025",
204204
},
205205
),
206206
# Chunk 6: Tool call delta - city name
@@ -217,7 +217,7 @@ def expected_streaming_chunks():
217217
)
218218
],
219219
meta={
220-
"model": "command-r-08-2024",
220+
"model": "command-a-03-2025",
221221
},
222222
),
223223
# Chunk 7: Tool call delta - closing brace
@@ -234,7 +234,7 @@ def expected_streaming_chunks():
234234
)
235235
],
236236
meta={
237-
"model": "command-r-08-2024",
237+
"model": "command-a-03-2025",
238238
},
239239
),
240240
# Chunk 8: Tool call end
@@ -245,7 +245,7 @@ def expected_streaming_chunks():
245245
finish_reason=None,
246246
tool_calls=None,
247247
meta={
248-
"model": "command-r-08-2024",
248+
"model": "command-a-03-2025",
249249
},
250250
),
251251
# Chunk 9: Tool call start - second tool
@@ -263,7 +263,7 @@ def expected_streaming_chunks():
263263
)
264264
],
265265
meta={
266-
"model": "command-r-08-2024",
266+
"model": "command-a-03-2025",
267267
"tool_call_id": "call_weather_berlin_456",
268268
},
269269
),
@@ -281,7 +281,7 @@ def expected_streaming_chunks():
281281
)
282282
],
283283
meta={
284-
"model": "command-r-08-2024",
284+
"model": "command-a-03-2025",
285285
},
286286
),
287287
# Chunk 11: Tool call delta - more second tool arguments
@@ -298,7 +298,7 @@ def expected_streaming_chunks():
298298
)
299299
],
300300
meta={
301-
"model": "command-r-08-2024",
301+
"model": "command-a-03-2025",
302302
},
303303
),
304304
# Chunk 12: Tool call delta - second city name
@@ -315,7 +315,7 @@ def expected_streaming_chunks():
315315
)
316316
],
317317
meta={
318-
"model": "command-r-08-2024",
318+
"model": "command-a-03-2025",
319319
},
320320
),
321321
# Chunk 13: Tool call delta - closing brace for second tool
@@ -332,7 +332,7 @@ def expected_streaming_chunks():
332332
)
333333
],
334334
meta={
335-
"model": "command-r-08-2024",
335+
"model": "command-a-03-2025",
336336
},
337337
),
338338
# Chunk 14: Tool call end - second tool
@@ -343,7 +343,7 @@ def expected_streaming_chunks():
343343
finish_reason=None,
344344
tool_calls=None,
345345
meta={
346-
"model": "command-r-08-2024",
346+
"model": "command-a-03-2025",
347347
},
348348
),
349349
# Chunk 15: Message end with finish reason and usage
@@ -354,7 +354,7 @@ def expected_streaming_chunks():
354354
finish_reason="tool_calls",
355355
tool_calls=None,
356356
meta={
357-
"model": "command-r-08-2024",
357+
"model": "command-a-03-2025",
358358
"finish_reason": "TOOL_CALLS",
359359
"usage": {
360360
"prompt_tokens": 9,
@@ -373,7 +373,7 @@ def test_convert_cohere_chunk_to_streaming_chunk_complete_sequence(self, cohere_
373373
for cohere_chunk, haystack_chunk in zip(cohere_chunks, expected_streaming_chunks):
374374
stream_chunk = _convert_cohere_chunk_to_streaming_chunk(
375375
chunk=cohere_chunk,
376-
model="command-r-08-2024",
376+
model="command-a-03-2025",
377377
)
378378
assert stream_chunk == haystack_chunk
379379

integrations/cohere/tests/test_generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_init_default(self, monkeypatch):
1919
monkeypatch.setenv("COHERE_API_KEY", "foo")
2020
component = CohereGenerator()
2121
assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"])
22-
assert component.model == "command-r-08-2024"
22+
assert component.model == "command-a-03-2025"
2323
assert component.streaming_callback is None
2424
assert component.api_base_url == COHERE_API_URL
2525
assert component.model_parameters == {}
@@ -47,7 +47,7 @@ def test_to_dict_default(self, monkeypatch):
4747
assert data == {
4848
"type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator",
4949
"init_parameters": {
50-
"model": "command-r-08-2024",
50+
"model": "command-a-03-2025",
5151
"api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"},
5252
"streaming_callback": None,
5353
"api_base_url": COHERE_API_URL,
@@ -86,7 +86,7 @@ def test_from_dict(self, monkeypatch):
8686
data = {
8787
"type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator",
8888
"init_parameters": {
89-
"model": "command-r-08-2024",
89+
"model": "command-a-03-2025",
9090
"max_tokens": 10,
9191
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
9292
"some_test_param": "test-params",
@@ -97,7 +97,7 @@ def test_from_dict(self, monkeypatch):
9797
}
9898
component: CohereGenerator = CohereGenerator.from_dict(data)
9999
assert component.api_key == Secret.from_env_var("ENV_VAR", strict=False)
100-
assert component.model == "command-r-08-2024"
100+
assert component.model == "command-a-03-2025"
101101
assert component.streaming_callback == print_streaming_chunk
102102
assert component.api_base_url == "test-base-url"
103103
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}

0 commit comments

Comments
 (0)