Skip to content

Commit b28ddce

Browse files
caseyclementsWaVEVJibola
authored
[Feature] Addition of MongoDB Atlas datastore (#428)
* docker compose file. * search example. * mongodb atlas datastore. * refactor, docstring and notebook cleaning. * docstring. * fix attributes names. * Functional tests. * Example adjustement. * setup.md * remove some useless comments. * wrong docker image. * Minor documentation fixes. * Update example. * refactor. * default as a default collection. * TODO resolved. * Refactor delete. * fix readme and setup.md * add warning when delete without criteria. * rename private function. * replace pymongo to motor and fix integration test. * Refactor code and adjust tests * wait for assert function. * Update docs/providers/mongodb_atlas/setup.md Co-authored-by: Jib <[email protected]> * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * Increase oversampling factor to 10. * Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py Co-authored-by: Jib <[email protected]> * Update tests/datastore/providers/mongodb_atlas/test_mongodb_datastore.py Co-authored-by: Jib <[email protected]> * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * Init docstring. * Default parameters * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * refactor sample_embeddings. * Apply suggestions from code review Co-authored-by: Jib <[email protected]> * refactor delete. * Version added. * Update datastore/providers/mongodb_atlas_datastore.py Co-authored-by: Jib <[email protected]> * Removed _atlas from folder name to keep it simple and self-consistent * Expanded setup.md * Fixed a couple typos in docstrings * Add optional EMBEDDING_DIMENSION to get_embedding * Fixed typo in kwarg * Extended setup.md * Edits to environment variable table * Added authentication token descriptions * Removed hardcoded vector size * Added semantic search example * Added instructions to integration tests * Cleanup * Removed pathname from example. * Override DataStore.upsert in MongoDBAtlasDataStore to increase performance. * upsert now returns ids of chunks, which is what each datastore document is * Added full integration test * test_integration now uses FastAPI TestClient * Retries query until response contains number requested --------- Co-authored-by: Emanuel Lupi <[email protected]> Co-authored-by: Jib <[email protected]>
1 parent b808c10 commit b28ddce

File tree

11 files changed

+1477
-6
lines changed

11 files changed

+1477
-6
lines changed

README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ This README provides detailed information on how to set up, develop, and deploy
4242
- [Choosing a Vector Database](#choosing-a-vector-database)
4343
- [Pinecone](#pinecone)
4444
- [Elasticsearch](#elasticsearch)
45+
- [MongoDB Atlas](#mongodb-atlas)
4546
- [Weaviate](#weaviate)
4647
- [Zilliz](#zilliz)
4748
- [Milvus](#milvus)
@@ -190,6 +191,12 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
190191
export ELASTICSEARCH_INDEX=<elasticsearch_index_name>
191192
export ELASTICSEARCH_REPLICAS=<elasticsearch_replicas>
192193
export ELASTICSEARCH_SHARDS=<elasticsearch_shards>
194+
195+
# MongoDB Atlas
196+
export MONGODB_URI=<mongodb_uri>
197+
export MONGODB_DATABASE=<mongodb_database>
198+
export MONGODB_COLLECTION=<mongodb_collection>
199+
export MONGODB_INDEX=<mongodb_index>
193200
```
194201

195202
10. Run the API locally: `poetry run start`
@@ -352,8 +359,8 @@ poetry install
352359
The API requires the following environment variables to work:
353360

354361
| Name | Required | Description |
355-
| ---------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
356-
| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `elasticsearch`, `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`, `supabase`, `postgres`, `analyticdb`. |
362+
| ---------------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
363+
| `DATASTORE` | Yes | This specifies the vector database provider you want to use to store and query embeddings. You can choose from `elasticsearch`, `chroma`, `pinecone`, `weaviate`, `zilliz`, `milvus`, `qdrant`, `redis`, `azuresearch`, `supabase`, `postgres`, `analyticdb`, `mongodb-atlas`. |
357364
| `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). |
358365
| `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the one of the OpenAI embeddings model. You can get an API key by creating an account on [OpenAI](https://openai.com/). |
359366

@@ -434,6 +441,10 @@ For detailed setup instructions, refer to [`/docs/providers/llama/setup.md`](/do
434441

435442
[Elasticsearch](https://www.elastic.co/guide/en/elasticsearch/reference/current/index.html) currently supports storing vectors through the `dense_vector` field type and uses them to calculate document scores. Elasticsearch 8.0 builds on this functionality to support fast, approximate nearest neighbor search (ANN). This represents a much more scalable approach, allowing vector search to run efficiently on large datasets. For detailed setup instructions, refer to [`/docs/providers/elasticsearch/setup.md`](/docs/providers/elasticsearch/setup.md).
436443

444+
#### Mongodb-Atlas
445+
446+
[MongoDB Atlas](https://www.mongodb.com/docs/atlas/getting-started/) Currently, the procedure involves generating an Atlas Vector Search index for all collections featuring vector embeddings of 2048 dimensions or fewer in width. This applies to diverse data types coexisting with additional data on your Atlas cluster, and the process is executed through the Atlas UI and Atlas Administration AP, refer to [`/docs/providers/mongodb_atlas/setup.md`](/docs/providers/mongodb_atlas/setup.md).
447+
437448
### Running the API locally
438449

439450
To run the API locally, you first need to set the requisite environment variables with the `export` command:

datastore/datastore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def upsert(
4444
@abstractmethod
4545
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
4646
"""
47-
Takes in a list of list of document chunks and inserts them into the database.
47+
Takes in a list of document chunks and inserts them into the database.
4848
Return a list of document ids.
4949
"""
5050

@@ -54,7 +54,7 @@ async def query(self, queries: List[Query]) -> List[QueryResult]:
5454
"""
5555
Takes in a list of queries and filters and returns a list of query results with matching document chunks and scores.
5656
"""
57-
# get a list of of just the queries from the Query list
57+
# get a list of just the queries from the Query list
5858
query_texts = [query.query for query in queries]
5959
query_embeddings = get_embeddings(query_texts)
6060
# hydrate the queries with embeddings

datastore/factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ async def get_datastore() -> DataStore:
6868
)
6969

7070
return ElasticsearchDataStore()
71+
case "mongodb":
72+
from datastore.providers.mongodb_atlas_datastore import (
73+
MongoDBAtlasDataStore,
74+
)
75+
76+
return MongoDBAtlasDataStore()
7177
case _:
7278
raise ValueError(
7379
f"Unsupported vector database: {datastore}. "
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import os
2+
from typing import Dict, List, Any, Optional
3+
from loguru import logger
4+
from importlib.metadata import version
5+
from motor.motor_asyncio import AsyncIOMotorClient
6+
from pymongo.driver_info import DriverInfo
7+
from pymongo import UpdateOne
8+
9+
from datastore.datastore import DataStore
10+
from functools import cached_property
11+
from models.models import (
12+
Document,
13+
DocumentChunk,
14+
DocumentChunkWithScore,
15+
DocumentMetadataFilter,
16+
QueryResult,
17+
QueryWithEmbedding,
18+
)
19+
from services.chunks import get_document_chunks
20+
from services.date import to_unix_timestamp
21+
22+
23+
MONGODB_CONNECTION_URI = os.environ.get("MONGODB_URI")
24+
MONGODB_DATABASE = os.environ.get("MONGODB_DATABASE", "default")
25+
MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "default")
26+
MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "default")
27+
OVERSAMPLING_FACTOR = 10
28+
MAX_CANDIDATES = 10_000
29+
30+
31+
class MongoDBAtlasDataStore(DataStore):
32+
33+
def __init__(
34+
self,
35+
atlas_connection_uri: str = MONGODB_CONNECTION_URI,
36+
index_name: str = MONGODB_INDEX,
37+
database_name: str = MONGODB_DATABASE,
38+
collection_name: str = MONGODB_COLLECTION,
39+
oversampling_factor: float = OVERSAMPLING_FACTOR,
40+
):
41+
"""
42+
Initialize a MongoDBAtlasDataStore instance.
43+
44+
Parameters:
45+
- index_name (str, optional): Vector search index. If not provided, default index name is used.
46+
- database_name (str, optional): Database. If not provided, default database name is used.
47+
- collection_name (str, optional): Collection. If not provided, default collection name is used.
48+
- oversampling_factor (float, optional): Oversampling factor for data augmentation.
49+
Default is OVERSAMPLING_FACTOR.
50+
51+
Raises:
52+
- ValueError: If index_name is not a valid string.
53+
54+
Attributes:
55+
- index_name (str): Name of the index.
56+
- database_name (str): Name of the database.
57+
- collection_name (str): Name of the collection.
58+
- oversampling_factor (float): Oversampling factor for data augmentation.
59+
"""
60+
61+
self.atlas_connection_uri = atlas_connection_uri
62+
self.oversampling_factor = oversampling_factor
63+
self.database_name = database_name
64+
self.collection_name = collection_name
65+
66+
if not (index_name and isinstance(index_name, str)):
67+
raise ValueError("Provide a valid index name")
68+
self.index_name = index_name
69+
70+
# TODO: Create index via driver https://jira.mongodb.org/browse/PYTHON-4175
71+
# self._create_search_index(num_dimensions=1536, path="embedding", similarity="dotProduct", type="vector")
72+
73+
@cached_property
74+
def client(self):
75+
return self._connect_to_mongodb_atlas(
76+
atlas_connection_uri=MONGODB_CONNECTION_URI
77+
)
78+
79+
async def upsert(
80+
self, documents: List[Document], chunk_token_size: Optional[int] = None
81+
) -> List[str]:
82+
"""
83+
Takes in a list of Documents, chunks them, and upserts the chunks into the database.
84+
Return a list the ids of the document chunks.
85+
"""
86+
chunks = get_document_chunks(documents, chunk_token_size)
87+
return await self._upsert(chunks)
88+
89+
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
90+
"""
91+
Takes in a list of document chunks and inserts them into the database.
92+
Return a list of document ids.
93+
"""
94+
documents_to_upsert = []
95+
inserted_ids = []
96+
for chunk_list in chunks.values():
97+
for chunk in chunk_list:
98+
inserted_ids.append(chunk.id)
99+
documents_to_upsert.append(
100+
UpdateOne({'_id': chunk.id}, {"$set": chunk.dict()}, upsert=True)
101+
)
102+
logger.info(f"Upsert documents into MongoDB collection: {self.database_name}: {self.collection_name}")
103+
await self.client[self.database_name][self.collection_name].bulk_write(documents_to_upsert)
104+
logger.info("Upsert successful")
105+
106+
return inserted_ids
107+
108+
async def _query(
109+
self,
110+
queries: List[QueryWithEmbedding],
111+
) -> List[QueryResult]:
112+
"""
113+
Takes in a list of queries with embeddings and filters and returns
114+
a list of query results with matching document chunks and scores.
115+
"""
116+
results = []
117+
for query in queries:
118+
query_result = await self._execute_embedding_query(query)
119+
results.append(query_result)
120+
121+
return results
122+
123+
async def _execute_embedding_query(self, query: QueryWithEmbedding) -> QueryResult:
124+
"""
125+
Execute a MongoDB query using vector search on the specified collection and
126+
return the result of the query, including matched documents and their scores.
127+
"""
128+
pipeline = [
129+
{
130+
'$vectorSearch': {
131+
'index': self.index_name,
132+
'path': 'embedding',
133+
'queryVector': query.embedding,
134+
'numCandidates': min(query.top_k * self.oversampling_factor, MAX_CANDIDATES),
135+
'limit': query.top_k
136+
}
137+
}, {
138+
'$project': {
139+
'text': 1,
140+
'metadata': 1,
141+
'score': {
142+
'$meta': 'vectorSearchScore'
143+
}
144+
}
145+
}
146+
]
147+
148+
async with self.client[self.database_name][self.collection_name].aggregate(pipeline) as cursor:
149+
results = [
150+
self._convert_mongodb_document_to_document_chunk_with_score(doc)
151+
async for doc in cursor
152+
]
153+
154+
return QueryResult(
155+
query=query.query,
156+
results=results,
157+
)
158+
159+
async def delete(
160+
self,
161+
ids: Optional[List[str]] = None,
162+
filter: Optional[DocumentMetadataFilter] = None,
163+
delete_all: Optional[bool] = None,
164+
) -> bool:
165+
"""
166+
Removes documents by ids, filter, or everything in the datastore.
167+
Returns whether the operation was successful.
168+
169+
Note that ids refer to those in the datastore,
170+
which are those of the **DocumentChunks**
171+
"""
172+
# Delete all documents from the collection if delete_all is True
173+
if delete_all:
174+
logger.info("Deleting all documents from collection")
175+
mg_filter = {}
176+
177+
# Delete by ids
178+
elif ids:
179+
logger.info(f"Deleting documents with ids: {ids}")
180+
mg_filter = {"_id": {"$in": ids}}
181+
182+
# Delete by filters
183+
elif filter:
184+
mg_filter = self._build_mongo_filter(filter)
185+
logger.info(f"Deleting documents with filter: {mg_filter}")
186+
# Do nothing
187+
else:
188+
logger.warning("No criteria set; nothing to delete args: ids: %s, filter: %s delete_all: %s", ids, filter, delete_all)
189+
return True
190+
191+
try:
192+
await self.client[self.database_name][self.collection_name].delete_many(mg_filter)
193+
logger.info("Deleted documents successfully")
194+
except Exception as e:
195+
logger.error("Error deleting documents with filter: %s -- error: %s", mg_filter, e)
196+
return False
197+
198+
return True
199+
200+
def _convert_mongodb_document_to_document_chunk_with_score(
201+
self, document: Dict
202+
) -> DocumentChunkWithScore:
203+
# Convert MongoDB document to DocumentChunkWithScore
204+
return DocumentChunkWithScore(
205+
id=document.get("_id"),
206+
text=document["text"],
207+
metadata=document.get("metadata"),
208+
score=document.get("score"),
209+
)
210+
211+
def _build_mongo_filter(
212+
self, filter: Optional[DocumentMetadataFilter] = None
213+
) -> Dict[str, Any]:
214+
"""
215+
Generate MongoDB query filters based on the provided DocumentMetadataFilter.
216+
"""
217+
if filter is None:
218+
return {}
219+
220+
mongo_filters = {
221+
"$and": [],
222+
}
223+
224+
# For each field in the MetadataFilter,
225+
# check if it has a value and add the corresponding MongoDB filter expression
226+
for field, value in filter.dict().items():
227+
if value is not None:
228+
if field == "start_date":
229+
mongo_filters["$and"].append(
230+
{"created_at": {"$gte": to_unix_timestamp(value)}}
231+
)
232+
elif field == "end_date":
233+
mongo_filters["$and"].append(
234+
{"created_at": {"$lte": to_unix_timestamp(value)}}
235+
)
236+
else:
237+
mongo_filters["$and"].append(
238+
{f"metadata.{field}": value}
239+
)
240+
241+
return mongo_filters
242+
243+
@staticmethod
244+
def _connect_to_mongodb_atlas(atlas_connection_uri: str):
245+
"""
246+
Establish a connection to MongoDB Atlas.
247+
"""
248+
249+
client = AsyncIOMotorClient(
250+
atlas_connection_uri,
251+
driver=DriverInfo(name="Chatgpt Retrieval Plugin", version=version("chatgpt_retrieval_plugin")))
252+
return client

0 commit comments

Comments
 (0)