Skip to content

Commit 0b40783

Browse files
Merge pull request #28 from qdrant/pre-commit-ci-update-config
[pre-commit.ci] pre-commit suggestions
2 parents aaa3bb1 + 6f074a4 commit 0b40783

File tree

17 files changed

+117
-85
lines changed

17 files changed

+117
-85
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: check-added-large-files
1818

1919
- repo: https://github.com/psf/black
20-
rev: 22.6.0
20+
rev: 22.8.0
2121
hooks:
2222
- id: black
2323
name: "Black: The uncompromising Python code formatter"

benchmark/config_read.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44

5-
from benchmark import ROOT_DIR, DATASETS_DIR
5+
from benchmark import DATASETS_DIR, ROOT_DIR
66

77

88
def read_engine_configs() -> dict:

engine/base_client/client.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@
1414

1515
class BaseClient:
1616
def __init__(
17-
self,
18-
name: str, # name of the experiment
19-
configurator: BaseConfigurator,
20-
uploader: BaseUploader,
21-
searchers: List[BaseSearcher],
17+
self,
18+
name: str, # name of the experiment
19+
configurator: BaseConfigurator,
20+
uploader: BaseUploader,
21+
searchers: List[BaseSearcher],
2222
):
2323
self.name = name
2424
self.configurator = configurator
2525
self.uploader = uploader
2626
self.searchers = searchers
2727

2828
def save_search_results(
29-
self, dataset_name: str, results: dict, search_id: int, search_params: dict
29+
self, dataset_name: str, results: dict, search_id: int, search_params: dict
3030
):
3131
now = datetime.now()
3232
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")
@@ -38,7 +38,9 @@ def save_search_results(
3838
json.dumps({"params": search_params, "results": results}, indent=2)
3939
)
4040

41-
def save_upload_results(self, dataset_name: str, results: dict, upload_params: dict):
41+
def save_upload_results(
42+
self, dataset_name: str, results: dict, upload_params: dict
43+
):
4244
now = datetime.now()
4345
timestamp = now.strftime("%Y-%m-%d-%H-%M-%S")
4446
experiments_file = f"{self.name}-{dataset_name}-upload-{timestamp}.json"
@@ -51,8 +53,8 @@ def save_upload_results(self, dataset_name: str, results: dict, upload_params: d
5153

5254
def run_experiment(self, dataset: Dataset, skip_upload: bool = False):
5355
execution_params = self.configurator.execution_params(
54-
distance=dataset.config.distance,
55-
vector_size=dataset.config.vector_size)
56+
distance=dataset.config.distance, vector_size=dataset.config.vector_size
57+
)
5658

5759
reader = dataset.get_reader(execution_params.get("normalize", False))
5860

@@ -65,18 +67,23 @@ def run_experiment(self, dataset: Dataset, skip_upload: bool = False):
6567

6668
print("Experiment stage: Upload")
6769
upload_stats = self.uploader.upload(
68-
distance=dataset.config.distance,
69-
records=reader.read_data()
70+
distance=dataset.config.distance, records=reader.read_data()
71+
)
72+
self.save_upload_results(
73+
dataset.config.name,
74+
upload_stats,
75+
upload_params={
76+
**self.uploader.upload_params,
77+
**self.configurator.collection_params,
78+
},
7079
)
71-
self.save_upload_results(dataset.config.name, upload_stats, upload_params={
72-
**self.uploader.upload_params,
73-
**self.configurator.collection_params
74-
})
7580

7681
print("Experiment stage: Search")
7782
for search_id, searcher in enumerate(self.searchers):
7883
search_params = {**searcher.search_params}
79-
search_stats = searcher.search_all(dataset.config.distance, reader.read_queries())
84+
search_stats = searcher.search_all(
85+
dataset.config.distance, reader.read_queries()
86+
)
8087
self.save_search_results(
8188
dataset.config.name, search_stats, search_id, search_params
8289
)

engine/base_client/search.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def __init__(self, host, connection_params, search_params):
2020
self.search_params = search_params
2121

2222
@classmethod
23-
def init_client(cls, host: str, distance, connection_params: dict, search_params: dict):
23+
def init_client(
24+
cls, host: str, distance, connection_params: dict, search_params: dict
25+
):
2426
raise NotImplementedError()
2527

2628
@classmethod
@@ -63,20 +65,29 @@ def search_all(
6365
top = self.search_params.pop("top", None)
6466

6567
# setup_search may require initialized client
66-
self.init_client(self.host, distance, self.connection_params, self.search_params)
68+
self.init_client(
69+
self.host, distance, self.connection_params, self.search_params
70+
)
6771
self.setup_search()
6872

6973
search_one = functools.partial(self.__class__._search_one, top=top)
7074

7175
if parallel == 1:
72-
precisions, latencies = list(zip(*[search_one(query) for query in tqdm.tqdm(queries)]))
76+
precisions, latencies = list(
77+
zip(*[search_one(query) for query in tqdm.tqdm(queries)])
78+
)
7379
else:
7480
ctx = get_context(self.get_mp_start_method())
7581

7682
with ctx.Pool(
7783
processes=parallel,
7884
initializer=self.__class__.init_client,
79-
initargs=(self.host, distance, self.connection_params, self.search_params),
85+
initargs=(
86+
self.host,
87+
distance,
88+
self.connection_params,
89+
self.search_params,
90+
),
8091
) as pool:
8192
precisions, latencies = list(
8293
zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(queries)))
@@ -101,4 +112,4 @@ def setup_search(self):
101112
pass
102113

103114
def post_search(self):
104-
pass
115+
pass

engine/base_client/upload.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,40 @@ def init_client(cls, host, distance, connection_params: dict, upload_params: dic
2525
raise NotImplementedError()
2626

2727
def upload(
28-
self,
29-
distance,
30-
records: Iterable[Record],
28+
self,
29+
distance,
30+
records: Iterable[Record],
3131
) -> dict:
3232
latencies = []
3333
start = time.perf_counter()
3434
parallel = self.upload_params.pop("parallel", 1)
3535
batch_size = self.upload_params.pop("batch_size", 64)
3636

37-
self.init_client(self.host, distance, self.connection_params, self.upload_params)
37+
self.init_client(
38+
self.host, distance, self.connection_params, self.upload_params
39+
)
3840

3941
if parallel == 1:
4042
for batch in iter_batches(tqdm.tqdm(records), batch_size):
4143
latencies.append(self._upload_batch(batch))
4244
else:
4345
ctx = get_context(self.get_mp_start_method())
4446
with ctx.Pool(
45-
processes=int(parallel),
46-
initializer=self.__class__.init_client,
47-
initargs=(self.host, distance, self.connection_params, self.upload_params),
47+
processes=int(parallel),
48+
initializer=self.__class__.init_client,
49+
initargs=(
50+
self.host,
51+
distance,
52+
self.connection_params,
53+
self.upload_params,
54+
),
4855
) as pool:
49-
latencies = list(pool.imap(
50-
self.__class__._upload_batch,
51-
iter_batches(tqdm.tqdm(records), batch_size),
52-
))
56+
latencies = list(
57+
pool.imap(
58+
self.__class__._upload_batch,
59+
iter_batches(tqdm.tqdm(records), batch_size),
60+
)
61+
)
5362

5463
upload_time = time.perf_counter() - start
5564

@@ -68,7 +77,7 @@ def upload(
6877

6978
@classmethod
7079
def _upload_batch(
71-
cls, batch: Tuple[List[int], List[list], List[Optional[dict]]]
80+
cls, batch: Tuple[List[int], List[list], List[Optional[dict]]]
7281
) -> float:
7382
ids, vectors, metadata = batch
7483
start = time.perf_counter()
@@ -81,6 +90,6 @@ def post_upload(cls, distance):
8190

8291
@classmethod
8392
def upload_batch(
84-
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
93+
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
8594
):
8695
raise NotImplementedError()

engine/clients/elasticsearch/configure.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ def __init__(self, host, collection_params: dict, connection_params: dict):
3636

3737
def clean(self):
3838
try:
39-
self.client.indices.delete(index=ELASTIC_INDEX, timeout="5m", master_timeout="5m")
39+
self.client.indices.delete(
40+
index=ELASTIC_INDEX, timeout="5m", master_timeout="5m"
41+
)
4042
except NotFoundError:
4143
pass
4244

engine/clients/elasticsearch/search.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import multiprocessing as mp
12
import uuid
23
from typing import List, Tuple
3-
import multiprocessing as mp
44

55
from elasticsearch import Elasticsearch
66

@@ -14,7 +14,6 @@
1414

1515

1616
class ClosableElastic(Elasticsearch):
17-
1817
def __del__(self):
1918
self.close()
2019

@@ -25,7 +24,7 @@ class ElasticSearcher(BaseSearcher):
2524

2625
@classmethod
2726
def get_mp_start_method(cls):
28-
return 'forkserver' if 'forkserver' in mp.get_all_start_methods() else 'spawn'
27+
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"
2928

3029
@classmethod
3130
def init_client(cls, host, distance, connection_params: dict, search_params: dict):

engine/clients/elasticsearch/upload.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import uuid
21
import multiprocessing as mp
2+
import uuid
33
from typing import List, Optional
44

55
from elasticsearch import Elasticsearch
@@ -14,7 +14,6 @@
1414

1515

1616
class ClosableElastic(Elasticsearch):
17-
1817
def __del__(self):
1918
self.close()
2019

@@ -25,7 +24,7 @@ class ElasticUploader(BaseUploader):
2524

2625
@classmethod
2726
def get_mp_start_method(cls):
28-
return 'forkserver' if 'forkserver' in mp.get_all_start_methods() else 'spawn'
27+
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"
2928

3029
@classmethod
3130
def init_client(cls, host, distance, connection_params, upload_params):

engine/clients/milvus/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
# Milvus does not support cosine. Cosine is equal to IP of normalized vectors
1111
Distance.COSINE: "IP"
1212
# Jaccard, Tanimoto, Hamming distance, Superstructure and Substructure are also available
13-
}
13+
}

engine/clients/milvus/configure.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919

2020
class MilvusConfigurator(BaseConfigurator):
21-
2221
def __init__(self, host, collection_params: dict, connection_params: dict):
2322
super().__init__(host, collection_params, connection_params)
2423
self.client = connections.connect(

0 commit comments

Comments
 (0)