Skip to content

Commit 0ac7f3d

Browse files
Mirko OrtensiMirko Ortensi
authored andcommitted
first commit
1 parent 7618505 commit 0ac7f3d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+12225
-1
lines changed

Dockerfile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
FROM docker.io/python:3.9
2+
WORKDIR /app
3+
COPY src /app/src/
4+
COPY requirements.txt wsgi.py /app
5+
6+
RUN pip install --no-cache-dir -r requirements.txt
7+
8+
ENV GUNICORN_CMD_ARGS="--workers 1 --bind 0.0.0.0:8000 --timeout 600 --log-level debug --capture-output --error-logfile ./gunicorn.log"
9+
ENV PYTHONUNBUFFERED=1
10+
EXPOSE 8000
11+
12+
CMD [ "gunicorn", "wsgi:create_app()" ]

README.md

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,51 @@
1-
# MiniPilot
1+
# MiniPilot
2+
3+
This application implements a chatbot you can train with your data. The example provided is a movie recommender system.
4+
5+
![demo](src/static/images/minipilot.gif)
6+
7+
The system uses:
8+
9+
- Redis Stack as a Vector Database to store the dataset and vectorize the entries to perform [Vector Similarity Search (VSS)](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/vectors/) for RAG
10+
- The [IMDB movies dataset](https://www.kaggle.com/datasets/ashpalsingh1525/imdb-movies-dataset), which contains 10000+ movies from the IMDB Movies dataset
11+
- OpenAI ChatGPT Large Language Model (LLM) [ChatCompletion API](https://platform.openai.com/docs/guides/gpt/chat-completions-api)
12+
13+
## setup
14+
15+
Clone the repository
16+
17+
```commandline
18+
git clone https://github.com/mortensi/MiniPilot.git
19+
```
20+
21+
Make sure you have an [OpenAI token](https://openai.com/api/pricing/), then install the requirements
22+
23+
```commandline
24+
pip install -r requirements.txt
25+
```
26+
27+
Then set the environment variables in a `.env` file.
28+
29+
```commandline
30+
DB_SERVICE="127.0.0.1"
31+
DB_PORT=6379
32+
DB_PWD=""
33+
34+
MINIPILOT_DEBUG = "True"
35+
MINIPILOT_MODEL="gpt-3.5-turbo-16k"
36+
37+
OPENAI_API_KEY="your-openai-key"
38+
```
39+
40+
You can also use the `export` command.
41+
42+
```commandline
43+
export DB_SERVICE="127.0.0.1" DB_PORT=6379 DB_PWD="" MINIPILOT_DEBUG = "True" MINIPILOT_MODEL="gpt-3.5-turbo-16k" OPENAI_API_KEY="your-openai-key"
44+
```
45+
46+
Start the server `./start.sh`
47+
48+
Load the data with `python3 initialize.sh`
49+
50+
Point your browsert to `http://127.0.0.1:5005/` and start asking.
51+

imdb_movies.csv

Lines changed: 10179 additions & 0 deletions
Large diffs are not rendered by default.

initialize.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import csv
2+
import os
3+
from datetime import datetime
4+
import logging
5+
6+
import redis
7+
from langchain.text_splitter import RecursiveCharacterTextSplitter
8+
from langchain_community.embeddings import OpenAIEmbeddings
9+
from langchain_community.vectorstores.redis import Redis
10+
11+
12+
def generate_redis_connection_string():
13+
if os.getenv('DB_PWD', ''):
14+
connection_string = f"redis://:{os.getenv('DB_PWD', '')}@{os.getenv('DB_SERVICE', '127.0.0.1')}:{int(os.getenv('DB_PORT', 6379))}"
15+
else:
16+
connection_string = f"redis://{os.getenv('DB_SERVICE', '127.0.0.1')}:{int(os.getenv('DB_PORT', 6379))}"
17+
18+
return connection_string
19+
20+
21+
def load():
22+
conn = redis.from_url(generate_redis_connection_string())
23+
24+
# Create a new index
25+
index_name = f"minipilot_rag_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}_idx"
26+
27+
index_schema = {
28+
"tag": [{"name": "genre"},
29+
{"name": "country"}],
30+
"text": [{"name": "names"}],
31+
"numeric": [{"name": "revenue"},
32+
{"name": "score"},
33+
{"name": "date_x"}]
34+
}
35+
36+
vector_schema = {
37+
"algorithm": "HNSW"
38+
}
39+
40+
# If there is no index for RAG, this is the first index; then, manually point an alias to it
41+
try:
42+
conn.ft('convai_rag_alias').info()
43+
except redis.exceptions.ResponseError as e:
44+
logging.warning(f"No alias exists for semantic search. Create the alias when indexing is done")
45+
46+
# Validate there is an OPENAI_API_KEY passed in the environment
47+
try:
48+
embedding_model = OpenAIEmbeddings()
49+
except Exception as e:
50+
logging.error(e)
51+
return
52+
53+
doc_splitter = RecursiveCharacterTextSplitter( chunk_size=10000,
54+
chunk_overlap=50,
55+
length_function=len,
56+
add_start_index=True
57+
)
58+
59+
with open("imdb_movies.csv", encoding='utf-8') as csvf:
60+
csvReader = csv.DictReader(csvf)
61+
cnt = 0
62+
for row in csvReader:
63+
movie = f"movie title is: {row['names']}\n"
64+
movie += f"movie genre is: {row['genre']}\n"
65+
movie += f"movie crew is: {row['crew']}\n"
66+
movie += f"movie score is: {row['score']}\n"
67+
movie += f"movie overview is: {row['overview']}\n"
68+
movie += f"movie country is: {row['country']}\n"
69+
movie += f"movie revenue is: {row['revenue']}\n"
70+
71+
72+
cnt += 1
73+
splits = doc_splitter.split_text(row['overview'])
74+
unix_timestamp = int(datetime.strptime(row['date_x'].strip(), "%m/%d/%Y").timestamp())
75+
metadatas = {"names": row['names'],
76+
"genre": row['genre'],
77+
"country": row['country'],
78+
"revenue": row['revenue'],
79+
"score": row['score'],
80+
"date_x": unix_timestamp}
81+
82+
if len(splits) > 0:
83+
Redis.from_texts(texts=splits,
84+
metadatas=[metadatas] * len(splits),
85+
embedding=embedding_model,
86+
index_name=index_name,
87+
index_schema=index_schema,
88+
vector_schema=vector_schema,
89+
redis_url=generate_redis_connection_string())
90+
91+
load()

requirements.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
beautifulsoup4==4.12.3
2+
Flask==3.0.1
3+
flask_cors==4.0.0
4+
flask_paginate==2023.10.24
5+
flask-restx==1.3.0
6+
flask_session==0.6.0
7+
gunicorn==21.2.0
8+
langchain==0.1.4
9+
langchain_community==0.0.16
10+
langchain_core==0.1.17
11+
openai==1.10.0
12+
python-dotenv==1.0.1
13+
redis==5.0.1
14+
redisvl==0.0.7
15+
Requests==2.31.0
16+
scrapy==2.11.0
17+
spacy==3.7.4
18+
spacy-legacy==3.0.12
19+
tiktoken==0.5.2
20+
WTForms==3.1.2
21+
wtforms_json==0.3.5
22+
sentence-transformers==2.3.0
23+
tenacity==8.2.2
24+
validators==0.22.0

src/.DS_Store

6 KB
Binary file not shown.

src/__init__.py

Whitespace-only changes.

src/_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = 0.9.0

src/apis/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from flask_restx import Api
2+
3+
from .service import api as ns_service
4+
5+
authorizations = {
6+
'api_key': {
7+
'type': 'apiKey',
8+
'in': 'header',
9+
'name': 'admin-token'
10+
}
11+
}
12+
13+
api = Api(
14+
title='Minipilot Server REST API',
15+
version='1.0',
16+
description='Welcome to the Minipilot Server REST API. Use this API to train your GenAI chatbot with online docs',
17+
doc='/api',
18+
prefix='/api',
19+
authorizations=authorizations,
20+
security='api_key'
21+
)
22+
23+
api.add_namespace(ns_service)

src/apis/service.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import urllib
2+
3+
from flask import request, jsonify, Response
4+
from flask_paginate import Pagination
5+
from flask_restx import Resource, Namespace, reqparse, inputs
6+
from langchain_community.chat_message_histories import RedisChatMessageHistory
7+
from redis.commands.search.query import Query
8+
9+
from src.apis.validation import rate_limiter
10+
from src.common.config import MINIPILOT_SEARCH_RESULTS, REDIS_CFG, MINIPILOT_HISTORY_TIMEOUT
11+
from src.common.utils import get_db, parse_query_string, extract_keywords, generate_redis_connection_string, \
12+
history_to_json
13+
from src.core.RedisRetrievalChain import RedisRetrievalChain
14+
15+
api = Namespace('Services', path="/", description='Chat and search services')
16+
17+
18+
def min_length(min_len):
19+
def validate(s):
20+
if len(s) < min_len:
21+
raise ValueError(f'Minimum length is {min_len}')
22+
return s
23+
return validate
24+
25+
26+
def validate_length(min_len, max_len):
27+
def validate(s):
28+
if len(s) < min_len:
29+
raise ValueError(f'Minimum length is {min_len}')
30+
if len(s) > max_len:
31+
raise ValueError(f'Maximum length is {max_len}')
32+
return s
33+
return validate
34+
35+
36+
@api.route('/history')
37+
class ChatHistory(Resource):
38+
@api.doc(params={'session-id': {'in': 'header', 'description': 'session-id'}})
39+
@api.doc(description='Get user conversation history', consumes=['application/json'])
40+
def get(self):
41+
"""Get user conversation history"""
42+
session_id = str(request.headers.get("session-id"))
43+
redis_history = RedisChatMessageHistory(url=generate_redis_connection_string(REDIS_CFG["host"], REDIS_CFG["port"], REDIS_CFG["password"]),
44+
session_id=session_id,
45+
key_prefix='minipilot:history:',
46+
ttl=MINIPILOT_HISTORY_TIMEOUT)
47+
return history_to_json(redis_history.messages), 200
48+
49+
50+
@api.route('/reset')
51+
class ChatHistoryReset(Resource):
52+
@api.doc(params={'session-id': {'in': 'header', 'description': 'session-id'}})
53+
@api.doc(description='Reset user conversation history', consumes=['application/json'])
54+
def post(self):
55+
"""Reset user conversation history"""
56+
session_id = str(request.headers.get("session-id"))
57+
engine = RedisRetrievalChain(session_id)
58+
engine.reset_history()
59+
return {"response": "Conversation restarted"}, 200
60+
61+
62+
@api.route('/chat')
63+
class Chat(Resource):
64+
service_query_parser = reqparse.RequestParser()
65+
service_query_parser.add_argument('q', type=validate_length(4, 500), required=True, help='Chat query', location='args')
66+
67+
@api.expect(service_query_parser)
68+
@api.doc(params={'session-id': {'in': 'header', 'description': 'session-id'}})
69+
@rate_limiter(request)
70+
@api.doc(description='Ask a question in natural language: will answer, post the answer to the history and semantic cache', consumes=['application/json'])
71+
def post(self):
72+
"""Ask a question in natural language: will answer, post the answer to the history and semantic cache"""
73+
args = self.service_query_parser.parse_args(req=request)
74+
session_id = str(request.headers.get("session-id"))
75+
76+
engine = RedisRetrievalChain(session_id)
77+
engine.ask(args['q'])
78+
return Response(engine.streamer(), content_type="text/event-stream", headers={'X-Accel-Buffering': 'no'})
79+
80+
81+
@api.route('/references')
82+
class SearchReferences(Resource):
83+
service_query_parser = reqparse.RequestParser()
84+
service_query_parser.add_argument('q', type=str, required=True, help='References query', location='args')
85+
86+
@api.expect(service_query_parser)
87+
@api.doc(description='Semantic references from a natural language query', consumes=['application/json'])
88+
def get(self):
89+
"""Semantic references from a natural language query"""
90+
args = self.service_query_parser.parse_args(req=request)
91+
# This method is session-less, just performs vector search, we reuse the RedisRetrievalChain utility, though
92+
# And indicate a fake session id TODO clean up and use a new session-less constructor, or RedisVL
93+
session_id = "xxxxxxx"
94+
engine = RedisRetrievalChain(session_id)
95+
references = engine.references(urllib.parse.unquote(args['q']))
96+
return references, 200
97+
98+
99+
100+

0 commit comments

Comments
 (0)