Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/milvus_haystack/milvus_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusEmbeddingRetriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float], top_k: Optional[int] = None) -> Dict[str, List[Document]]:
def run(
self, query_embedding: List[float], top_k: Optional[int] = None, filters: Optional[Dict[str, Any]] = None
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the `MilvusDocumentStore`, based on their dense embeddings.

:param query_embedding: Embedding of the query.
:param top_k: Optional number of documents to retrieve. If provided, overrides the top_k
set during initialization.
:param filters: Optional filters to apply at runtime. If provided, overrides the filters
set during initialization.
:return: List of Document similar to `query_embedding`.
"""
docs = self.document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self.filters,
top_k=top_k or self.top_k,
filters=filters or self.filters,
)
return {"documents": docs}

Expand Down Expand Up @@ -125,17 +131,23 @@ def run(
query_sparse_embedding: Optional[SparseEmbedding] = None,
query_text: Optional[str] = None,
top_k: Optional[int] = None,
filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, List[Document]]:
"""
Retrieve documents from the `MilvusDocumentStore`, based on their sparse embeddings.

:param query_sparse_embedding: Sparse Embedding of the query.
:param query_text: Optional text query for sparse retrieval.
:param top_k: Optional number of documents to retrieve. If provided, overrides the top_k
set during initialization.
:param filters: Optional filters to apply at runtime. If provided, overrides the filters
set during initialization.
:return: List of Document similar to `query_embedding`.
"""
docs = self.document_store._sparse_embedding_retrieval(
query_sparse_embedding=query_sparse_embedding,
filters=self.filters,
top_k=top_k or self.top_k,
filters=filters or self.filters,
query_text=query_text,
)
return {"documents": docs}
Expand Down Expand Up @@ -221,18 +233,24 @@ def run(
query_sparse_embedding: Optional[SparseEmbedding] = None,
query_text: Optional[str] = None,
top_k: Optional[int] = None,
filters: Optional[Dict[str, Any]] = None,
):
"""
Retrieve documents from the `MilvusDocumentStore`, based on their dense and sparse embeddings.

:param query_embedding: Dense Embedding of the query.
:param query_sparse_embedding: Sparse Embedding of the query.
:param query_text: Optional text query for sparse retrieval.
:param top_k: Optional number of documents to retrieve. If provided, overrides the top_k
set during initialization.
:param filters: Optional filters to apply at runtime. If provided, overrides the filters
set during initialization.
:return: List of Document similar to `query_embedding`.
"""
docs = self.document_store._hybrid_retrieval(
query_embedding=query_embedding,
query_sparse_embedding=query_sparse_embedding,
filters=self.filters,
filters=filters or self.filters,
top_k=top_k or self.top_k,
reranker=self.reranker,
query_text=query_text,
Expand Down
121 changes: 120 additions & 1 deletion tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import fields
from typing import List
from typing import Any, Dict, List

import numpy as np
import pytest
Expand Down Expand Up @@ -77,6 +77,42 @@ def test_run(self, document_store: MilvusDocumentStore):
assert len(res["documents"]) == 10
assert_docs_equal_except_score(res["documents"][0], documents[5])

def test_run_using_filters(self, document_store: MilvusDocumentStore):
"""Test that filters are properly applied at runtime"""
documents = []
for i in range(10):
doc = Document(
content=f"Document {i}",
meta={
"name": f"name_{i}",
"page": str(100 + i),
"chapter": "intro" if i < 5 else "outro",
"number": i,
"date": "1969-07-21T20:17:40",
},
embedding=l2_normalization([0.5] * 63 + [0.1 * i]),
)
documents.append(doc)
document_store.write_documents(documents)

# Test with runtime filters
retriever = MilvusEmbeddingRetriever(document_store)
query_embedding = l2_normalization([0.5] * 64)

# Filter: chapter == "intro" (should return 5 documents)
filters: Dict[str, Any] = {"field": "chapter", "operator": "==", "value": "intro"}
res = retriever.run(query_embedding, filters=filters)
assert len(res["documents"]) == 5
for doc in res["documents"]:
assert doc.meta["chapter"] == "intro"

# Filter: number >= 5 (should return 5 documents)
filters = {"field": "number", "operator": ">=", "value": 5} # type: ignore[no-redef]
res = retriever.run(query_embedding, filters=filters)
assert len(res["documents"]) == 5
for doc in res["documents"]:
assert doc.meta["number"] >= 5

def test_to_dict(self, document_store: MilvusDocumentStore):
expected_dict = {
"type": "src.milvus_haystack.document_store.MilvusDocumentStore",
Expand Down Expand Up @@ -210,6 +246,43 @@ def test_run(self, document_store: MilvusDocumentStore, documents: List[Document
assert len(res["documents"]) == 10
assert_docs_equal_except_score(res["documents"][0], documents[5])

def test_run_using_filters(self, document_store: MilvusDocumentStore):
"""Test that filters are properly applied at runtime for sparse retrieval"""
documents = []
for i in range(10):
doc = Document(
content=f"Document {i}",
meta={
"name": f"name_{i}",
"page": str(100 + i),
"chapter": "intro" if i < 5 else "outro",
"number": i,
"date": "1969-07-21T20:17:40",
},
embedding=l2_normalization([0.5] * 64),
sparse_embedding=SparseEmbedding(indices=[0, 1, 2 + i], values=[1.0, 2.0, 3.0]),
)
documents.append(doc)
document_store.write_documents(documents)

# Test with runtime filters
retriever = MilvusSparseEmbeddingRetriever(document_store)
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2 + 5], values=[1.0, 2.0, 3.0])

# Filter: chapter == "outro" (should return 5 documents)
filters: Dict[str, Any] = {"field": "chapter", "operator": "==", "value": "outro"}
res = retriever.run(sparse_query_embedding, filters=filters)
assert len(res["documents"]) == 5
for doc in res["documents"]:
assert doc.meta["chapter"] == "outro"

# Filter: number < 3 (should return 3 documents)
filters = {"field": "number", "operator": "<", "value": 3} # type: ignore[no-redef]
res = retriever.run(sparse_query_embedding, filters=filters)
assert len(res["documents"]) == 3
for doc in res["documents"]:
assert doc.meta["number"] < 3

def test_fail_without_sparse_field(self, documents: List[Document]):
document_store = MilvusDocumentStore(
connection_args=DEFAULT_CONNECTION_ARGS,
Expand Down Expand Up @@ -366,6 +439,52 @@ def test_run(self, document_store: MilvusDocumentStore, documents: List[Document
assert len(res["documents"]) == 10
assert_docs_equal_except_score(res["documents"][0], documents[5])

def test_run_using_filters(self, document_store: MilvusDocumentStore):
"""Test that filters are properly applied at runtime for hybrid retrieval"""
documents = []
for i in range(10):
doc = Document(
content=f"Hybrid Document {i}",
meta={
"name": f"name_{i}",
"page": str(100 + i),
"chapter": "intro" if i < 5 else "outro",
"number": i,
"date": "1969-07-21T20:17:40",
},
embedding=l2_normalization([0.5] * 63 + [0.45 + 0.01 * i]),
sparse_embedding=SparseEmbedding(indices=[0, 1, 2 + i], values=[1.0, 2.0, 3.0]),
)
documents.append(doc)
document_store.write_documents(documents)

# Test with runtime filters
retriever = MilvusHybridRetriever(document_store)
query_embedding = l2_normalization([0.5] * 64)
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2 + 5], values=[1.0, 2.0, 3.0])

# Filter: chapter == "intro" (should return 5 documents)
filters: Dict[str, Any] = {"field": "chapter", "operator": "==", "value": "intro"}
res = retriever.run(
query_embedding=query_embedding,
query_sparse_embedding=sparse_query_embedding,
filters=filters,
)
assert len(res["documents"]) == 5
for doc in res["documents"]:
assert doc.meta["chapter"] == "intro"

# Filter: number in [2, 4, 6, 8] (should return 4 documents)
filters = {"field": "number", "operator": "in", "value": [2, 4, 6, 8]} # type: ignore[no-redef]
res = retriever.run(
query_embedding=query_embedding,
query_sparse_embedding=sparse_query_embedding,
filters=filters,
)
assert len(res["documents"]) == 4
for doc in res["documents"]:
assert doc.meta["number"] in [2, 4, 6, 8]

def test_fail_without_sparse_field(self, documents: List[Document]):
document_store = MilvusDocumentStore(
connection_args=DEFAULT_CONNECTION_ARGS,
Expand Down