Skip to content

Commit 0eb5301

Browse files
Notebook for RAG using Google's Gemma, Hugging Face and Elastic (#197)
* notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * notebook for RAG using gemma, huggingface and elastic * Update README.md Added subfolders for Gemini and Gemma * Update find-notebooks-to-test.sh Added Gemma notebook
1 parent 41c0a05 commit 0eb5301

File tree

3 files changed

+375
-1
lines changed

3 files changed

+375
-1
lines changed

bin/find-notebooks-to-test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ EXEMPT_NOTEBOOKS=(
1616
"notebooks/integrations/gemini/vector-search-gemini-elastic.ipynb"
1717
"notebooks/integrations/gemini/qa-langchain-gemini-elasticsearch.ipynb"
1818
"notebooks/integrations/openai/openai-KNN-RAG.ipynb"
19+
"notebooks/integrations/gemma/rag-gemma-huggingface-elastic.ipynb"
1920
)
2021

2122
ALL_NOTEBOOKS=$(find notebooks -name "*.ipynb" | grep -v "_nbtest" | grep -v ".ipynb_checkpoints" | sort)

notebooks/integrations/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ The following subfolders contain notebooks that demonstrate how to integrate pop
44

55
- [OpenAI](./openai/README.md)
66
- [Hugging Face](./hugging-face/README.md)
7-
- [LlamaIndex](./llama-index/README.md)
7+
- [LlamaIndex](./llama-index/README.md)
8+
- [Gemini](./gemini)
9+
- [Gemma](./gemma)
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "faa09879-9128-4864-8bb5-945ef9b8e84c",
6+
"metadata": {},
7+
"source": [
8+
"# RAG: Using Gemma LLM locally for question answering on private data"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "d047438b-6f18-47ed-aac9-12c741cefd06",
14+
"metadata": {},
15+
"source": [
16+
"In this notebook, our aim is to develop a RAG system utilizing [Google's Gemma](https://ai.google.dev/gemma) model. We'll generate vectors with [Elastic's ELSER](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html) model and store them in Elasticsearch. Additionally, we'll explore semantic retrieval techniques and present the top search results as a context window to the Gemma model. Furthermore, we'll utilize the [Hugging Face transformer](https://huggingface.co/google/gemma-2b-it) library to load Gemma on a local environment."
17+
]
18+
},
19+
{
20+
"cell_type": "markdown",
21+
"id": "1bd3acec-d490-4139-bab1-b874e1e7db8d",
22+
"metadata": {},
23+
"source": [
24+
"## Setup"
25+
]
26+
},
27+
{
28+
"cell_type": "markdown",
29+
"id": "ef406b8a-03fb-49c5-baed-18e03bcd36d9",
30+
"metadata": {},
31+
"source": [
32+
"**Elastic Credentials** - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`,` ELASTIC_API_KEY`).\n",
33+
"\n",
34+
"**Hugging Face Token** - To get started with the [Gemma](https://huggingface.co/google/gemma-2b-it) model, it is necessary to agree to the terms on Hugging Face and generate the [access token](https://huggingface.co/docs/hub/en/security-tokens) with `write` role.\n",
35+
"\n",
36+
"**Gemma Model** - We're going to use [gemma-2b-it](https://huggingface.co/google/gemma-2b-it), though Google has released 4 open models. You can use any of them i.e. [gemma-2b](https://huggingface.co/google/gemma-2b), [gemma-7b](https://huggingface.co/google/gemma-7b), [gemma-7b-it](https://huggingface.co/google/gemma-7b-it)"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"id": "ac91d7a3-1198-4b11-a9c5-50028abc861b",
42+
"metadata": {},
43+
"source": [
44+
"## Install packages"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"id": "fda41538-444c-48d7-80a0-b34b2e158b82",
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"pip install -q -U elasticsearch langchain transformers huggingface_hub"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"id": "15c2e924-e5a2-439b-8e98-f13a162db7fe",
60+
"metadata": {},
61+
"source": [
62+
"## Import packages"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": null,
68+
"id": "7219411b-fae6-4c2a-b170-796bc30ed073",
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"import json\n",
73+
"import os\n",
74+
"from getpass import getpass\n",
75+
"from urllib.request import urlopen\n",
76+
"\n",
77+
"from elasticsearch import Elasticsearch, helpers\n",
78+
"from langchain.text_splitter import CharacterTextSplitter\n",
79+
"from langchain.vectorstores import ElasticsearchStore\n",
80+
"from langchain import HuggingFacePipeline\n",
81+
"from langchain.chains import RetrievalQA\n",
82+
"from langchain.prompts import ChatPromptTemplate\n",
83+
"from langchain.schema.output_parser import StrOutputParser\n",
84+
"from langchain.schema.runnable import RunnablePassthrough\n",
85+
"from huggingface_hub import login\n",
86+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
87+
"from transformers import AutoTokenizer, pipeline"
88+
]
89+
},
90+
{
91+
"cell_type": "markdown",
92+
"id": "182a413f-e7fd-4361-8096-90736d3df33e",
93+
"metadata": {},
94+
"source": [
95+
"## Get Credentials"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"id": "b184b3a5-0cc8-43f9-b15d-f5ccf48f574b",
102+
"metadata": {},
103+
"outputs": [],
104+
"source": [
105+
"ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n",
106+
"ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n",
107+
"elastic_index_name = \"gemma-rag\""
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"id": "a2efbd81-70b9-409c-ab5f-796d538b42a1",
113+
"metadata": {},
114+
"source": [
115+
"## Add documents"
116+
]
117+
},
118+
{
119+
"cell_type": "markdown",
120+
"id": "161dfb9d-f11f-4de5-8489-6464ade0cdb2",
121+
"metadata": {},
122+
"source": [
123+
"### Let's download the sample dataset and deserialize the document."
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"id": "49427546-7b37-48f4-a6fe-395736ea2d38",
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"url = \"https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json\"\n",
134+
"\n",
135+
"response = urlopen(url)\n",
136+
"\n",
137+
"workplace_docs = json.loads(response.read())"
138+
]
139+
},
140+
{
141+
"cell_type": "markdown",
142+
"id": "f3bf0104-8b31-4b39-ad21-b372fd1fa0db",
143+
"metadata": {},
144+
"source": [
145+
"### Split Documents into Passages"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"id": "79e55ed1-418e-48ed-b3e3-d28e10744eb5",
152+
"metadata": {},
153+
"outputs": [],
154+
"source": [
155+
"metadata = []\n",
156+
"content = []\n",
157+
"\n",
158+
"for doc in workplace_docs:\n",
159+
" content.append(doc[\"content\"])\n",
160+
" metadata.append(\n",
161+
" {\n",
162+
" \"name\": doc[\"name\"],\n",
163+
" \"summary\": doc[\"summary\"],\n",
164+
" \"rolePermissions\": doc[\"rolePermissions\"],\n",
165+
" }\n",
166+
" )\n",
167+
"\n",
168+
"text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n",
169+
"docs = text_splitter.create_documents(content, metadatas=metadata)"
170+
]
171+
},
172+
{
173+
"cell_type": "markdown",
174+
"id": "4264bc1b-23b1-4547-a7f0-670944c3e605",
175+
"metadata": {},
176+
"source": [
177+
"## Index Documents into Elasticsearch using ELSER\n",
178+
"\n",
179+
"Before we begin indexing, ensure you have [downloaded and deployed the ELSER model](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#download-deploy-elser) in your deployment and is running on the ML node."
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": null,
185+
"id": "eb1db78e-e40a-4a5c-9d15-75ee2a1d0994",
186+
"metadata": {},
187+
"outputs": [],
188+
"source": [
189+
"es = ElasticsearchStore.from_documents(\n",
190+
" docs,\n",
191+
" es_cloud_id=ELASTIC_CLOUD_ID,\n",
192+
" es_api_key=ELASTIC_API_KEY,\n",
193+
" index_name=elastic_index_name,\n",
194+
" strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),\n",
195+
")\n",
196+
"\n",
197+
"es"
198+
]
199+
},
200+
{
201+
"cell_type": "markdown",
202+
"id": "02b1ead9-c442-40e9-ba81-d4d286ea878b",
203+
"metadata": {},
204+
"source": [
205+
"## Hugging Face login"
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": null,
211+
"id": "d2f651e4-e760-4b59-a8a3-57c58dfc229f",
212+
"metadata": {},
213+
"outputs": [],
214+
"source": [
215+
"from huggingface_hub import notebook_login\n",
216+
"\n",
217+
"notebook_login()"
218+
]
219+
},
220+
{
221+
"cell_type": "markdown",
222+
"id": "7454f551-71a9-4310-bb2a-3fe0e683daab",
223+
"metadata": {},
224+
"source": [
225+
"## Initialize the tokenizer with the model (`google/gemma-2b-it`)"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": null,
231+
"id": "3e1d98eb-0f4e-4c41-a851-125b75502963",
232+
"metadata": {},
233+
"outputs": [],
234+
"source": [
235+
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n",
236+
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")"
237+
]
238+
},
239+
{
240+
"cell_type": "markdown",
241+
"id": "11a12596-2ac1-4101-b189-d21d53d33b04",
242+
"metadata": {},
243+
"source": [
244+
"## Create a `text-generation` pipeline and initialize with LLM"
245+
]
246+
},
247+
{
248+
"cell_type": "code",
249+
"execution_count": null,
250+
"id": "623e74fb-5707-44f7-9dd8-d9499f7ab61e",
251+
"metadata": {},
252+
"outputs": [],
253+
"source": [
254+
"pipe = pipeline(\n",
255+
" \"text-generation\",\n",
256+
" model=model,\n",
257+
" tokenizer=tokenizer,\n",
258+
" max_new_tokens=1024,\n",
259+
")\n",
260+
"\n",
261+
"llm = HuggingFacePipeline(\n",
262+
" pipeline=pipe,\n",
263+
" model_kwargs={\"temperature\": 0.7},\n",
264+
")"
265+
]
266+
},
267+
{
268+
"cell_type": "markdown",
269+
"id": "49ce0e72-e419-4310-85e9-09077d6c40b2",
270+
"metadata": {},
271+
"source": [
272+
"## Format Docs"
273+
]
274+
},
275+
{
276+
"cell_type": "code",
277+
"execution_count": null,
278+
"id": "b3c07a75-9220-4a82-a92e-3fc2727ad3ba",
279+
"metadata": {},
280+
"outputs": [],
281+
"source": [
282+
"def format_docs(docs):\n",
283+
" return \"\\n\\n\".join(doc.page_content for doc in docs)"
284+
]
285+
},
286+
{
287+
"cell_type": "markdown",
288+
"id": "f6266222-6ec3-495a-8f14-460549bab89d",
289+
"metadata": {},
290+
"source": [
291+
"## Create a chain using Prompt template"
292+
]
293+
},
294+
{
295+
"cell_type": "code",
296+
"execution_count": null,
297+
"id": "ec203d1a-104b-4583-9ba1-a6b4b0354367",
298+
"metadata": {},
299+
"outputs": [],
300+
"source": [
301+
"retriever = es.as_retriever(search_kwargs={\"k\": 10})\n",
302+
"\n",
303+
"template = \"\"\"Answer the question based only on the following context:\\n\n",
304+
"\n",
305+
"{context}\n",
306+
"\n",
307+
"Question: {question}\n",
308+
"\"\"\"\n",
309+
"prompt = ChatPromptTemplate.from_template(template)\n",
310+
"\n",
311+
"\n",
312+
"chain = (\n",
313+
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
314+
" | prompt\n",
315+
" | llm\n",
316+
" | StrOutputParser()\n",
317+
")"
318+
]
319+
},
320+
{
321+
"cell_type": "markdown",
322+
"id": "8ae892dd-7442-4d4d-a804-1d717266e596",
323+
"metadata": {},
324+
"source": [
325+
"## Ask question"
326+
]
327+
},
328+
{
329+
"cell_type": "code",
330+
"execution_count": 11,
331+
"id": "ba312f17-44ae-423d-89a0-ea01eccd85b5",
332+
"metadata": {},
333+
"outputs": [
334+
{
335+
"data": {
336+
"text/plain": [
337+
"'Answer: The pet policy in the office allows employees to bring pets to the office, subject to approval by the HR department. Pets covered under this policy include dogs, cats, and other small, non-exotic animals, subject to approval by the HR department.'"
338+
]
339+
},
340+
"execution_count": 11,
341+
"metadata": {},
342+
"output_type": "execute_result"
343+
}
344+
],
345+
"source": [
346+
"chain.invoke(\"What is the pet policy in the office?\")"
347+
]
348+
}
349+
],
350+
"metadata": {
351+
"kernelspec": {
352+
"display_name": "Python 3 (ipykernel)",
353+
"language": "python",
354+
"name": "python3"
355+
},
356+
"language_info": {
357+
"codemirror_mode": {
358+
"name": "ipython",
359+
"version": 3
360+
},
361+
"file_extension": ".py",
362+
"mimetype": "text/x-python",
363+
"name": "python",
364+
"nbconvert_exporter": "python",
365+
"pygments_lexer": "ipython3",
366+
"version": "3.11.4"
367+
}
368+
},
369+
"nbformat": 4,
370+
"nbformat_minor": 5
371+
}

0 commit comments

Comments
 (0)