Skip to content

Commit ab28227

Browse files
authored
Implement query_namespaces over grpc (#416)
## Problem I want to maintain parity across REST / GRPC implementations. This PR adds a query_namespaces implementation for the GRPC index client. ## Solution Use a ThreadPoolExecutor to execute queries in parallel, then aggregate the results QueryResultsAggregator. ## Usage ```python import random from pinecone.grpc import PineconeGRPC pc = PineconeGRPC(api_key="key") index = pc.Index(host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io", pool_threads=25) query_vec = [random.random() for i in range(1024)] combined_results = index.query_namespaces( vector=query_vec, namespaces=["ns1", "ns2", "ns3", "ns4"], include_values=False, include_metadata=True, filter={"genre": {"$eq": "drama"}}, top_k=50 ) for vec in combined_results.matches: print(vec.get('id'), vec.get('score')) print(combined_results.usage) ``` ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan Describe specific steps for validating this change.
1 parent eade7dd commit ab28227

File tree

4 files changed

+60
-6
lines changed

4 files changed

+60
-6
lines changed

pinecone/grpc/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pinecone import Config
1111
from .config import GRPCClientConfig
1212
from .grpc_runner import GrpcRunner
13+
from concurrent.futures import ThreadPoolExecutor
1314

1415
from pinecone_plugin_interface import load_and_install as install_plugins
1516

@@ -29,10 +30,12 @@ def __init__(
2930
config: Config,
3031
channel: Optional[Channel] = None,
3132
grpc_config: Optional[GRPCClientConfig] = None,
33+
pool_threads: Optional[int] = None,
3234
_endpoint_override: Optional[str] = None,
3335
):
3436
self.config = config
3537
self.grpc_client_config = grpc_config or GRPCClientConfig()
38+
self.pool_threads = pool_threads
3639

3740
self._endpoint_override = _endpoint_override
3841

@@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs):
5861
except Exception as e:
5962
_logger.error(f"Error loading plugins in GRPCIndex: {e}")
6063

64+
@property
65+
def threadpool_executor(self):
66+
if self._pool is None:
67+
pt = self.pool_threads or 10
68+
self._pool = ThreadPoolExecutor(max_workers=pt)
69+
return self._pool
70+
6171
@property
6272
@abstractmethod
6373
def stub_class(self):

pinecone/grpc/index_grpc.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
2-
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast
2+
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast
33

44
from google.protobuf import json_format
55

66
from tqdm.autonotebook import tqdm
7+
from concurrent.futures import as_completed, Future
8+
79

810
from .utils import (
911
dict_to_proto_struct,
@@ -35,6 +37,7 @@
3537
SparseValues as GRPCSparseValues,
3638
)
3739
from pinecone import Vector as NonGRPCVector
40+
from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator
3841
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
3942
from .base import GRPCIndexBase
4043
from .future import PineconeGrpcFuture
@@ -402,6 +405,49 @@ def query(
402405
json_response = json_format.MessageToDict(response)
403406
return parse_query_response(json_response, _check_type=False)
404407

408+
def query_namespaces(
409+
self,
410+
vector: List[float],
411+
namespaces: List[str],
412+
top_k: Optional[int] = None,
413+
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
414+
include_values: Optional[bool] = None,
415+
include_metadata: Optional[bool] = None,
416+
sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None,
417+
**kwargs,
418+
) -> QueryNamespacesResults:
419+
if namespaces is None or len(namespaces) == 0:
420+
raise ValueError("At least one namespace must be specified")
421+
if len(vector) == 0:
422+
raise ValueError("Query vector must not be empty")
423+
424+
overall_topk = top_k if top_k is not None else 10
425+
aggregator = QueryResultsAggregator(top_k=overall_topk)
426+
427+
target_namespaces = set(namespaces) # dedup namespaces
428+
futures = [
429+
self.threadpool_executor.submit(
430+
self.query,
431+
vector=vector,
432+
namespace=ns,
433+
top_k=overall_topk,
434+
filter=filter,
435+
include_values=include_values,
436+
include_metadata=include_metadata,
437+
sparse_vector=sparse_vector,
438+
async_req=False,
439+
**kwargs,
440+
)
441+
for ns in target_namespaces
442+
]
443+
444+
only_futures = cast(Iterable[Future], futures)
445+
for response in as_completed(only_futures):
446+
aggregator.add_results(response.result())
447+
448+
final_results = aggregator.get_results()
449+
return final_results
450+
405451
def update(
406452
self,
407453
id: str,

pinecone/grpc/pinecone.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,13 @@ def Index(self, name: str = "", host: str = "", **kwargs):
124124
# Use host if it is provided, otherwise get host from describe_index
125125
index_host = host or self.index_host_store.get_host(self.index_api, self.config, name)
126126

127+
pt = kwargs.pop("pool_threads", None) or self.pool_threads
128+
127129
config = ConfigBuilder.build(
128130
api_key=self.config.api_key,
129131
host=index_host,
130132
source_tag=self.config.source_tag,
131133
proxy_url=self.config.proxy_url,
132134
ssl_ca_certs=self.config.ssl_ca_certs,
133135
)
134-
return GRPCIndex(index_name=name, config=config, **kwargs)
136+
return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs)

tests/integration/data/test_query_namespaces.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
import os
32
from ..helpers import random_string, poll_stats_for_namespace
43
from pinecone.data.query_results_aggregator import (
54
QueryResultsAggregatorInvalidTopKError,
@@ -9,9 +8,6 @@
98
from pinecone import Vector
109

1110

12-
@pytest.mark.skipif(
13-
os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest"
14-
)
1511
class TestQueryNamespacesRest:
1612
def test_query_namespaces(self, idx):
1713
ns_prefix = random_string(5)

0 commit comments

Comments
 (0)