Skip to content

Commit 0e8a8e1

Browse files
committed
new: add type annotations, add comment for hdf5 dataset
1 parent f1d873d commit 0e8a8e1

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

dataset_reader/h5_reader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ def read_data(self) -> Iterator[Record]:
2727
if __name__ == '__main__':
2828
import os
2929
from benchmark.settings import DATASET_DIR
30+
31+
# h5py file 4 keys:
32+
# `train` - float vectors (num vectors 1183514)
33+
# `test` - float vectors (num vectors 10000)
34+
# `neighbors` - int - indices of nearest neighbors for test (num items 10k, each item
35+
# contains info about 100 nearest neighbors)
36+
# `distances` - float - distances for nearest neighbors for test vectors
37+
3038
test_path = os.path.join(DATASET_DIR, 'glove-100-angular', 'glove-100-angular.hdf5')
3139
record = next(H5Reader(test_path).read_data())
32-
print(record)
40+
print(record)

engine/base_client/search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ class BaseSearcher:
99
MP_CONTEXT = None
1010

1111
@classmethod
12-
def init_client(cls, host, connection_params, search_params):
12+
def init_client(cls, host: str, connection_params: dict, search_params: dict):
1313
cls.search_params = search_params
1414
raise NotImplementedError()
1515

1616
@classmethod
17-
def search_one(cls, vector, meta_conditions) -> List[Tuple[int, float]]:
17+
def search_one(cls, vector: List[float], meta_conditions) -> List[Tuple[int, float]]:
1818
raise NotImplementedError()
1919

2020
@classmethod

engine/base_client/upload.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ def init_client(cls, host, connection_params: dict):
1616
raise NotImplementedError()
1717

1818
@classmethod
19-
def upload(cls, url, records: Iterable[Record], batch_size, parallel, connection_params):
19+
def upload(
20+
cls,
21+
url: str,
22+
records: Iterable[Record],
23+
batch_size: int,
24+
parallel: int,
25+
connection_params: dict,
26+
) -> List[float]:
2027
latencies = []
2128

2229
if parallel == 1:
@@ -38,16 +45,15 @@ def upload(cls, url, records: Iterable[Record], batch_size, parallel, connection
3845
return latencies
3946

4047
@classmethod
41-
def _upload_batch(cls, ids, vectors, metadata) -> float:
48+
def _upload_batch(
49+
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
50+
) -> float:
4251
start = time.perf_counter()
4352
cls.upload_batch(ids, vectors, metadata)
4453
return time.perf_counter() - start
4554

4655
@classmethod
4756
def upload_batch(
48-
cls,
49-
ids: List[int],
50-
vectors: List[list],
51-
metadata: List[Optional[dict]]
57+
cls, ids: List[int], vectors: List[list], metadata: List[Optional[dict]]
5258
):
5359
raise NotImplementedError()

0 commit comments

Comments
 (0)