Skip to content

Commit 8372632

Browse files
committed
Show tqdm output, fail on empty results
1 parent c017328 commit 8372632

File tree

2 files changed

+41
-20
lines changed

2 files changed

+41
-20
lines changed

pinecone/grpc/index_grpc_asyncio.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,33 +264,46 @@ async def composite_query(
264264
include_values: Optional[bool] = None,
265265
include_metadata: Optional[bool] = None,
266266
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
267+
show_progress: Optional[bool] = True,
267268
max_concurrent_requests: Optional[int] = None,
268269
semaphore: Optional[asyncio.Semaphore] = None,
269270
**kwargs,
270271
) -> Awaitable[CompositeQueryResults]:
271272
aggregator_lock = asyncio.Lock()
272273
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)
274+
275+
# The caller may only want the topK=1 result across all queries,
276+
# but we need to get at least 2 results from each query in order to
277+
# aggregate them correctly. So we'll temporarily set topK to 2 for the
278+
# subqueries, and then we'll take the topK=1 results from the aggregated
279+
# results.
273280
aggregator = QueryResultsAggregator(top_k=top_k)
281+
subquery_topk = top_k if top_k > 2 else 2
274282

283+
target_namespaces = set(namespaces) # dedup namespaces
275284
query_tasks = [
276285
self._query(
277286
vector=vector,
278287
namespace=ns,
279-
top_k=top_k,
288+
top_k=subquery_topk,
280289
filter=filter,
281290
include_values=include_values,
282291
include_metadata=include_metadata,
283292
sparse_vector=sparse_vector,
284293
semaphore=semaphore,
285294
**kwargs,
286295
)
287-
for ns in namespaces
296+
for ns in target_namespaces
288297
]
289298

290-
for query_task in asyncio.as_completed(query_tasks):
291-
response = await query_task
292-
async with aggregator_lock:
293-
aggregator.add_results(response)
299+
with tqdm(
300+
total=len(query_tasks), disable=not show_progress, desc="Querying namespaces"
301+
) as pbar:
302+
for query_task in asyncio.as_completed(query_tasks):
303+
response = await query_task
304+
pbar.update(1)
305+
async with aggregator_lock:
306+
aggregator.add_results(response)
294307

295308
final_results = aggregator.get_results()
296309
return final_results

pinecone/grpc/query_results_aggregator.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __repr__(self):
8484
class QueryResultsAggregationEmptyResultsError(Exception):
8585
def __init__(self, namespace: str):
8686
super().__init__(
87-
f"Cannot infer metric type from empty query results. Query result for namespace '{namespace}' is empty. Have you spelled the namespace name correctly?"
87+
f"Query results for namespace '{namespace}' were empty. Check that you have upserted vectors into this namespace (see describe_index_stats) and that the namespace name is spelled correctly."
8888
)
8989

9090

@@ -111,7 +111,7 @@ def __init__(self, top_k: int):
111111
self.is_dotproduct = None
112112
self.read = False
113113

114-
def __is_dotproduct_index(self, matches):
114+
def _is_dotproduct_index(self, matches):
115115
# The interpretation of the score depends on the similar metric used.
116116
# Unlike other index types, in indexes configured for dotproduct,
117117
# a higher score is better. We have to infer this is the case by inspecting
@@ -121,6 +121,20 @@ def __is_dotproduct_index(self, matches):
121121
return False
122122
return True
123123

124+
def _dotproduct_heap_item(self, match, ns):
125+
return (match.get("score"), -self.insertion_counter, match, ns)
126+
127+
def _non_dotproduct_heap_item(self, match, ns):
128+
return (-match.get("score"), -self.insertion_counter, match, ns)
129+
130+
def _process_matches(self, matches, ns, heap_item_fn):
131+
for match in matches:
132+
self.insertion_counter += 1
133+
if len(self.heap) < self.top_k:
134+
heapq.heappush(self.heap, heap_item_fn(match, ns))
135+
else:
136+
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
137+
124138
def add_results(self, results: QueryResponse):
125139
if self.read:
126140
# This is mainly just to sanity check in test cases which get quite confusing
@@ -132,24 +146,18 @@ def add_results(self, results: QueryResponse):
132146
ns = results.get("namespace")
133147
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
134148

149+
if len(matches) == 0:
150+
raise QueryResultsAggregationEmptyResultsError(ns)
151+
135152
if self.is_dotproduct is None:
136-
if len(matches) == 0:
137-
raise QueryResultsAggregationEmptyResultsError(ns)
138153
if len(matches) == 1:
139154
raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches))
140-
self.is_dotproduct = self.__is_dotproduct_index(matches)
155+
self.is_dotproduct = self._is_dotproduct_index(matches)
141156

142-
print("is_dotproduct:", self.is_dotproduct)
143157
if self.is_dotproduct:
144-
raise NotImplementedError("Dotproduct indexes are not yet supported.")
158+
self._process_matches(matches, ns, self._dotproduct_heap_item)
145159
else:
146-
for match in matches:
147-
self.insertion_counter += 1
148-
score = match.get("score")
149-
if len(self.heap) < self.top_k:
150-
heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns))
151-
else:
152-
heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns))
160+
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
153161

154162
def get_results(self) -> CompositeQueryResults:
155163
if self.read:

0 commit comments

Comments
 (0)