Skip to content

Commit 42ba054

Browse files
authored
Merge pull request #70 from zc277584121/main
feat: add runtime filter
2 parents 4e86337 + cd9fd51 commit 42ba054

File tree

2 files changed

+142
-5
lines changed

2 files changed

+142
-5
lines changed

src/milvus_haystack/milvus_embedding_retriever.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusEmbeddingRetriever":
5757
return default_from_dict(cls, data)
5858

5959
@component.output_types(documents=List[Document])
60-
def run(self, query_embedding: List[float], top_k: Optional[int] = None) -> Dict[str, List[Document]]:
60+
def run(
61+
self, query_embedding: List[float], top_k: Optional[int] = None, filters: Optional[Dict[str, Any]] = None
62+
) -> Dict[str, List[Document]]:
6163
"""
6264
Retrieve documents from the `MilvusDocumentStore`, based on their dense embeddings.
6365
6466
:param query_embedding: Embedding of the query.
67+
:param top_k: Optional number of documents to retrieve. If provided, overrides the top_k
68+
set during initialization.
69+
:param filters: Optional filters to apply at runtime. If provided, overrides the filters
70+
set during initialization.
6571
:return: List of Document similar to `query_embedding`.
6672
"""
6773
docs = self.document_store._embedding_retrieval(
6874
query_embedding=query_embedding,
69-
filters=self.filters,
7075
top_k=top_k or self.top_k,
76+
filters=filters or self.filters,
7177
)
7278
return {"documents": docs}
7379

@@ -125,17 +131,23 @@ def run(
125131
query_sparse_embedding: Optional[SparseEmbedding] = None,
126132
query_text: Optional[str] = None,
127133
top_k: Optional[int] = None,
134+
filters: Optional[Dict[str, Any]] = None,
128135
) -> Dict[str, List[Document]]:
129136
"""
130137
Retrieve documents from the `MilvusDocumentStore`, based on their sparse embeddings.
131138
132139
:param query_sparse_embedding: Sparse Embedding of the query.
140+
:param query_text: Optional text query for sparse retrieval.
141+
:param top_k: Optional number of documents to retrieve. If provided, overrides the top_k
142+
set during initialization.
143+
:param filters: Optional filters to apply at runtime. If provided, overrides the filters
144+
set during initialization.
133145
:return: List of Document similar to `query_embedding`.
134146
"""
135147
docs = self.document_store._sparse_embedding_retrieval(
136148
query_sparse_embedding=query_sparse_embedding,
137-
filters=self.filters,
138149
top_k=top_k or self.top_k,
150+
filters=filters or self.filters,
139151
query_text=query_text,
140152
)
141153
return {"documents": docs}
@@ -221,18 +233,24 @@ def run(
221233
query_sparse_embedding: Optional[SparseEmbedding] = None,
222234
query_text: Optional[str] = None,
223235
top_k: Optional[int] = None,
236+
filters: Optional[Dict[str, Any]] = None,
224237
):
225238
"""
226239
Retrieve documents from the `MilvusDocumentStore`, based on their dense and sparse embeddings.
227240
228241
:param query_embedding: Dense Embedding of the query.
229242
:param query_sparse_embedding: Sparse Embedding of the query.
243+
:param query_text: Optional text query for sparse retrieval.
244+
:param top_k: Optional number of documents to retrieve. If provided, overrides the top_k
245+
set during initialization.
246+
:param filters: Optional filters to apply at runtime. If provided, overrides the filters
247+
set during initialization.
230248
:return: List of Document similar to `query_embedding`.
231249
"""
232250
docs = self.document_store._hybrid_retrieval(
233251
query_embedding=query_embedding,
234252
query_sparse_embedding=query_sparse_embedding,
235-
filters=self.filters,
253+
filters=filters or self.filters,
236254
top_k=top_k or self.top_k,
237255
reranker=self.reranker,
238256
query_text=query_text,

tests/test_embedding_retriever.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from dataclasses import fields
3-
from typing import List
3+
from typing import Any, Dict, List
44

55
import numpy as np
66
import pytest
@@ -77,6 +77,42 @@ def test_run(self, document_store: MilvusDocumentStore):
7777
assert len(res["documents"]) == 10
7878
assert_docs_equal_except_score(res["documents"][0], documents[5])
7979

80+
def test_run_using_filters(self, document_store: MilvusDocumentStore):
81+
"""Test that filters are properly applied at runtime"""
82+
documents = []
83+
for i in range(10):
84+
doc = Document(
85+
content=f"Document {i}",
86+
meta={
87+
"name": f"name_{i}",
88+
"page": str(100 + i),
89+
"chapter": "intro" if i < 5 else "outro",
90+
"number": i,
91+
"date": "1969-07-21T20:17:40",
92+
},
93+
embedding=l2_normalization([0.5] * 63 + [0.1 * i]),
94+
)
95+
documents.append(doc)
96+
document_store.write_documents(documents)
97+
98+
# Test with runtime filters
99+
retriever = MilvusEmbeddingRetriever(document_store)
100+
query_embedding = l2_normalization([0.5] * 64)
101+
102+
# Filter: chapter == "intro" (should return 5 documents)
103+
filters: Dict[str, Any] = {"field": "chapter", "operator": "==", "value": "intro"}
104+
res = retriever.run(query_embedding, filters=filters)
105+
assert len(res["documents"]) == 5
106+
for doc in res["documents"]:
107+
assert doc.meta["chapter"] == "intro"
108+
109+
# Filter: number >= 5 (should return 5 documents)
110+
filters = {"field": "number", "operator": ">=", "value": 5} # type: ignore[no-redef]
111+
res = retriever.run(query_embedding, filters=filters)
112+
assert len(res["documents"]) == 5
113+
for doc in res["documents"]:
114+
assert doc.meta["number"] >= 5
115+
80116
def test_to_dict(self, document_store: MilvusDocumentStore):
81117
expected_dict = {
82118
"type": "src.milvus_haystack.document_store.MilvusDocumentStore",
@@ -210,6 +246,43 @@ def test_run(self, document_store: MilvusDocumentStore, documents: List[Document
210246
assert len(res["documents"]) == 10
211247
assert_docs_equal_except_score(res["documents"][0], documents[5])
212248

249+
def test_run_using_filters(self, document_store: MilvusDocumentStore):
250+
"""Test that filters are properly applied at runtime for sparse retrieval"""
251+
documents = []
252+
for i in range(10):
253+
doc = Document(
254+
content=f"Document {i}",
255+
meta={
256+
"name": f"name_{i}",
257+
"page": str(100 + i),
258+
"chapter": "intro" if i < 5 else "outro",
259+
"number": i,
260+
"date": "1969-07-21T20:17:40",
261+
},
262+
embedding=l2_normalization([0.5] * 64),
263+
sparse_embedding=SparseEmbedding(indices=[0, 1, 2 + i], values=[1.0, 2.0, 3.0]),
264+
)
265+
documents.append(doc)
266+
document_store.write_documents(documents)
267+
268+
# Test with runtime filters
269+
retriever = MilvusSparseEmbeddingRetriever(document_store)
270+
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2 + 5], values=[1.0, 2.0, 3.0])
271+
272+
# Filter: chapter == "outro" (should return 5 documents)
273+
filters: Dict[str, Any] = {"field": "chapter", "operator": "==", "value": "outro"}
274+
res = retriever.run(sparse_query_embedding, filters=filters)
275+
assert len(res["documents"]) == 5
276+
for doc in res["documents"]:
277+
assert doc.meta["chapter"] == "outro"
278+
279+
# Filter: number < 3 (should return 3 documents)
280+
filters = {"field": "number", "operator": "<", "value": 3} # type: ignore[no-redef]
281+
res = retriever.run(sparse_query_embedding, filters=filters)
282+
assert len(res["documents"]) == 3
283+
for doc in res["documents"]:
284+
assert doc.meta["number"] < 3
285+
213286
def test_fail_without_sparse_field(self, documents: List[Document]):
214287
document_store = MilvusDocumentStore(
215288
connection_args=DEFAULT_CONNECTION_ARGS,
@@ -366,6 +439,52 @@ def test_run(self, document_store: MilvusDocumentStore, documents: List[Document
366439
assert len(res["documents"]) == 10
367440
assert_docs_equal_except_score(res["documents"][0], documents[5])
368441

442+
def test_run_using_filters(self, document_store: MilvusDocumentStore):
443+
"""Test that filters are properly applied at runtime for hybrid retrieval"""
444+
documents = []
445+
for i in range(10):
446+
doc = Document(
447+
content=f"Hybrid Document {i}",
448+
meta={
449+
"name": f"name_{i}",
450+
"page": str(100 + i),
451+
"chapter": "intro" if i < 5 else "outro",
452+
"number": i,
453+
"date": "1969-07-21T20:17:40",
454+
},
455+
embedding=l2_normalization([0.5] * 63 + [0.45 + 0.01 * i]),
456+
sparse_embedding=SparseEmbedding(indices=[0, 1, 2 + i], values=[1.0, 2.0, 3.0]),
457+
)
458+
documents.append(doc)
459+
document_store.write_documents(documents)
460+
461+
# Test with runtime filters
462+
retriever = MilvusHybridRetriever(document_store)
463+
query_embedding = l2_normalization([0.5] * 64)
464+
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2 + 5], values=[1.0, 2.0, 3.0])
465+
466+
# Filter: chapter == "intro" (should return 5 documents)
467+
filters: Dict[str, Any] = {"field": "chapter", "operator": "==", "value": "intro"}
468+
res = retriever.run(
469+
query_embedding=query_embedding,
470+
query_sparse_embedding=sparse_query_embedding,
471+
filters=filters,
472+
)
473+
assert len(res["documents"]) == 5
474+
for doc in res["documents"]:
475+
assert doc.meta["chapter"] == "intro"
476+
477+
# Filter: number in [2, 4, 6, 8] (should return 4 documents)
478+
filters = {"field": "number", "operator": "in", "value": [2, 4, 6, 8]} # type: ignore[no-redef]
479+
res = retriever.run(
480+
query_embedding=query_embedding,
481+
query_sparse_embedding=sparse_query_embedding,
482+
filters=filters,
483+
)
484+
assert len(res["documents"]) == 4
485+
for doc in res["documents"]:
486+
assert doc.meta["number"] in [2, 4, 6, 8]
487+
369488
def test_fail_without_sparse_field(self, documents: List[Document]):
370489
document_store = MilvusDocumentStore(
371490
connection_args=DEFAULT_CONNECTION_ARGS,

0 commit comments

Comments
 (0)