Skip to content

Commit c017328

Browse files
committed
More test cases for aggregator
1 parent af3f614 commit c017328

File tree

5 files changed

+176
-83
lines changed

5 files changed

+176
-83
lines changed

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ venv.bak/
137137
.ropeproject
138138

139139
# pdocs documentation
140-
# We want to exclude any locally generated artifacts, but we rely on
140+
# We want to exclude any locally generated artifacts, but we rely on
141141
# keeping documentation assets in the docs/ folder.
142142
docs/*
143143
!docs/pinecone-python-client-fork.png
@@ -155,4 +155,5 @@ dmypy.json
155155
*.hdf5
156156
*~
157157

158-
tests/integration/proxy_config/logs
158+
tests/integration/proxy_config/logs
159+
*.parquet

pinecone/grpc/index_grpc_asyncio.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,18 @@ async def upsert(
112112
for batch in vector_batches
113113
]
114114

115-
return await tqdm.gather(*tasks, disable=not show_progress, desc="Upserted batches")
115+
if namespace is not None:
116+
pbar_desc = f"Upserted vectors in namespace '{namespace}'"
117+
else:
118+
pbar_desc = "Upserted vectors in namespace ''"
119+
120+
upserted_count = 0
121+
with tqdm(total=len(vectors), disable=not show_progress, desc=pbar_desc) as pbar:
122+
for task in asyncio.as_completed(tasks):
123+
res = await task
124+
pbar.update(res.upserted_count)
125+
upserted_count += res.upserted_count
126+
return UpsertResponse(upserted_count=upserted_count)
116127

117128
async def _upsert_batch(
118129
self,
@@ -173,12 +184,12 @@ async def _query(
173184
)
174185
return json_format.MessageToDict(response)
175186

176-
async def composite_query(
187+
async def query(
177188
self,
178189
vector: Optional[List[float]] = None,
179190
id: Optional[str] = None,
180191
namespace: Optional[str] = None,
181-
top_k: Optional[int] = None,
192+
top_k: Optional[int] = 10,
182193
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
183194
include_values: Optional[bool] = None,
184195
include_metadata: Optional[bool] = None,
@@ -244,12 +255,11 @@ async def composite_query(
244255
)
245256
return parse_query_response(json_response, _check_type=False)
246257

247-
async def multi_namespace_query(
258+
async def composite_query(
248259
self,
249260
vector: Optional[List[float]] = None,
250-
id: Optional[str] = None,
251-
namespaces: Optional[str] = None,
252-
top_k: Optional[int] = None,
261+
namespaces: Optional[List[str]] = None,
262+
top_k: Optional[int] = 10,
253263
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
254264
include_values: Optional[bool] = None,
255265
include_metadata: Optional[bool] = None,
@@ -258,12 +268,13 @@ async def multi_namespace_query(
258268
semaphore: Optional[asyncio.Semaphore] = None,
259269
**kwargs,
260270
) -> Awaitable[CompositeQueryResults]:
271+
aggregator_lock = asyncio.Lock()
261272
semaphore = self._get_semaphore(max_concurrent_requests, semaphore)
273+
aggregator = QueryResultsAggregator(top_k=top_k)
262274

263-
queries = [
275+
query_tasks = [
264276
self._query(
265277
vector=vector,
266-
id=id,
267278
namespace=ns,
268279
top_k=top_k,
269280
filter=filter,
@@ -276,13 +287,11 @@ async def multi_namespace_query(
276287
for ns in namespaces
277288
]
278289

279-
results = await asyncio.gather(*queries, return_exceptions=True)
290+
for query_task in asyncio.as_completed(query_tasks):
291+
response = await query_task
292+
async with aggregator_lock:
293+
aggregator.add_results(response)
280294

281-
aggregator = QueryResultsAggregator(top_k=top_k)
282-
for result in results:
283-
if isinstance(result, Exception):
284-
continue
285-
aggregator.add_results(result)
286295
final_results = aggregator.get_results()
287296
return final_results
288297

pinecone/grpc/pinecone.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class PineconeGRPC(Pinecone):
4848
4949
"""
5050

51-
def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs):
51+
def Index(self, name: str = "", host: str = "", **kwargs):
5252
"""
5353
Target an index for data operations.
5454
@@ -119,6 +119,12 @@ def Index(self, name: str = "", host: str = "", use_asyncio=False, **kwargs):
119119
index.query(vector=[...], top_k=10)
120120
```
121121
"""
122+
return self._init_index(name=name, host=host, use_asyncio=False, **kwargs)
123+
124+
def AsyncioIndex(self, name: str = "", host: str = "", **kwargs):
125+
return self._init_index(name=name, host=host, use_asyncio=True, **kwargs)
126+
127+
def _init_index(self, name: str, host: str, use_asyncio=False, **kwargs):
122128
if name == "" and host == "":
123129
raise ValueError("Either name or host must be specified")
124130

pinecone/grpc/query_results_aggregator.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,39 +81,85 @@ def __repr__(self):
8181
)
8282

8383

84+
class QueryResultsAggregationEmptyResultsError(Exception):
85+
def __init__(self, namespace: str):
86+
super().__init__(
87+
f"Cannot infer metric type from empty query results. Query result for namespace '{namespace}' is empty. Have you spelled the namespace name correctly?"
88+
)
89+
90+
91+
class QueryResultsAggregregatorNotEnoughResultsError(Exception):
92+
def __init__(self, top_k: int, num_results: int):
93+
super().__init__(
94+
f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least {top_k} results but got {num_results}."
95+
)
96+
97+
98+
class QueryResultsAggregatorInvalidTopKError(Exception):
99+
def __init__(self, top_k: int):
100+
super().__init__(f"Invalid top_k value {top_k}. top_k must be a positive integer.")
101+
102+
84103
class QueryResultsAggregator:
85104
def __init__(self, top_k: int):
105+
if top_k < 1:
106+
raise QueryResultsAggregatorInvalidTopKError(top_k)
86107
self.top_k = top_k
87108
self.usage_read_units = 0
88109
self.heap = []
89110
self.insertion_counter = 0
111+
self.is_dotproduct = None
90112
self.read = False
91113

114+
def __is_dotproduct_index(self, matches):
115+
# The interpretation of the score depends on the similar metric used.
116+
# Unlike other index types, in indexes configured for dotproduct,
117+
# a higher score is better. We have to infer this is the case by inspecting
118+
# the order of the scores in the results.
119+
for i in range(1, len(matches)):
120+
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
121+
return False
122+
return True
123+
92124
def add_results(self, results: QueryResponse):
93125
if self.read:
126+
# This is mainly just to sanity check in test cases which get quite confusing
127+
# if you read results twice due to the heap being emptied when constructing
128+
# the ordered results.
94129
raise ValueError("Results have already been read. Cannot add more results.")
95130

96-
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
131+
matches = results.get("matches", [])
97132
ns = results.get("namespace")
98-
for match in results.get("matches", []):
99-
self.insertion_counter += 1
100-
score = match.get("score")
101-
if len(self.heap) < self.top_k:
102-
heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns))
103-
else:
104-
heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns))
133+
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)
134+
135+
if self.is_dotproduct is None:
136+
if len(matches) == 0:
137+
raise QueryResultsAggregationEmptyResultsError(ns)
138+
if len(matches) == 1:
139+
raise QueryResultsAggregregatorNotEnoughResultsError(self.top_k, len(matches))
140+
self.is_dotproduct = self.__is_dotproduct_index(matches)
141+
142+
print("is_dotproduct:", self.is_dotproduct)
143+
if self.is_dotproduct:
144+
raise NotImplementedError("Dotproduct indexes are not yet supported.")
145+
else:
146+
for match in matches:
147+
self.insertion_counter += 1
148+
score = match.get("score")
149+
if len(self.heap) < self.top_k:
150+
heapq.heappush(self.heap, (-score, -self.insertion_counter, match, ns))
151+
else:
152+
heapq.heappushpop(self.heap, (-score, -self.insertion_counter, match, ns))
105153

106154
def get_results(self) -> CompositeQueryResults:
107155
if self.read:
108-
# This is mainly just to sanity check in test cases which get quite confusing
109-
# if you read results twice due to the heap being emptied each time you read
110-
# results into an ordered list.
111-
raise ValueError("Results have already been read. Cannot read again.")
156+
return self.final_results
112157
self.read = True
113158

114-
return CompositeQueryResults(
159+
self.final_results = CompositeQueryResults(
115160
usage=Usage(read_units=self.usage_read_units),
116161
matches=[
117162
ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
118163
][::-1],
119164
)
165+
return self.final_results

0 commit comments

Comments
 (0)