Skip to content

Add Chroma Retrieval #323 #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from setuptools import setup

setup()
setup(
install_requires=[

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will the maintainers be ok with this?

'chromadb>=0.4.22',
]
)
10 changes: 8 additions & 2 deletions python/src/agent_squad/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .retriever import Retriever
from .amazon_kb_retriever import AmazonKnowledgeBasesRetriever, AmazonKnowledgeBasesRetrieverOptions
from .amazon_kb_retriever import (
AmazonKnowledgeBasesRetriever,
AmazonKnowledgeBasesRetrieverOptions
)
from .chroma_retriever import ChromaRetriever, ChromaRetrieverOptions

__all__ = [
'Retriever',
'AmazonKnowledgeBasesRetriever',
'AmazonKnowledgeBasesRetrieverOptions'
'AmazonKnowledgeBasesRetrieverOptions',
'ChromaRetriever',
'ChromaRetrieverOptions'
]
126 changes: 126 additions & 0 deletions python/src/agent_squad/retrievers/chroma_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from dataclasses import dataclass
from typing import Any, Optional, Dict, List
import chromadb
from chromadb.config import Settings
from agent_squad.retrievers import Retriever

@dataclass
class ChromaRetrieverOptions:
"""Options for Chroma Retriever."""
collection_name: str
persist_directory: Optional[str] = None
host: Optional[str] = None
port: Optional[int] = None
ssl: Optional[bool] = False
n_results: Optional[int] = 4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest higher default value like 10?

embedding_function: Optional[Any] = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be typed? also im thinking it shouldnt be optional?


class ChromaRetriever(Retriever):
def __init__(self, options: ChromaRetrieverOptions):
super().__init__(options)
self.options = options

if not self.options.collection_name:
raise ValueError("collection_name is required in options")

# Initialize ChromaDB client
if self.options.host and self.options.port:
self.client = chromadb.HttpClient(
host=self.options.host,
port=self.options.port,
ssl=self.options.ssl
)
else:
self.client = chromadb.Client(
Settings(
persist_directory=self.options.persist_directory,
is_persistent=bool(self.options.persist_directory)
)
)

# Get or create collection
self.collection = self.client.get_or_create_collection(
name=self.options.collection_name,
embedding_function=self.options.embedding_function
)

async def retrieve(self, text: str) -> List[Dict[str, Any]]:
"""
Retrieve documents from ChromaDB based on the input text.

Args:
text (str): The input text to base the retrieval on.

Returns:
List[Dict[str, Any]]: List of retrieved documents with their metadata.
"""
if not text:
raise ValueError("Input text is required for retrieve")

results = self.collection.query(
query_texts=[text],
n_results=self.options.n_results
)

# Format results
formatted_results = []
for i in range(len(results['documents'][0])):
formatted_results.append({
'content': {
'text': results['documents'][0][i]
},
'metadata': (
results['metadatas'][0][i] if results['metadatas'] else {}
)
})

return formatted_results

async def retrieve_and_combine_results(self, text: str) -> str:
"""
Retrieve documents and combine their content into a single string.

Args:
text (str): The input text to base the retrieval on.

Returns:
str: Combined text from retrieved documents.
"""
results = await self.retrieve(text)
return self.combine_retrieval_results(results)

async def retrieve_and_generate(self, text: str) -> Any:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do and is there a reason it cant be implemented?

"""
Placeholder for retrieve and generate functionality.
This can be implemented based on specific requirements.

Args:
text (str): The input text to base the retrieval on.

Returns:
Any: Generated content based on retrieved documents.
"""
raise NotImplementedError(
"retrieve_and_generate is not implemented for ChromaRetriever"
)

@staticmethod
def combine_retrieval_results(
retrieval_results: List[Dict[str, Any]]
) -> str:
"""
Combine retrieval results into a single string.

Args:
retrieval_results (List[Dict[str, Any]]): List of retrieved documents.

Returns:
str: Combined text from retrieved documents.
"""
return "\n".join(
result['content']['text']
for result in retrieval_results
if result and result.get('content') and isinstance(
result['content'].get('text'), str
)
)
73 changes: 73 additions & 0 deletions python/src/tests/retrievers/test_chroma_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
from unittest.mock import Mock, patch
from agent_squad.retrievers import ChromaRetriever, ChromaRetrieverOptions


@pytest.fixture
def mock_chroma_client():
with patch('chromadb.Client') as mock_client:
mock_collection = Mock()
mock_collection.query.return_value = {
'documents': [['doc1', 'doc2']],
'metadatas': [['meta1', 'meta2']]
}
mock_client.return_value.get_or_create_collection.return_value = (
mock_collection
)
yield mock_client


@pytest.fixture
def chroma_retriever(mock_chroma_client):
options = ChromaRetrieverOptions(
collection_name='test_collection',
persist_directory='./test_data'
)
return ChromaRetriever(options)


@pytest.mark.asyncio
async def test_retrieve(chroma_retriever):
results = await chroma_retriever.retrieve("test query")
assert len(results) == 2
assert results[0]['content']['text'] == 'doc1'
assert results[0]['metadata'] == 'meta1'
assert results[1]['content']['text'] == 'doc2'
assert results[1]['metadata'] == 'meta2'


@pytest.mark.asyncio
async def test_retrieve_and_combine_results(chroma_retriever):
result = await chroma_retriever.retrieve_and_combine_results("test query")
assert result == "doc1\ndoc2"


@pytest.mark.asyncio
async def test_retrieve_and_generate(chroma_retriever):
with pytest.raises(NotImplementedError):
await chroma_retriever.retrieve_and_generate("test query")


def test_init_without_collection_name():
with pytest.raises(ValueError):
ChromaRetriever(ChromaRetrieverOptions(collection_name=''))


def test_init_with_remote_client():
options = ChromaRetrieverOptions(
collection_name='test_collection',
host='localhost',
port=8000
)
with patch('chromadb.HttpClient') as mock_client:
mock_collection = Mock()
mock_client.return_value.get_or_create_collection.return_value = (
mock_collection
)
retriever = ChromaRetriever(options)
assert retriever.client is not None
mock_client.assert_called_once_with(
host='localhost',
port=8000,
ssl=False
)
5 changes: 5 additions & 0 deletions typescript/jest.config.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'^.+\\.tsx?$': ['ts-jest', {
tsconfig: 'tsconfig.test.json'
}]
},
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
modulePathIgnorePatterns: ['<rootDir>/examples/'],
};
62 changes: 56 additions & 6 deletions typescript/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading