Skip to content

Commit e668c89

Browse files
authored
Add query_namespaces (#409)
## Problem Sometimes people would like to run a query across multiple namespaces ## Solution Run a query for each namespace in parallel, then merge the results using a heap ```python from pinecone import Pinecone import random pc = Pinecone(api_key='api-key') index = pc.Index( host="https://indexhost/", pool_threads=10 ) query_vec = [random.random()] * dimension combined_results = index.query_namespaces( vector=query_vec, namespaces=["ns1", "ns2", "ns3", "ns4"], include_values=False, include_metadata=True, filter={"publication_date": {"$eq":"Last3Months"}}, top_k=100 ) ``` ## TODO A grpc implementation of this will follow in a separate PR. I have WIP on it, but some mypy type issues were causing me headaches and I'd rather land this stuff first. ## Type of Change - [x] New feature (non-breaking change which adds functionality) ## Test Plan Added integration tests
1 parent b4bfae8 commit e668c89

File tree

7 files changed

+1120
-26
lines changed

7 files changed

+1120
-26
lines changed

pinecone/core/openapi/shared/api_client.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@
88
import typing
99
from urllib.parse import quote
1010
from urllib3.fields import RequestField
11+
import time
12+
import random
13+
14+
def retry_api_call(
15+
func, args=(), kwargs={}, retries=3, backoff=1, jitter=0.5
16+
):
17+
attempts = 0
18+
while attempts < retries:
19+
try:
20+
return func(*args, **kwargs) # Attempt to call __call_api
21+
except Exception as e:
22+
attempts += 1
23+
if attempts >= retries:
24+
print(f"API call failed after {attempts} attempts: {e}")
25+
raise # Re-raise exception if retries are exhausted
26+
sleep_time = backoff * (2 ** (attempts - 1)) + random.uniform(0, jitter)
27+
# print(f"Retrying ({attempts}/{retries}) in {sleep_time:.2f} seconds after error: {e}")
28+
time.sleep(sleep_time)
1129

1230

1331
from pinecone.core.openapi.shared import rest
@@ -397,25 +415,32 @@ def call_api(
397415
)
398416

399417
return self.pool.apply_async(
400-
self.__call_api,
401-
(
402-
resource_path,
403-
method,
404-
path_params,
405-
query_params,
406-
header_params,
407-
body,
408-
post_params,
409-
files,
410-
response_type,
411-
auth_settings,
412-
_return_http_data_only,
413-
collection_formats,
414-
_preload_content,
415-
_request_timeout,
416-
_host,
417-
_check_type,
418-
),
418+
retry_api_call,
419+
args=(
420+
self.__call_api, # Pass the API call function as the first argument
421+
(
422+
resource_path,
423+
method,
424+
path_params,
425+
query_params,
426+
header_params,
427+
body,
428+
post_params,
429+
files,
430+
response_type,
431+
auth_settings,
432+
_return_http_data_only,
433+
collection_formats,
434+
_preload_content,
435+
_request_timeout,
436+
_host,
437+
_check_type,
438+
),
439+
{}, # empty kwargs dictionary
440+
3, # retries
441+
1, # backoff time
442+
0.5 # jitter
443+
)
419444
)
420445

421446
def request(

pinecone/data/index.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
)
3434
from .features.bulk_import import ImportFeatureMixin
3535
from .vector_factory import VectorFactory
36+
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
37+
from multiprocessing.pool import ApplyResult
3638

3739
from pinecone_plugin_interface import load_and_install as install_plugins
3840

@@ -387,7 +389,7 @@ def query(
387389
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
388390
] = None,
389391
**kwargs,
390-
) -> QueryResponse:
392+
) -> Union[QueryResponse, ApplyResult]:
391393
"""
392394
The Query operation searches a namespace, using a query vector.
393395
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
@@ -429,6 +431,39 @@ def query(
429431
and namespace name.
430432
"""
431433

434+
response = self._query(
435+
*args,
436+
top_k=top_k,
437+
vector=vector,
438+
id=id,
439+
namespace=namespace,
440+
filter=filter,
441+
include_values=include_values,
442+
include_metadata=include_metadata,
443+
sparse_vector=sparse_vector,
444+
**kwargs,
445+
)
446+
447+
if kwargs.get("async_req", False):
448+
return response
449+
else:
450+
return parse_query_response(response)
451+
452+
def _query(
453+
self,
454+
*args,
455+
top_k: int,
456+
vector: Optional[List[float]] = None,
457+
id: Optional[str] = None,
458+
namespace: Optional[str] = None,
459+
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
460+
include_values: Optional[bool] = None,
461+
include_metadata: Optional[bool] = None,
462+
sparse_vector: Optional[
463+
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
464+
] = None,
465+
**kwargs,
466+
) -> QueryResponse:
432467
if len(args) > 0:
433468
raise ValueError(
434469
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
@@ -461,7 +496,52 @@ def query(
461496
),
462497
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
463498
)
464-
return parse_query_response(response)
499+
return response
500+
501+
@validate_and_convert_errors
502+
def query_namespaces(
503+
self,
504+
vector: List[float],
505+
namespaces: List[str],
506+
top_k: Optional[int] = None,
507+
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
508+
include_values: Optional[bool] = None,
509+
include_metadata: Optional[bool] = None,
510+
sparse_vector: Optional[
511+
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
512+
] = None,
513+
**kwargs,
514+
) -> QueryNamespacesResults:
515+
if namespaces is None or len(namespaces) == 0:
516+
raise ValueError("At least one namespace must be specified")
517+
if len(vector) == 0:
518+
raise ValueError("Query vector must not be empty")
519+
520+
overall_topk = top_k if top_k is not None else 10
521+
aggregator = QueryResultsAggregator(top_k=overall_topk)
522+
523+
target_namespaces = set(namespaces) # dedup namespaces
524+
async_results = [
525+
self.query(
526+
vector=vector,
527+
namespace=ns,
528+
top_k=overall_topk,
529+
filter=filter,
530+
include_values=include_values,
531+
include_metadata=include_metadata,
532+
sparse_vector=sparse_vector,
533+
async_req=True,
534+
**kwargs,
535+
)
536+
for ns in target_namespaces
537+
]
538+
539+
for result in async_results:
540+
response = result.get()
541+
aggregator.add_results(response)
542+
543+
final_results = aggregator.get_results()
544+
return final_results
465545

466546
@validate_and_convert_errors
467547
def update(
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from typing import List, Tuple, Optional, Any, Dict
2+
import json
3+
import heapq
4+
from pinecone.core.openapi.data.models import Usage
5+
from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse
6+
7+
from dataclasses import dataclass, asdict
8+
9+
10+
@dataclass
11+
class ScoredVectorWithNamespace:
12+
namespace: str
13+
score: float
14+
id: str
15+
values: List[float]
16+
sparse_values: dict
17+
metadata: dict
18+
19+
def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
20+
json_vector = aggregate_results_heap_tuple[2]
21+
self.namespace = aggregate_results_heap_tuple[3]
22+
self.id = json_vector.get("id") # type: ignore
23+
self.score = json_vector.get("score") # type: ignore
24+
self.values = json_vector.get("values") # type: ignore
25+
self.sparse_values = json_vector.get("sparse_values", None) # type: ignore
26+
self.metadata = json_vector.get("metadata", None) # type: ignore
27+
28+
def __getitem__(self, key):
29+
if hasattr(self, key):
30+
return getattr(self, key)
31+
else:
32+
raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")
33+
34+
def get(self, key, default=None):
35+
return getattr(self, key, default)
36+
37+
def __repr__(self):
38+
return json.dumps(self._truncate(asdict(self)), indent=4)
39+
40+
def __json__(self):
41+
return self._truncate(asdict(self))
42+
43+
def _truncate(self, obj, max_items=2):
44+
"""
45+
Recursively traverse and truncate lists that exceed max_items length.
46+
Only display the "... X more" message if at least 2 elements are hidden.
47+
"""
48+
if obj is None:
49+
return None # Skip None values
50+
elif isinstance(obj, list):
51+
filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
52+
if len(filtered_list) > max_items:
53+
# Show the truncation message only if more than 1 item is hidden
54+
remaining_items = len(filtered_list) - max_items
55+
if remaining_items > 1:
56+
return filtered_list[:max_items] + [f"... {remaining_items} more"]
57+
else:
58+
# If only 1 item remains, show it
59+
return filtered_list
60+
return filtered_list
61+
elif isinstance(obj, dict):
62+
# Recursively process dictionaries, omitting None values
63+
return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
64+
return obj
65+
66+
67+
@dataclass
68+
class QueryNamespacesResults:
69+
usage: Usage
70+
matches: List[ScoredVectorWithNamespace]
71+
72+
def __getitem__(self, key):
73+
if hasattr(self, key):
74+
return getattr(self, key)
75+
else:
76+
raise KeyError(f"'{key}' not found in QueryNamespacesResults")
77+
78+
def get(self, key, default=None):
79+
return getattr(self, key, default)
80+
81+
def __repr__(self):
82+
return json.dumps(
83+
{
84+
"usage": self.usage.to_dict(),
85+
"matches": [match.__json__() for match in self.matches],
86+
},
87+
indent=4,
88+
)
89+
90+
91+
class QueryResultsAggregregatorNotEnoughResultsError(Exception):
92+
def __init__(self):
93+
super().__init__(
94+
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
95+
)
96+
97+
98+
class QueryResultsAggregatorInvalidTopKError(Exception):
99+
def __init__(self, top_k: int):
100+
super().__init__(
101+
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
102+
)
103+
104+
105+
class QueryResultsAggregator:
106+
def __init__(self, top_k: int):
107+
if top_k < 2:
108+
raise QueryResultsAggregatorInvalidTopKError(top_k)
109+
self.top_k = top_k
110+
self.usage_read_units = 0
111+
self.heap: List[Tuple[float, int, object, str]] = []
112+
self.insertion_counter = 0
113+
self.is_dotproduct = None
114+
self.read = False
115+
self.final_results: Optional[QueryNamespacesResults] = None
116+
117+
def _is_dotproduct_index(self, matches):
118+
# The interpretation of the score depends on the similar metric used.
119+
# Unlike other index types, in indexes configured for dotproduct,
120+
# a higher score is better. We have to infer this is the case by inspecting
121+
# the order of the scores in the results.
122+
for i in range(1, len(matches)):
123+
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
124+
return False
125+
return True
126+
127+
def _dotproduct_heap_item(self, match, ns):
128+
return (match.get("score"), -self.insertion_counter, match, ns)
129+
130+
def _non_dotproduct_heap_item(self, match, ns):
131+
return (-match.get("score"), -self.insertion_counter, match, ns)
132+
133+
def _process_matches(self, matches, ns, heap_item_fn):
134+
for match in matches:
135+
self.insertion_counter += 1
136+
if len(self.heap) < self.top_k:
137+
heapq.heappush(self.heap, heap_item_fn(match, ns))
138+
else:
139+
# Assume we have dotproduct scores sorted in descending order
140+
if self.is_dotproduct and match["score"] < self.heap[0][0]:
141+
# No further matches can improve the top-K heap
142+
break
143+
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
144+
# No further matches can improve the top-K heap
145+
break
146+
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
147+
148+
def add_results(self, results: Dict[str, Any]):
149+
if self.read:
150+
# This is mainly just to sanity check in test cases which get quite confusing
151+
# if you read results twice due to the heap being emptied when constructing
152+
# the ordered results.
153+
raise ValueError("Results have already been read. Cannot add more results.")
154+
155+
matches = results.get("matches", [])
156+
ns: str = results.get("namespace", "")
157+
if isinstance(results, OpenAPIQueryResponse):
158+
self.usage_read_units += results.usage.read_units
159+
else:
160+
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
161+
162+
if len(matches) == 0:
163+
return
164+
165+
if self.is_dotproduct is None:
166+
if len(matches) == 1:
167+
# This condition should match the second time we add results containing
168+
# only one match. We need at least two matches in a single response in order
169+
# to infer the similarity metric
170+
raise QueryResultsAggregregatorNotEnoughResultsError()
171+
self.is_dotproduct = self._is_dotproduct_index(matches)
172+
173+
if self.is_dotproduct:
174+
self._process_matches(matches, ns, self._dotproduct_heap_item)
175+
else:
176+
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
177+
178+
def get_results(self) -> QueryNamespacesResults:
179+
if self.read:
180+
if self.final_results is not None:
181+
return self.final_results
182+
else:
183+
# I don't think this branch can ever actually be reached, but the type checker disagrees
184+
raise ValueError("Results have already been read. Cannot get results again.")
185+
self.read = True
186+
187+
self.final_results = QueryNamespacesResults(
188+
usage=Usage(read_units=self.usage_read_units),
189+
matches=[
190+
ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
191+
][::-1],
192+
)
193+
return self.final_results

0 commit comments

Comments
 (0)