Skip to content

Commit af3f614

Browse files
committed
WIP on asyncio index and composite_query method
1 parent ddef712 commit af3f614

File tree

9 files changed

+1000
-92
lines changed

9 files changed

+1000
-92
lines changed

pinecone/control/pinecone.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
import logging
3-
from typing import Optional, Dict, Any, Union, List, Tuple, Literal
3+
from typing import Optional, Dict, Any, Union, Literal
44

55
from .index_host_store import IndexHostStore
66

@@ -10,7 +10,12 @@
1010
from pinecone.core.openapi.shared.api_client import ApiClient
1111

1212

13-
from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
13+
from pinecone.utils import (
14+
normalize_host,
15+
setup_openapi_client,
16+
build_plugin_setup_client,
17+
parse_non_empty_args,
18+
)
1419
from pinecone.core.openapi.control.models import (
1520
CreateCollectionRequest,
1621
CreateIndexRequest,
@@ -317,9 +322,6 @@ def create_index(
317322

318323
api_instance = self.index_api
319324

320-
def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
321-
return {arg_name: val for arg_name, val in args if val is not None}
322-
323325
if deletion_protection in ["enabled", "disabled"]:
324326
dp = DeletionProtection(deletion_protection)
325327
else:
@@ -329,7 +331,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
329331
if "serverless" in spec:
330332
index_spec = IndexSpec(serverless=ServerlessSpecModel(**spec["serverless"]))
331333
elif "pod" in spec:
332-
args_dict = _parse_non_empty_args(
334+
args_dict = parse_non_empty_args(
333335
[
334336
("environment", spec["pod"].get("environment")),
335337
("metadata_config", spec["pod"].get("metadata_config")),
@@ -351,7 +353,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
351353
serverless=ServerlessSpecModel(cloud=spec.cloud, region=spec.region)
352354
)
353355
elif isinstance(spec, PodSpec):
354-
args_dict = _parse_non_empty_args(
356+
args_dict = parse_non_empty_args(
355357
[
356358
("replicas", spec.replicas),
357359
("shards", spec.shards),

pinecone/grpc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"""
4646

4747
from .index_grpc import GRPCIndex
48+
from .index_grpc_asyncio import GRPCIndexAsyncio
4849
from .pinecone import PineconeGRPC
4950
from .config import GRPCClientConfig
5051

pinecone/grpc/grpc_runner.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from functools import wraps
23
from typing import Dict, Tuple, Optional
34

@@ -62,20 +63,32 @@ async def run_asyncio(
6263
credentials: Optional[CallCredentials] = None,
6364
wait_for_ready: Optional[bool] = None,
6465
compression: Optional[Compression] = None,
66+
semaphore: Optional[asyncio.Semaphore] = None,
6567
):
6668
@wraps(func)
6769
async def wrapped():
6870
user_provided_metadata = metadata or {}
6971
_metadata = self._prepare_metadata(user_provided_metadata)
7072
try:
71-
return await func(
72-
request,
73-
timeout=timeout,
74-
metadata=_metadata,
75-
credentials=credentials,
76-
wait_for_ready=wait_for_ready,
77-
compression=compression,
78-
)
73+
if semaphore is not None:
74+
async with semaphore:
75+
return await func(
76+
request,
77+
timeout=timeout,
78+
metadata=_metadata,
79+
credentials=credentials,
80+
wait_for_ready=wait_for_ready,
81+
compression=compression,
82+
)
83+
else:
84+
return await func(
85+
request,
86+
timeout=timeout,
87+
metadata=_metadata,
88+
credentials=credentials,
89+
wait_for_ready=wait_for_ready,
90+
compression=compression,
91+
)
7992
except _InactiveRpcError as e:
8093
raise PineconeException(e._state.debug_error_string) from e
8194

pinecone/grpc/index_grpc.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
import logging
2-
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast
2+
from typing import Optional, Dict, Union, List, cast
33

44
from google.protobuf import json_format
55

66
from tqdm.autonotebook import tqdm
77

8+
from pinecone.utils import parse_non_empty_args
89
from .utils import (
910
dict_to_proto_struct,
1011
parse_fetch_response,
1112
parse_query_response,
1213
parse_stats_response,
14+
parse_sparse_values_arg,
1315
)
1416
from .vector_factory_grpc import VectorFactoryGRPC
17+
from .base import GRPCIndexBase
18+
from .future import PineconeGrpcFuture
19+
from .sparse_vector import SparseVectorTypedDict
20+
from .config import GRPCClientConfig
1521

1622
from pinecone.core.openapi.data.models import (
1723
FetchResponse,
@@ -36,10 +42,7 @@
3642
)
3743
from pinecone import Vector as NonGRPCVector
3844
from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub
39-
from .base import GRPCIndexBase
40-
from .future import PineconeGrpcFuture
4145

42-
from .config import GRPCClientConfig
4346
from pinecone.config import Config
4447
from grpc._channel import Channel
4548

@@ -49,11 +52,6 @@
4952
_logger = logging.getLogger(__name__)
5053

5154

52-
class SparseVectorTypedDict(TypedDict):
53-
indices: List[int]
54-
values: List[float]
55-
56-
5755
class GRPCIndex(GRPCIndexBase):
5856
"""A client for interacting with a Pinecone index via GRPC API."""
5957

@@ -152,7 +150,7 @@ def upsert(
152150

153151
vectors = list(map(VectorFactoryGRPC.build, vectors))
154152
if async_req:
155-
args_dict = self._parse_non_empty_args([("namespace", namespace)])
153+
args_dict = parse_non_empty_args([("namespace", namespace)])
156154
request = UpsertRequest(vectors=vectors, **args_dict, **kwargs)
157155
future = self.runner.run(self.stub.Upsert.future, request, timeout=timeout)
158156
return PineconeGrpcFuture(future)
@@ -178,7 +176,7 @@ def upsert(
178176
def _upsert_batch(
179177
self, vectors: List[GRPCVector], namespace: Optional[str], timeout: Optional[int], **kwargs
180178
) -> UpsertResponse:
181-
args_dict = self._parse_non_empty_args([("namespace", namespace)])
179+
args_dict = parse_non_empty_args([("namespace", namespace)])
182180
request = UpsertRequest(vectors=vectors, **args_dict)
183181
return self.runner.run(self.stub.Upsert, request, timeout=timeout, **kwargs)
184182

@@ -285,7 +283,7 @@ def delete(
285283
else:
286284
filter_struct = None
287285

288-
args_dict = self._parse_non_empty_args(
286+
args_dict = parse_non_empty_args(
289287
[
290288
("ids", ids),
291289
("delete_all", delete_all),
@@ -322,7 +320,7 @@ def fetch(
322320
"""
323321
timeout = kwargs.pop("timeout", None)
324322

325-
args_dict = self._parse_non_empty_args([("namespace", namespace)])
323+
args_dict = parse_non_empty_args([("namespace", namespace)])
326324

327325
request = FetchRequest(ids=ids, **args_dict, **kwargs)
328326
response = self.runner.run(self.stub.Fetch, request, timeout=timeout)
@@ -388,8 +386,8 @@ def query(
388386
else:
389387
filter_struct = None
390388

391-
sparse_vector = self._parse_sparse_values_arg(sparse_vector)
392-
args_dict = self._parse_non_empty_args(
389+
sparse_vector = parse_sparse_values_arg(sparse_vector)
390+
args_dict = parse_non_empty_args(
393391
[
394392
("vector", vector),
395393
("id", id),
@@ -456,8 +454,8 @@ def update(
456454
set_metadata_struct = None
457455

458456
timeout = kwargs.pop("timeout", None)
459-
sparse_values = self._parse_sparse_values_arg(sparse_values)
460-
args_dict = self._parse_non_empty_args(
457+
sparse_values = parse_sparse_values_arg(sparse_values)
458+
args_dict = parse_non_empty_args(
461459
[
462460
("values", values),
463461
("set_metadata", set_metadata_struct),
@@ -506,7 +504,7 @@ def list_paginated(
506504
507505
Returns: SimpleListResponse object which contains the list of ids, the namespace name, pagination information, and usage showing the number of read_units consumed.
508506
"""
509-
args_dict = self._parse_non_empty_args(
507+
args_dict = parse_non_empty_args(
510508
[
511509
("prefix", prefix),
512510
("limit", limit),
@@ -585,36 +583,10 @@ def describe_index_stats(
585583
filter_struct = dict_to_proto_struct(filter)
586584
else:
587585
filter_struct = None
588-
args_dict = self._parse_non_empty_args([("filter", filter_struct)])
586+
args_dict = parse_non_empty_args([("filter", filter_struct)])
589587
timeout = kwargs.pop("timeout", None)
590588

591589
request = DescribeIndexStatsRequest(**args_dict)
592590
response = self.runner.run(self.stub.DescribeIndexStats, request, timeout=timeout)
593591
json_response = json_format.MessageToDict(response)
594592
return parse_stats_response(json_response)
595-
596-
@staticmethod
597-
def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
598-
return {arg_name: val for arg_name, val in args if val is not None}
599-
600-
@staticmethod
601-
def _parse_sparse_values_arg(
602-
sparse_values: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]],
603-
) -> Optional[GRPCSparseValues]:
604-
if sparse_values is None:
605-
return None
606-
607-
if isinstance(sparse_values, GRPCSparseValues):
608-
return sparse_values
609-
610-
if (
611-
not isinstance(sparse_values, dict)
612-
or "indices" not in sparse_values
613-
or "values" not in sparse_values
614-
):
615-
raise ValueError(
616-
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
617-
f"Received: {sparse_values}"
618-
)
619-
620-
return GRPCSparseValues(indices=sparse_values["indices"], values=sparse_values["values"])

0 commit comments

Comments
 (0)