|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "faa09879-9128-4864-8bb5-945ef9b8e84c", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# RAG: Using Gemma LLM locally for question answering on private data" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "markdown", |
| 13 | + "id": "d047438b-6f18-47ed-aac9-12c741cefd06", |
| 14 | + "metadata": {}, |
| 15 | + "source": [ |
| 16 | + "In this notebook, our aim is to develop a RAG system utilizing [Google's Gemma](https://ai.google.dev/gemma) model. We'll generate vectors with [Elastic's ELSER](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html) model and store them in Elasticsearch. Additionally, we'll explore semantic retrieval techniques and present the top search results as a context window to the Gemma model. Furthermore, we'll utilize the [Hugging Face transformer](https://huggingface.co/google/gemma-2b-it) library to load Gemma on a local environment." |
| 17 | + ] |
| 18 | + }, |
| 19 | + { |
| 20 | + "cell_type": "markdown", |
| 21 | + "id": "1bd3acec-d490-4139-bab1-b874e1e7db8d", |
| 22 | + "metadata": {}, |
| 23 | + "source": [ |
| 24 | + "## Setup" |
| 25 | + ] |
| 26 | + }, |
| 27 | + { |
| 28 | + "cell_type": "markdown", |
| 29 | + "id": "ef406b8a-03fb-49c5-baed-18e03bcd36d9", |
| 30 | + "metadata": {}, |
| 31 | + "source": [ |
| 32 | + "**Elastic Credentials** - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`,` ELASTIC_API_KEY`).\n", |
| 33 | + "\n", |
| 34 | + "**Hugging Face Token** - To get started with the [Gemma](https://huggingface.co/google/gemma-2b-it) model, it is necessary to agree to the terms on Hugging Face and generate the [access token](https://huggingface.co/docs/hub/en/security-tokens) with `write` role.\n", |
| 35 | + "\n", |
| 36 | + "**Gemma Model** - We're going to use [gemma-2b-it](https://huggingface.co/google/gemma-2b-it), though Google has released 4 open models. You can use any of them i.e. [gemma-2b](https://huggingface.co/google/gemma-2b), [gemma-7b](https://huggingface.co/google/gemma-7b), [gemma-7b-it](https://huggingface.co/google/gemma-7b-it)" |
| 37 | + ] |
| 38 | + }, |
| 39 | + { |
| 40 | + "cell_type": "markdown", |
| 41 | + "id": "ac91d7a3-1198-4b11-a9c5-50028abc861b", |
| 42 | + "metadata": {}, |
| 43 | + "source": [ |
| 44 | + "## Install packages" |
| 45 | + ] |
| 46 | + }, |
| 47 | + { |
| 48 | + "cell_type": "code", |
| 49 | + "execution_count": null, |
| 50 | + "id": "fda41538-444c-48d7-80a0-b34b2e158b82", |
| 51 | + "metadata": {}, |
| 52 | + "outputs": [], |
| 53 | + "source": [ |
| 54 | + "pip install -q -U elasticsearch langchain transformers huggingface_hub" |
| 55 | + ] |
| 56 | + }, |
| 57 | + { |
| 58 | + "cell_type": "markdown", |
| 59 | + "id": "15c2e924-e5a2-439b-8e98-f13a162db7fe", |
| 60 | + "metadata": {}, |
| 61 | + "source": [ |
| 62 | + "## Import packages" |
| 63 | + ] |
| 64 | + }, |
| 65 | + { |
| 66 | + "cell_type": "code", |
| 67 | + "execution_count": null, |
| 68 | + "id": "7219411b-fae6-4c2a-b170-796bc30ed073", |
| 69 | + "metadata": {}, |
| 70 | + "outputs": [], |
| 71 | + "source": [ |
| 72 | + "import json\n", |
| 73 | + "import os\n", |
| 74 | + "from getpass import getpass\n", |
| 75 | + "from urllib.request import urlopen\n", |
| 76 | + "\n", |
| 77 | + "from elasticsearch import Elasticsearch, helpers\n", |
| 78 | + "from langchain.text_splitter import CharacterTextSplitter\n", |
| 79 | + "from langchain.vectorstores import ElasticsearchStore\n", |
| 80 | + "from langchain import HuggingFacePipeline\n", |
| 81 | + "from langchain.chains import RetrievalQA\n", |
| 82 | + "from langchain.prompts import ChatPromptTemplate\n", |
| 83 | + "from langchain.schema.output_parser import StrOutputParser\n", |
| 84 | + "from langchain.schema.runnable import RunnablePassthrough\n", |
| 85 | + "from huggingface_hub import login\n", |
| 86 | + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", |
| 87 | + "from transformers import AutoTokenizer, pipeline" |
| 88 | + ] |
| 89 | + }, |
| 90 | + { |
| 91 | + "cell_type": "markdown", |
| 92 | + "id": "182a413f-e7fd-4361-8096-90736d3df33e", |
| 93 | + "metadata": {}, |
| 94 | + "source": [ |
| 95 | + "## Get Credentials" |
| 96 | + ] |
| 97 | + }, |
| 98 | + { |
| 99 | + "cell_type": "code", |
| 100 | + "execution_count": null, |
| 101 | + "id": "b184b3a5-0cc8-43f9-b15d-f5ccf48f574b", |
| 102 | + "metadata": {}, |
| 103 | + "outputs": [], |
| 104 | + "source": [ |
| 105 | + "ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n", |
| 106 | + "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n", |
| 107 | + "elastic_index_name = \"gemma-rag\"" |
| 108 | + ] |
| 109 | + }, |
| 110 | + { |
| 111 | + "cell_type": "markdown", |
| 112 | + "id": "a2efbd81-70b9-409c-ab5f-796d538b42a1", |
| 113 | + "metadata": {}, |
| 114 | + "source": [ |
| 115 | + "## Add documents" |
| 116 | + ] |
| 117 | + }, |
| 118 | + { |
| 119 | + "cell_type": "markdown", |
| 120 | + "id": "161dfb9d-f11f-4de5-8489-6464ade0cdb2", |
| 121 | + "metadata": {}, |
| 122 | + "source": [ |
| 123 | + "### Let's download the sample dataset and deserialize the document." |
| 124 | + ] |
| 125 | + }, |
| 126 | + { |
| 127 | + "cell_type": "code", |
| 128 | + "execution_count": null, |
| 129 | + "id": "49427546-7b37-48f4-a6fe-395736ea2d38", |
| 130 | + "metadata": {}, |
| 131 | + "outputs": [], |
| 132 | + "source": [ |
| 133 | + "url = \"https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json\"\n", |
| 134 | + "\n", |
| 135 | + "response = urlopen(url)\n", |
| 136 | + "\n", |
| 137 | + "workplace_docs = json.loads(response.read())" |
| 138 | + ] |
| 139 | + }, |
| 140 | + { |
| 141 | + "cell_type": "markdown", |
| 142 | + "id": "f3bf0104-8b31-4b39-ad21-b372fd1fa0db", |
| 143 | + "metadata": {}, |
| 144 | + "source": [ |
| 145 | + "### Split Documents into Passages" |
| 146 | + ] |
| 147 | + }, |
| 148 | + { |
| 149 | + "cell_type": "code", |
| 150 | + "execution_count": null, |
| 151 | + "id": "79e55ed1-418e-48ed-b3e3-d28e10744eb5", |
| 152 | + "metadata": {}, |
| 153 | + "outputs": [], |
| 154 | + "source": [ |
| 155 | + "metadata = []\n", |
| 156 | + "content = []\n", |
| 157 | + "\n", |
| 158 | + "for doc in workplace_docs:\n", |
| 159 | + " content.append(doc[\"content\"])\n", |
| 160 | + " metadata.append(\n", |
| 161 | + " {\n", |
| 162 | + " \"name\": doc[\"name\"],\n", |
| 163 | + " \"summary\": doc[\"summary\"],\n", |
| 164 | + " \"rolePermissions\": doc[\"rolePermissions\"],\n", |
| 165 | + " }\n", |
| 166 | + " )\n", |
| 167 | + "\n", |
| 168 | + "text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n", |
| 169 | + "docs = text_splitter.create_documents(content, metadatas=metadata)" |
| 170 | + ] |
| 171 | + }, |
| 172 | + { |
| 173 | + "cell_type": "markdown", |
| 174 | + "id": "4264bc1b-23b1-4547-a7f0-670944c3e605", |
| 175 | + "metadata": {}, |
| 176 | + "source": [ |
| 177 | + "## Index Documents into Elasticsearch using ELSER\n", |
| 178 | + "\n", |
| 179 | + "Before we begin indexing, ensure you have [downloaded and deployed the ELSER model](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#download-deploy-elser) in your deployment and is running on the ML node." |
| 180 | + ] |
| 181 | + }, |
| 182 | + { |
| 183 | + "cell_type": "code", |
| 184 | + "execution_count": null, |
| 185 | + "id": "eb1db78e-e40a-4a5c-9d15-75ee2a1d0994", |
| 186 | + "metadata": {}, |
| 187 | + "outputs": [], |
| 188 | + "source": [ |
| 189 | + "es = ElasticsearchStore.from_documents(\n", |
| 190 | + " docs,\n", |
| 191 | + " es_cloud_id=ELASTIC_CLOUD_ID,\n", |
| 192 | + " es_api_key=ELASTIC_API_KEY,\n", |
| 193 | + " index_name=elastic_index_name,\n", |
| 194 | + " strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(),\n", |
| 195 | + ")\n", |
| 196 | + "\n", |
| 197 | + "es" |
| 198 | + ] |
| 199 | + }, |
| 200 | + { |
| 201 | + "cell_type": "markdown", |
| 202 | + "id": "02b1ead9-c442-40e9-ba81-d4d286ea878b", |
| 203 | + "metadata": {}, |
| 204 | + "source": [ |
| 205 | + "## Hugging Face login" |
| 206 | + ] |
| 207 | + }, |
| 208 | + { |
| 209 | + "cell_type": "code", |
| 210 | + "execution_count": null, |
| 211 | + "id": "d2f651e4-e760-4b59-a8a3-57c58dfc229f", |
| 212 | + "metadata": {}, |
| 213 | + "outputs": [], |
| 214 | + "source": [ |
| 215 | + "from huggingface_hub import notebook_login\n", |
| 216 | + "\n", |
| 217 | + "notebook_login()" |
| 218 | + ] |
| 219 | + }, |
| 220 | + { |
| 221 | + "cell_type": "markdown", |
| 222 | + "id": "7454f551-71a9-4310-bb2a-3fe0e683daab", |
| 223 | + "metadata": {}, |
| 224 | + "source": [ |
| 225 | + "## Initialize the tokenizer with the model (`google/gemma-2b-it`)" |
| 226 | + ] |
| 227 | + }, |
| 228 | + { |
| 229 | + "cell_type": "code", |
| 230 | + "execution_count": null, |
| 231 | + "id": "3e1d98eb-0f4e-4c41-a851-125b75502963", |
| 232 | + "metadata": {}, |
| 233 | + "outputs": [], |
| 234 | + "source": [ |
| 235 | + "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n", |
| 236 | + "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")" |
| 237 | + ] |
| 238 | + }, |
| 239 | + { |
| 240 | + "cell_type": "markdown", |
| 241 | + "id": "11a12596-2ac1-4101-b189-d21d53d33b04", |
| 242 | + "metadata": {}, |
| 243 | + "source": [ |
| 244 | + "## Create a `text-generation` pipeline and initialize with LLM" |
| 245 | + ] |
| 246 | + }, |
| 247 | + { |
| 248 | + "cell_type": "code", |
| 249 | + "execution_count": null, |
| 250 | + "id": "623e74fb-5707-44f7-9dd8-d9499f7ab61e", |
| 251 | + "metadata": {}, |
| 252 | + "outputs": [], |
| 253 | + "source": [ |
| 254 | + "pipe = pipeline(\n", |
| 255 | + " \"text-generation\",\n", |
| 256 | + " model=model,\n", |
| 257 | + " tokenizer=tokenizer,\n", |
| 258 | + " max_new_tokens=1024,\n", |
| 259 | + ")\n", |
| 260 | + "\n", |
| 261 | + "llm = HuggingFacePipeline(\n", |
| 262 | + " pipeline=pipe,\n", |
| 263 | + " model_kwargs={\"temperature\": 0.7},\n", |
| 264 | + ")" |
| 265 | + ] |
| 266 | + }, |
| 267 | + { |
| 268 | + "cell_type": "markdown", |
| 269 | + "id": "49ce0e72-e419-4310-85e9-09077d6c40b2", |
| 270 | + "metadata": {}, |
| 271 | + "source": [ |
| 272 | + "## Format Docs" |
| 273 | + ] |
| 274 | + }, |
| 275 | + { |
| 276 | + "cell_type": "code", |
| 277 | + "execution_count": null, |
| 278 | + "id": "b3c07a75-9220-4a82-a92e-3fc2727ad3ba", |
| 279 | + "metadata": {}, |
| 280 | + "outputs": [], |
| 281 | + "source": [ |
| 282 | + "def format_docs(docs):\n", |
| 283 | + " return \"\\n\\n\".join(doc.page_content for doc in docs)" |
| 284 | + ] |
| 285 | + }, |
| 286 | + { |
| 287 | + "cell_type": "markdown", |
| 288 | + "id": "f6266222-6ec3-495a-8f14-460549bab89d", |
| 289 | + "metadata": {}, |
| 290 | + "source": [ |
| 291 | + "## Create a chain using Prompt template" |
| 292 | + ] |
| 293 | + }, |
| 294 | + { |
| 295 | + "cell_type": "code", |
| 296 | + "execution_count": null, |
| 297 | + "id": "ec203d1a-104b-4583-9ba1-a6b4b0354367", |
| 298 | + "metadata": {}, |
| 299 | + "outputs": [], |
| 300 | + "source": [ |
| 301 | + "retriever = es.as_retriever(search_kwargs={\"k\": 10})\n", |
| 302 | + "\n", |
| 303 | + "template = \"\"\"Answer the question based only on the following context:\\n\n", |
| 304 | + "\n", |
| 305 | + "{context}\n", |
| 306 | + "\n", |
| 307 | + "Question: {question}\n", |
| 308 | + "\"\"\"\n", |
| 309 | + "prompt = ChatPromptTemplate.from_template(template)\n", |
| 310 | + "\n", |
| 311 | + "\n", |
| 312 | + "chain = (\n", |
| 313 | + " {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n", |
| 314 | + " | prompt\n", |
| 315 | + " | llm\n", |
| 316 | + " | StrOutputParser()\n", |
| 317 | + ")" |
| 318 | + ] |
| 319 | + }, |
| 320 | + { |
| 321 | + "cell_type": "markdown", |
| 322 | + "id": "8ae892dd-7442-4d4d-a804-1d717266e596", |
| 323 | + "metadata": {}, |
| 324 | + "source": [ |
| 325 | + "## Ask question" |
| 326 | + ] |
| 327 | + }, |
| 328 | + { |
| 329 | + "cell_type": "code", |
| 330 | + "execution_count": 11, |
| 331 | + "id": "ba312f17-44ae-423d-89a0-ea01eccd85b5", |
| 332 | + "metadata": {}, |
| 333 | + "outputs": [ |
| 334 | + { |
| 335 | + "data": { |
| 336 | + "text/plain": [ |
| 337 | + "'Answer: The pet policy in the office allows employees to bring pets to the office, subject to approval by the HR department. Pets covered under this policy include dogs, cats, and other small, non-exotic animals, subject to approval by the HR department.'" |
| 338 | + ] |
| 339 | + }, |
| 340 | + "execution_count": 11, |
| 341 | + "metadata": {}, |
| 342 | + "output_type": "execute_result" |
| 343 | + } |
| 344 | + ], |
| 345 | + "source": [ |
| 346 | + "chain.invoke(\"What is the pet policy in the office?\")" |
| 347 | + ] |
| 348 | + } |
| 349 | + ], |
| 350 | + "metadata": { |
| 351 | + "kernelspec": { |
| 352 | + "display_name": "Python 3 (ipykernel)", |
| 353 | + "language": "python", |
| 354 | + "name": "python3" |
| 355 | + }, |
| 356 | + "language_info": { |
| 357 | + "codemirror_mode": { |
| 358 | + "name": "ipython", |
| 359 | + "version": 3 |
| 360 | + }, |
| 361 | + "file_extension": ".py", |
| 362 | + "mimetype": "text/x-python", |
| 363 | + "name": "python", |
| 364 | + "nbconvert_exporter": "python", |
| 365 | + "pygments_lexer": "ipython3", |
| 366 | + "version": "3.11.4" |
| 367 | + } |
| 368 | + }, |
| 369 | + "nbformat": 4, |
| 370 | + "nbformat_minor": 5 |
| 371 | +} |
0 commit comments