Skip to content

Commit aed8c4e

Browse files
add Cohere support to the chatbot example (#199)
1 parent ec14ed9 commit aed8c4e

File tree

6 files changed

+112
-12
lines changed

6 files changed

+112
-12
lines changed

example-apps/chatbot-rag-app/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ export MISTRAL_API_ENDPOINT=... # optional
128128
export MISTRAL_MODEL=... # optional
129129
```
130130

131+
### Cohere
132+
133+
To use Cohere you need to set the following environment variables:
134+
135+
```
136+
export LLM_TYPE=cohere
137+
export COHERE_API_KEY=...
138+
export COHERE_MODEL=... # optional
139+
```
140+
131141
## Running the App
132142

133143
Once you have indexed data into the Elasticsearch index, there are two ways to run the app: via Docker or locally. Docker is advised for testing & production use. Locally is advised for development.

example-apps/chatbot-rag-app/api/chat.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ def ask_question(question, session_id):
6464

6565
answer = ""
6666
for chunk in get_llm().stream(qa_prompt):
67-
yield f"data: {chunk.content}\n\n"
67+
content = chunk.content.replace(
68+
"\n", " "
69+
) # the stream can get messed up with newlines
70+
yield f"data: {content}\n\n"
6871
answer += chunk.content
6972

7073
yield f"data: {DONE_TAG}\n\n"

example-apps/chatbot-rag-app/api/llm_integrations.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
ChatVertexAI,
44
AzureChatOpenAI,
55
BedrockChat,
6+
ChatCohere,
67
)
7-
from langchain_core.messages import HumanMessage
88
from langchain_mistralai.chat_models import ChatMistralAI
99
import os
1010
import vertexai
@@ -76,12 +76,21 @@ def init_mistral_chat(temperature):
7676
return ChatMistralAI(**kwargs)
7777

7878

79+
def init_cohere_chat(temperature):
80+
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
81+
COHERE_MODEL = os.getenv("COHERE_MODEL")
82+
return ChatCohere(
83+
cohere_api_key=COHERE_API_KEY, model=COHERE_MODEL, temperature=temperature
84+
)
85+
86+
7987
MAP_LLM_TYPE_TO_CHAT_MODEL = {
8088
"azure": init_azure_chat,
8189
"bedrock": init_bedrock,
8290
"openai": init_openai_chat,
8391
"vertex": init_vertex_chat,
8492
"mistral": init_mistral_chat,
93+
"cohere": init_cohere_chat,
8594
}
8695

8796

example-apps/chatbot-rag-app/env.example

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@ ES_INDEX_CHAT_HISTORY=workplace-app-docs-chat-history
4040
# MISTRAL_API_KEY=
4141
# MISTRAL_API_ENDPOINT=
4242
# MISTRAL_MODEL=
43+
44+
# Uncomment and complete if you want to use Cohere
45+
# LLM_TYPE=cohere
46+
# COHERE_API_KEY=
47+
# COHERE_MODEL=

example-apps/chatbot-rag-app/requirements.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ boto3
2323
# Mistral dependencies
2424
langchain-mistralai
2525

26+
# Cohere dependencies
27+
cohere
28+
2629
# TBD if these are still needed
2730
exceptiongroup
2831
importlib-metadata

example-apps/chatbot-rag-app/requirements.txt

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,24 @@
66
#
77
aiohttp==3.8.5
88
# via
9+
# cohere
910
# langchain
11+
# langchain-community
1012
# openai
1113
aiosignal==1.3.1
1214
# via aiohttp
1315
annotated-types==0.5.0
1416
# via pydantic
1517
anyio==3.7.1
1618
# via
17-
# langchain
19+
# httpx
1820
# langchain-core
1921
async-timeout==4.0.3
2022
# via aiohttp
2123
attrs==23.1.0
2224
# via aiohttp
25+
backoff==2.2.1
26+
# via cohere
2327
blinker==1.6.2
2428
# via flask
2529
boto3==1.28.61
@@ -35,6 +39,8 @@ cachetools==5.3.1
3539
certifi==2023.7.22
3640
# via
3741
# elastic-transport
42+
# httpcore
43+
# httpx
3844
# requests
3945
charset-normalizer==3.2.0
4046
# via
@@ -44,8 +50,12 @@ click==8.1.7
4450
# via
4551
# flask
4652
# pip-tools
53+
cohere==4.52
54+
# via -r requirements.in
4755
dataclasses-json==0.5.14
48-
# via langchain
56+
# via
57+
# langchain
58+
# langchain-community
4959
elastic-transport==8.4.0
5060
# via elasticsearch
5161
elasticsearch==8.12.1
@@ -54,6 +64,10 @@ elasticsearch==8.12.1
5464
# langchain-elasticsearch
5565
exceptiongroup==1.2.0
5666
# via -r requirements.in
67+
fastavro==1.9.4
68+
# via cohere
69+
filelock==3.13.1
70+
# via huggingface-hub
5771
flask==2.3.3
5872
# via
5973
# -r requirements.in
@@ -64,6 +78,8 @@ frozenlist==1.4.0
6478
# via
6579
# aiohttp
6680
# aiosignal
81+
fsspec==2024.2.0
82+
# via huggingface-hub
6783
google-api-core[grpc]==2.14.0
6884
# via
6985
# google-cloud-aiplatform
@@ -112,13 +128,24 @@ grpcio-status==1.59.3
112128
# via
113129
# -r requirements.in
114130
# google-api-core
131+
h11==0.14.0
132+
# via httpcore
133+
httpcore==1.0.4
134+
# via httpx
135+
httpx==0.25.2
136+
# via mistralai
137+
huggingface-hub==0.21.4
138+
# via tokenizers
115139
idna==3.4
116140
# via
117141
# anyio
142+
# httpx
118143
# requests
119144
# yarl
120145
importlib-metadata==6.8.0
121-
# via -r requirements.in
146+
# via
147+
# -r requirements.in
148+
# cohere
122149
itsdangerous==2.1.2
123150
# via flask
124151
jinja2==3.1.2
@@ -135,20 +162,31 @@ jsonpointer==2.4
135162
# via jsonpatch
136163
langchain==0.1.9
137164
# via -r requirements.in
138-
langchain-core==0.1.23
139-
# via langchain-elasticsearch
165+
langchain-community==0.0.27
166+
# via langchain
167+
langchain-core==0.1.30
168+
# via
169+
# langchain
170+
# langchain-community
171+
# langchain-elasticsearch
172+
# langchain-mistralai
140173
langchain-elasticsearch==0.1.0
141174
# via -r requirements.in
175+
langchain-mistralai==0.0.5
176+
# via -r requirements.in
142177
langsmith==0.1.10
143178
# via
144179
# langchain
180+
# langchain-community
145181
# langchain-core
146182
markupsafe==2.1.3
147183
# via
148184
# jinja2
149185
# werkzeug
150186
marshmallow==3.20.1
151187
# via dataclasses-json
188+
mistralai==0.1.3
189+
# via langchain-mistralai
152190
multidict==6.0.4
153191
# via
154192
# aiohttp
@@ -160,18 +198,28 @@ numexpr==2.8.5
160198
numpy==1.25.2
161199
# via
162200
# langchain
201+
# langchain-community
163202
# langchain-elasticsearch
164203
# numexpr
204+
# pandas
205+
# pyarrow
165206
# shapely
166207
openai==0.27.9
167208
# via -r requirements.in
209+
orjson==3.9.15
210+
# via
211+
# langsmith
212+
# mistralai
168213
packaging==23.2
169214
# via
170215
# build
171216
# google-cloud-aiplatform
172217
# google-cloud-bigquery
218+
# huggingface-hub
173219
# langchain-core
174220
# marshmallow
221+
pandas==2.2.1
222+
# via mistralai
175223
pip-tools==7.3.0
176224
# via -r requirements.in
177225
proto-plus==1.22.3
@@ -189,6 +237,8 @@ protobuf==4.25.1
189237
# grpc-google-iam-v1
190238
# grpcio-status
191239
# proto-plus
240+
pyarrow==15.0.1
241+
# via mistralai
192242
pyasn1==0.5.0
193243
# via
194244
# pyasn1-modules
@@ -200,6 +250,7 @@ pydantic==2.5.2
200250
# langchain
201251
# langchain-core
202252
# langsmith
253+
# mistralai
203254
pydantic-core==2.14.5
204255
# via pydantic
205256
pyproject-hooks==1.0.0
@@ -208,20 +259,28 @@ python-dateutil==2.8.2
208259
# via
209260
# botocore
210261
# google-cloud-bigquery
262+
# pandas
211263
python-dotenv==1.0.0
212264
# via -r requirements.in
265+
pytz==2024.1
266+
# via pandas
213267
pyyaml==6.0.1
214268
# via
269+
# huggingface-hub
215270
# langchain
271+
# langchain-community
216272
# langchain-core
217273
regex==2023.10.3
218274
# via tiktoken
219275
requests==2.31.0
220276
# via
277+
# cohere
221278
# google-api-core
222279
# google-cloud-bigquery
223280
# google-cloud-storage
281+
# huggingface-hub
224282
# langchain
283+
# langchain-community
225284
# langchain-core
226285
# langsmith
227286
# openai
@@ -235,28 +294,41 @@ shapely==2.0.2
235294
six==1.16.0
236295
# via python-dateutil
237296
sniffio==1.3.0
238-
# via anyio
297+
# via
298+
# anyio
299+
# httpx
239300
sqlalchemy==2.0.20
240-
# via langchain
301+
# via
302+
# langchain
303+
# langchain-community
241304
tenacity==8.2.3
242305
# via
243306
# langchain
307+
# langchain-community
244308
# langchain-core
245309
tiktoken==0.5.1
246310
# via -r requirements.in
311+
tokenizers==0.15.2
312+
# via langchain-mistralai
247313
tqdm==4.66.1
248-
# via openai
314+
# via
315+
# huggingface-hub
316+
# openai
249317
typing-extensions==4.7.1
250318
# via
319+
# huggingface-hub
251320
# pydantic
252321
# pydantic-core
253322
# sqlalchemy
254323
# typing-inspect
255324
typing-inspect==0.9.0
256325
# via dataclasses-json
326+
tzdata==2024.1
327+
# via pandas
257328
urllib3==1.26.16
258329
# via
259330
# botocore
331+
# cohere
260332
# elastic-transport
261333
# requests
262334
werkzeug==2.3.7
@@ -268,8 +340,6 @@ yarl==1.9.2
268340
zipp==3.17.0
269341
# via importlib-metadata
270342

271-
langchain-mistralai==0.0.5
272-
# via -r requirements.in
273343
# The following packages are considered to be unsafe in a requirements file:
274344
# pip
275345
# setuptools

0 commit comments

Comments
 (0)