Skip to content

Commit 4a99468

Browse files
authored
query_namespaces performance improvements (#417)
## Problem Want to improve the performance of the rest implementation of `query_namespaces` ## Solution - Add `pytest-benchmark` dev dependency and some basic performance tests to interrogate the impact of certain changes. For now these are only run on my local machine, but in the future these could potentially be expanded into an automated suite. - Pass `_preload_content=False` to tell the underlying generated code not to instantiate response objects for all the intermediate results. - Use `ThreadPoolExecutor` instead of older `ThreadPool` implementation from multiprocessing. This involved some changes to the generated code, but the benefit of this approach is that you get back a `concurrent.futures.Future` instead of an `ApplyResult` which is much more ergonomic. I'm planning to extract the edited files out of the code gen process very shortly, so there shouldn't be a concern about modifying generated files in this case. I gated this approach behind a new kwarg, `async_threadpool_executor`, that lives alongside `async_req`; eventually I would like to replace all usage of `async_req`'s ThreadPool with ThreadPoolExecutor to bring the rest and grpc implementations closer together, but I can't do that in this PR without creating a breaking change. The net effect of these changes seems to be about ~18% performance improvement. ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here)
1 parent ab28227 commit 4a99468

File tree

7 files changed

+172
-6
lines changed

7 files changed

+172
-6
lines changed

pinecone/core/openapi/shared/api_client.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import atexit
33
import mimetypes
44
from multiprocessing.pool import ThreadPool
5+
from concurrent.futures import ThreadPoolExecutor
56
import io
67
import os
78
import re
@@ -70,6 +71,7 @@ class ApiClient(object):
7071
"""
7172

7273
_pool = None
74+
_threadpool_executor = None
7375

7476
def __init__(self, configuration=None, header_name=None, header_value=None, cookie=None, pool_threads=1):
7577
if configuration is None:
@@ -92,6 +94,9 @@ def __exit__(self, exc_type, exc_value, traceback):
9294
self.close()
9395

9496
def close(self):
97+
if self._threadpool_executor:
98+
self._threadpool_executor.shutdown()
99+
self._threadpool_executor = None
95100
if self._pool:
96101
self._pool.close()
97102
self._pool.join()
@@ -109,6 +114,12 @@ def pool(self):
109114
self._pool = ThreadPool(self.pool_threads)
110115
return self._pool
111116

117+
@property
118+
def threadpool_executor(self):
119+
if self._threadpool_executor is None:
120+
self._threadpool_executor = ThreadPoolExecutor(max_workers=self.pool_threads)
121+
return self._threadpool_executor
122+
112123
@property
113124
def user_agent(self):
114125
"""User agent for this API client"""
@@ -334,6 +345,7 @@ def call_api(
334345
response_type: typing.Optional[typing.Tuple[typing.Any]] = None,
335346
auth_settings: typing.Optional[typing.List[str]] = None,
336347
async_req: typing.Optional[bool] = None,
348+
async_threadpool_executor: typing.Optional[bool] = None,
337349
_return_http_data_only: typing.Optional[bool] = None,
338350
collection_formats: typing.Optional[typing.Dict[str, str]] = None,
339351
_preload_content: bool = True,
@@ -394,6 +406,27 @@ def call_api(
394406
If parameter async_req is False or missing,
395407
then the method will return the response directly.
396408
"""
409+
if async_threadpool_executor:
410+
return self.threadpool_executor.submit(
411+
self.__call_api,
412+
resource_path,
413+
method,
414+
path_params,
415+
query_params,
416+
header_params,
417+
body,
418+
post_params,
419+
files,
420+
response_type,
421+
auth_settings,
422+
_return_http_data_only,
423+
collection_formats,
424+
_preload_content,
425+
_request_timeout,
426+
_host,
427+
_check_type,
428+
)
429+
397430
if not async_req:
398431
return self.__call_api(
399432
resource_path,
@@ -690,6 +723,7 @@ def __init__(self, settings=None, params_map=None, root_map=None, headers_map=No
690723
self.params_map["all"].extend(
691724
[
692725
"async_req",
726+
"async_threadpool_executor",
693727
"_host_index",
694728
"_preload_content",
695729
"_request_timeout",
@@ -704,6 +738,7 @@ def __init__(self, settings=None, params_map=None, root_map=None, headers_map=No
704738
self.openapi_types = root_map["openapi_types"]
705739
extra_types = {
706740
"async_req": (bool,),
741+
"async_threadpool_executor": (bool, ),
707742
"_host_index": (none_type, int),
708743
"_preload_content": (bool,),
709744
"_request_timeout": (none_type, float, (float,), [float], int, (int,), [int]),
@@ -853,6 +888,7 @@ def call_with_http_info(self, **kwargs):
853888
response_type=self.settings["response_type"],
854889
auth_settings=self.settings["auth"],
855890
async_req=kwargs["async_req"],
891+
async_threadpool_executor=kwargs.get("async_threadpool_executor", None),
856892
_check_type=kwargs["_check_return_type"],
857893
_return_http_data_only=kwargs["_return_http_data_only"],
858894
_preload_content=kwargs["_preload_content"],

pinecone/core/openapi/shared/configuration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,23 @@ def host(self, value):
469469
"""Fix base path."""
470470
self._base_path = value
471471
self.server_index = None
472+
473+
def __repr__(self):
474+
attrs = [
475+
f"host={self.host}",
476+
f"api_key=***",
477+
f"api_key_prefix={self.api_key_prefix}",
478+
f"access_token={self.access_token}",
479+
f"connection_pool_maxsize={self.connection_pool_maxsize}",
480+
f"username={self.username}",
481+
f"password={self.password}",
482+
f"discard_unknown_keys={self.discard_unknown_keys}",
483+
f"disabled_client_side_validations={self.disabled_client_side_validations}",
484+
f"server_index={self.server_index}",
485+
f"server_variables={self.server_variables}",
486+
f"server_operation_index={self.server_operation_index}",
487+
f"server_operation_variables={self.server_operation_variables}",
488+
f"ssl_ca_cert={self.ssl_ca_cert}",
489+
490+
]
491+
return f"Configuration({', '.join(attrs)})"

pinecone/data/index.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from tqdm.autonotebook import tqdm
22

33
import logging
4+
import json
45
from typing import Union, List, Optional, Dict, Any
56

67
from pinecone.config import ConfigBuilder
@@ -34,7 +35,9 @@
3435
from .features.bulk_import import ImportFeatureMixin
3536
from .vector_factory import VectorFactory
3637
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
38+
3739
from multiprocessing.pool import ApplyResult
40+
from concurrent.futures import as_completed
3841

3942
from pinecone_plugin_interface import load_and_install as install_plugins
4043

@@ -67,6 +70,7 @@
6770
"_check_return_type",
6871
"_host_index",
6972
"async_req",
73+
"async_threadpool_executor",
7074
)
7175

7276

@@ -447,7 +451,7 @@ def query(
447451
**kwargs,
448452
)
449453

450-
if kwargs.get("async_req", False):
454+
if kwargs.get("async_req", False) or kwargs.get("async_threadpool_executor", False):
451455
return response
452456
else:
453457
return parse_query_response(response)
@@ -491,6 +495,7 @@ def _query(
491495
("sparse_vector", sparse_vector),
492496
]
493497
)
498+
494499
response = self._vector_api.query(
495500
QueryRequest(
496501
**args_dict,
@@ -566,7 +571,7 @@ def query_namespaces(
566571
aggregator = QueryResultsAggregator(top_k=overall_topk)
567572

568573
target_namespaces = set(namespaces) # dedup namespaces
569-
async_results = [
574+
async_futures = [
570575
self.query(
571576
vector=vector,
572577
namespace=ns,
@@ -575,14 +580,16 @@ def query_namespaces(
575580
include_values=include_values,
576581
include_metadata=include_metadata,
577582
sparse_vector=sparse_vector,
578-
async_req=True,
583+
async_threadpool_executor=True,
584+
_preload_content=False,
579585
**kwargs,
580586
)
581587
for ns in target_namespaces
582588
]
583589

584-
for result in async_results:
585-
response = result.get()
590+
for result in as_completed(async_futures):
591+
raw_result = result.result()
592+
response = json.loads(raw_result.data.decode("utf-8"))
586593
aggregator.add_results(response)
587594

588595
final_results = aggregator.get_results()

poetry.lock

Lines changed: 32 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ pytest-asyncio = "0.15.1"
8282
pytest-cov = "2.10.1"
8383
pytest-mock = "3.6.1"
8484
pytest-timeout = "2.2.0"
85+
pytest-benchmark = [
86+
{ version = '5.0.0', python = ">=3.9,<4.0" }
87+
]
8588
urllib3_mock = "0.3.3"
8689
responses = ">=0.8.1"
8790
ddtrace = "^2.14.4"

tests/perf/test_query_namespaces.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import time
2+
import random
3+
import pytest
4+
from pinecone import Pinecone
5+
from pinecone.grpc import PineconeGRPC
6+
7+
latencies = []
8+
9+
10+
def call_n_threads(index):
11+
query_vec = [random.random() for i in range(1024)]
12+
start = time.time()
13+
combined_results = index.query_namespaces(
14+
vector=query_vec,
15+
namespaces=["ns1", "ns2", "ns3", "ns4"],
16+
include_values=False,
17+
include_metadata=True,
18+
filter={"publication_date": {"$eq": "Last3Months"}},
19+
top_k=1000,
20+
)
21+
finish = time.time()
22+
# print(f"Query took {finish-start} seconds")
23+
latencies.append(finish - start)
24+
25+
return combined_results
26+
27+
28+
class TestQueryNamespacesRest:
29+
@pytest.mark.parametrize("n_threads", [4])
30+
def test_query_namespaces_grpc(self, benchmark, n_threads):
31+
pc = PineconeGRPC()
32+
index = pc.Index(
33+
host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io", pool_threads=n_threads
34+
)
35+
benchmark.pedantic(call_n_threads, (index,), rounds=10, warmup_rounds=1, iterations=5)
36+
37+
@pytest.mark.parametrize("n_threads", [4])
38+
def test_query_namespaces_rest(self, benchmark, n_threads):
39+
pc = Pinecone()
40+
index = pc.Index(
41+
host="jen1024-dojoi3u.svc.apw5-4e34-81fa.pinecone.io",
42+
pool_threads=n_threads,
43+
connection_pool_maxsize=20,
44+
)
45+
benchmark.pedantic(call_n_threads, (index,), rounds=10, warmup_rounds=1, iterations=5)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import random
2+
from pinecone.data.query_results_aggregator import QueryResultsAggregator
3+
4+
5+
def fake_results(i):
6+
matches = [
7+
{"id": f"id{i}", "score": random.random(), "values": [random.random() for _ in range(768)]}
8+
for _ in range(1000)
9+
]
10+
matches.sort(key=lambda x: x["score"], reverse=True)
11+
return {"namespace": f"ns{i}", "matches": matches}
12+
13+
14+
def aggregate_results(responses):
15+
ag = QueryResultsAggregator(1000)
16+
for response in responses:
17+
ag.add_results(response)
18+
return ag.get_results()
19+
20+
21+
class TestQueryResultsAggregatorPerf:
22+
def test_my_stuff(self, benchmark):
23+
responses = [fake_results(i) for i in range(10)]
24+
benchmark(aggregate_results, responses)

0 commit comments

Comments
 (0)