-
Notifications
You must be signed in to change notification settings - Fork 518
Add Chroma Retrieval #323 #324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
from setuptools import setup | ||
|
||
setup() | ||
setup( | ||
install_requires=[ | ||
'chromadb>=0.4.22', | ||
] | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
from .retriever import Retriever | ||
from .amazon_kb_retriever import AmazonKnowledgeBasesRetriever, AmazonKnowledgeBasesRetrieverOptions | ||
from .amazon_kb_retriever import ( | ||
AmazonKnowledgeBasesRetriever, | ||
AmazonKnowledgeBasesRetrieverOptions | ||
) | ||
from .chroma_retriever import ChromaRetriever, ChromaRetrieverOptions | ||
|
||
__all__ = [ | ||
'Retriever', | ||
'AmazonKnowledgeBasesRetriever', | ||
'AmazonKnowledgeBasesRetrieverOptions' | ||
'AmazonKnowledgeBasesRetrieverOptions', | ||
'ChromaRetriever', | ||
'ChromaRetrieverOptions' | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Optional, Dict, List | ||
import chromadb | ||
from chromadb.config import Settings | ||
from agent_squad.retrievers import Retriever | ||
|
||
@dataclass | ||
class ChromaRetrieverOptions: | ||
"""Options for Chroma Retriever.""" | ||
collection_name: str | ||
persist_directory: Optional[str] = None | ||
host: Optional[str] = None | ||
port: Optional[int] = None | ||
ssl: Optional[bool] = False | ||
n_results: Optional[int] = 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggest higher default value like 10? |
||
embedding_function: Optional[Any] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be typed? also im thinking it shouldnt be optional? |
||
|
||
class ChromaRetriever(Retriever): | ||
def __init__(self, options: ChromaRetrieverOptions): | ||
super().__init__(options) | ||
self.options = options | ||
|
||
if not self.options.collection_name: | ||
raise ValueError("collection_name is required in options") | ||
|
||
# Initialize ChromaDB client | ||
if self.options.host and self.options.port: | ||
self.client = chromadb.HttpClient( | ||
host=self.options.host, | ||
port=self.options.port, | ||
ssl=self.options.ssl | ||
) | ||
else: | ||
self.client = chromadb.Client( | ||
Settings( | ||
persist_directory=self.options.persist_directory, | ||
is_persistent=bool(self.options.persist_directory) | ||
) | ||
) | ||
|
||
# Get or create collection | ||
self.collection = self.client.get_or_create_collection( | ||
name=self.options.collection_name, | ||
embedding_function=self.options.embedding_function | ||
) | ||
|
||
async def retrieve(self, text: str) -> List[Dict[str, Any]]: | ||
""" | ||
Retrieve documents from ChromaDB based on the input text. | ||
|
||
Args: | ||
text (str): The input text to base the retrieval on. | ||
|
||
Returns: | ||
List[Dict[str, Any]]: List of retrieved documents with their metadata. | ||
""" | ||
if not text: | ||
raise ValueError("Input text is required for retrieve") | ||
|
||
results = self.collection.query( | ||
query_texts=[text], | ||
n_results=self.options.n_results | ||
) | ||
|
||
# Format results | ||
formatted_results = [] | ||
for i in range(len(results['documents'][0])): | ||
formatted_results.append({ | ||
'content': { | ||
'text': results['documents'][0][i] | ||
}, | ||
'metadata': ( | ||
results['metadatas'][0][i] if results['metadatas'] else {} | ||
) | ||
}) | ||
|
||
return formatted_results | ||
|
||
async def retrieve_and_combine_results(self, text: str) -> str: | ||
""" | ||
Retrieve documents and combine their content into a single string. | ||
|
||
Args: | ||
text (str): The input text to base the retrieval on. | ||
|
||
Returns: | ||
str: Combined text from retrieved documents. | ||
""" | ||
results = await self.retrieve(text) | ||
return self.combine_retrieval_results(results) | ||
|
||
async def retrieve_and_generate(self, text: str) -> Any: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this do and is there a reason it cant be implemented? |
||
""" | ||
Placeholder for retrieve and generate functionality. | ||
This can be implemented based on specific requirements. | ||
|
||
Args: | ||
text (str): The input text to base the retrieval on. | ||
|
||
Returns: | ||
Any: Generated content based on retrieved documents. | ||
""" | ||
raise NotImplementedError( | ||
"retrieve_and_generate is not implemented for ChromaRetriever" | ||
) | ||
|
||
@staticmethod | ||
def combine_retrieval_results( | ||
retrieval_results: List[Dict[str, Any]] | ||
) -> str: | ||
""" | ||
Combine retrieval results into a single string. | ||
|
||
Args: | ||
retrieval_results (List[Dict[str, Any]]): List of retrieved documents. | ||
|
||
Returns: | ||
str: Combined text from retrieved documents. | ||
""" | ||
return "\n".join( | ||
result['content']['text'] | ||
for result in retrieval_results | ||
if result and result.get('content') and isinstance( | ||
result['content'].get('text'), str | ||
) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pytest | ||
from unittest.mock import Mock, patch | ||
from agent_squad.retrievers import ChromaRetriever, ChromaRetrieverOptions | ||
|
||
|
||
@pytest.fixture | ||
def mock_chroma_client(): | ||
with patch('chromadb.Client') as mock_client: | ||
mock_collection = Mock() | ||
mock_collection.query.return_value = { | ||
'documents': [['doc1', 'doc2']], | ||
'metadatas': [['meta1', 'meta2']] | ||
} | ||
mock_client.return_value.get_or_create_collection.return_value = ( | ||
mock_collection | ||
) | ||
yield mock_client | ||
|
||
|
||
@pytest.fixture | ||
def chroma_retriever(mock_chroma_client): | ||
options = ChromaRetrieverOptions( | ||
collection_name='test_collection', | ||
persist_directory='./test_data' | ||
) | ||
return ChromaRetriever(options) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_retrieve(chroma_retriever): | ||
results = await chroma_retriever.retrieve("test query") | ||
assert len(results) == 2 | ||
assert results[0]['content']['text'] == 'doc1' | ||
assert results[0]['metadata'] == 'meta1' | ||
assert results[1]['content']['text'] == 'doc2' | ||
assert results[1]['metadata'] == 'meta2' | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_retrieve_and_combine_results(chroma_retriever): | ||
result = await chroma_retriever.retrieve_and_combine_results("test query") | ||
assert result == "doc1\ndoc2" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_retrieve_and_generate(chroma_retriever): | ||
with pytest.raises(NotImplementedError): | ||
await chroma_retriever.retrieve_and_generate("test query") | ||
|
||
|
||
def test_init_without_collection_name(): | ||
with pytest.raises(ValueError): | ||
ChromaRetriever(ChromaRetrieverOptions(collection_name='')) | ||
|
||
|
||
def test_init_with_remote_client(): | ||
options = ChromaRetrieverOptions( | ||
collection_name='test_collection', | ||
host='localhost', | ||
port=8000 | ||
) | ||
with patch('chromadb.HttpClient') as mock_client: | ||
mock_collection = Mock() | ||
mock_client.return_value.get_or_create_collection.return_value = ( | ||
mock_collection | ||
) | ||
retriever = ChromaRetriever(options) | ||
assert retriever.client is not None | ||
mock_client.assert_called_once_with( | ||
host='localhost', | ||
port=8000, | ||
ssl=False | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,11 @@ | ||
module.exports = { | ||
preset: 'ts-jest', | ||
testEnvironment: 'node', | ||
transform: { | ||
'^.+\\.tsx?$': ['ts-jest', { | ||
tsconfig: 'tsconfig.test.json' | ||
}] | ||
}, | ||
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], | ||
modulePathIgnorePatterns: ['<rootDir>/examples/'], | ||
}; |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will the maintainers be ok with this?