Skip to content

Commit 5674ab2

Browse files
authored
Merge pull request #29 from qdrant/add-filtering-support
Add filtering support
2 parents 036e4b3 + 7a0a35e commit 5674ab2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+1413
-283
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.idea/
2+
.pytest_cache/
23
__pycache__
34
*.pyc
45
NOTES.md

benchmark/dataset.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import shutil
33
import tarfile
44
import urllib.request
5-
from dataclasses import dataclass
6-
from typing import Optional
5+
from dataclasses import dataclass, field
6+
from typing import Dict, Optional
77

88
from benchmark import DATASETS_DIR
9+
from dataset_reader.ann_compound_reader import AnnCompoundReader
910
from dataset_reader.ann_h5_reader import AnnH5Reader
1011
from dataset_reader.base_reader import BaseReader
1112
from dataset_reader.json_reader import JSONReader
@@ -19,9 +20,10 @@ class DatasetConfig:
1920
type: str
2021
path: str
2122
link: Optional[str] = None
23+
schema: Optional[Dict[str, str]] = field(default_factory=dict)
2224

2325

24-
READER_TYPE = {"h5": AnnH5Reader, "jsonl": JSONReader}
26+
READER_TYPE = {"h5": AnnH5Reader, "jsonl": JSONReader, "tar": AnnCompoundReader}
2527

2628

2729
class Dataset:
@@ -39,9 +41,11 @@ def download(self):
3941
print(f"Downloading {self.config.link}...")
4042
tmp_path, _ = urllib.request.urlretrieve(self.config.link)
4143

42-
if tmp_path.endswith(".tgz") or tmp_path.endswith(".tar.gz"):
44+
if self.config.link.endswith(".tgz") or self.config.link.endswith(
45+
".tar.gz"
46+
):
4347
print(f"Extracting: {tmp_path} -> {target_path}")
44-
(DATASETS_DIR / self.config.path).mkdir(exist_ok=True)
48+
(DATASETS_DIR / self.config.path).mkdir(exist_ok=True, parents=True)
4549
file = tarfile.open(tmp_path)
4650
file.extractall(target_path)
4751
file.close()

dataset_reader/ann_compound_reader.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
from typing import Iterator, List
3+
4+
import numpy as np
5+
6+
from dataset_reader.base_reader import Query
7+
from dataset_reader.json_reader import JSONReader
8+
9+
10+
class AnnCompoundReader(JSONReader):
11+
"""
12+
A reader created specifically to read the format used in
13+
https://github.com/qdrant/ann-filtering-benchmark-datasets, in which vectors
14+
and their metadata are stored in separate files.
15+
"""
16+
17+
VECTORS_FILE = "vectors.npy"
18+
QUERIES_FILE = "tests.jsonl"
19+
20+
def read_vectors(self) -> Iterator[List[float]]:
21+
vectors = np.load(self.path / self.VECTORS_FILE)
22+
for vector in vectors:
23+
if self.normalize:
24+
vector = vector / np.linalg.norm(vector)
25+
yield vector.tolist()
26+
27+
def read_queries(self) -> Iterator[Query]:
28+
with open(self.path / self.QUERIES_FILE) as payloads_fp:
29+
for idx, row in enumerate(payloads_fp):
30+
row_json = json.loads(row)
31+
vector = np.array(row_json["query"])
32+
if self.normalize:
33+
vector /= np.linalg.norm(vector)
34+
yield Query(
35+
vector=vector.tolist(),
36+
meta_conditions=row_json["conditions"],
37+
expected_result=row_json["closest_ids"],
38+
expected_scores=row_json["closest_scores"],
39+
)

dataset_reader/json_reader.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,46 @@
66

77
from dataset_reader.base_reader import BaseReader, Query, Record
88

9-
VECTORS_FILE = "vectors.jsonl"
10-
PAYLOADS_FILE = "payloads.jsonl"
11-
QUERIES_FILE = "queries.jsonl"
12-
NEIGHBOURS_FILE = "neighbours.jsonl"
13-
149

1510
class JSONReader(BaseReader):
11+
VECTORS_FILE = "vectors.jsonl"
12+
PAYLOADS_FILE = "payloads.jsonl"
13+
QUERIES_FILE = "queries.jsonl"
14+
NEIGHBOURS_FILE = "neighbours.jsonl"
15+
1616
def __init__(self, path: Path, normalize=False):
1717
self.path = path
1818
self.normalize = normalize
1919

2020
def read_payloads(self) -> Iterator[dict]:
21-
if not (self.path / PAYLOADS_FILE).exists():
21+
if not (self.path / self.PAYLOADS_FILE).exists():
2222
while True:
2323
yield {}
24-
with open(self.path / PAYLOADS_FILE, "r") as json_fp:
24+
with open(self.path / self.PAYLOADS_FILE, "r") as json_fp:
2525
for json_line in json_fp:
2626
line = json.loads(json_line)
2727
yield line
2828

2929
def read_vectors(self) -> Iterator[List[float]]:
30-
with open(self.path / VECTORS_FILE, "r") as json_fp:
30+
with open(self.path / self.VECTORS_FILE, "r") as json_fp:
3131
for json_line in json_fp:
3232
vector = json.loads(json_line)
3333
if self.normalize:
3434
vector = vector / np.linalg.norm(vector)
3535
yield vector
3636

3737
def read_neighbours(self) -> Iterator[Optional[List[int]]]:
38-
if not (self.path / NEIGHBOURS_FILE).exists():
38+
if not (self.path / self.NEIGHBOURS_FILE).exists():
3939
while True:
4040
yield None
4141

42-
with open(self.path / NEIGHBOURS_FILE, "r") as json_fp:
42+
with open(self.path / self.NEIGHBOURS_FILE, "r") as json_fp:
4343
for json_line in json_fp:
4444
line = json.loads(json_line)
4545
yield line
4646

4747
def read_query_vectors(self) -> Iterator[List[float]]:
48-
with open(self.path / QUERIES_FILE, "r") as json_fp:
48+
with open(self.path / self.QUERIES_FILE, "r") as json_fp:
4949
for json_line in json_fp:
5050
vector = json.loads(json_line)
5151
if self.normalize:

datasets/datasets.json

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,149 @@
3939
"path": "gist-960-angular/gist-960-angular.hdf5",
4040
"link": "http://ann-benchmarks.com/gist-960-euclidean.hdf5"
4141
},
42+
{
43+
"name": "h-and-m-2048-angular-filters",
44+
"vector_size": 2048,
45+
"distance": "cosine",
46+
"type": "tar",
47+
"path": "h-and-m-2048-angular/hnm",
48+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/hnm.tgz",
49+
"schema": {
50+
"product_code": "int",
51+
"prod_name": "keyword",
52+
"product_type_no": "int",
53+
"product_type_name": "keyword",
54+
"product_group_name": "keyword",
55+
"graphical_appearance_no": "int",
56+
"graphical_appearance_name": "keyword",
57+
"colour_group_code": "int",
58+
"colour_group_name": "keyword",
59+
"perceived_colour_value_id": "int",
60+
"perceived_colour_value_name": "keyword",
61+
"perceived_colour_master_id": "int",
62+
"perceived_colour_master_name": "keyword",
63+
"department_no": "int",
64+
"department_name": "keyword",
65+
"index_code": "keyword",
66+
"index_name": "keyword",
67+
"index_group_no": "int",
68+
"index_group_name": "keyword",
69+
"section_no": "int",
70+
"section_name": "keyword",
71+
"garment_group_no": "int",
72+
"garment_group_name": "keyword",
73+
"detail_desc": "text"
74+
}
75+
},
76+
{
77+
"name": "arxiv-titles-384-angular-filters",
78+
"vector_size": 384,
79+
"distance": "cosine",
80+
"type": "tar",
81+
"path": "arxiv-titles-384-angular/arxiv",
82+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/arxiv.tar.gz",
83+
"schema": {
84+
"update_date_ts": "int",
85+
"labels": "keyword",
86+
"submitter": "keyword"
87+
}
88+
},
89+
{
90+
"name": "random-match-keyword-100-angular-filters",
91+
"vector_size": 100,
92+
"distance": "cosine",
93+
"type": "tar",
94+
"path": "random-match-keyword-100-angular/random_keywords_1m",
95+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_1m.tgz",
96+
"schema": {
97+
"a": "keyword",
98+
"b": "keyword"
99+
}
100+
},
101+
{
102+
"name": "random-match-int-100-angular-filters",
103+
"vector_size": 100,
104+
"distance": "cosine",
105+
"type": "tar",
106+
"path": "random-match-int-100-angular/random_ints_1m",
107+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_ints_1m.tgz",
108+
"schema": {
109+
"a": "int",
110+
"b": "int"
111+
}
112+
},
113+
{
114+
"name": "random-range-100-angular-filters",
115+
"vector_size": 100,
116+
"distance": "cosine",
117+
"type": "tar",
118+
"path": "random-range-100-angular/random_float_1m",
119+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_float_1m.tgz",
120+
"schema": {
121+
"a": "float",
122+
"b": "float"
123+
}
124+
},
125+
{
126+
"name": "random-geo-radius-100-angular-filters",
127+
"vector_size": 100,
128+
"distance": "cosine",
129+
"type": "tar",
130+
"path": "random-geo-radius-100-angular/random_geo_1m",
131+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_geo_1m.tgz",
132+
"schema": {
133+
"a": "geo",
134+
"b": "geo"
135+
}
136+
},
137+
{
138+
"name": "random-match-keyword-2048-angular-filters",
139+
"vector_size": 2048,
140+
"distance": "cosine",
141+
"type": "tar",
142+
"path": "random-match-keyword-2048-angular/random_keywords_100k",
143+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_100k.tgz",
144+
"schema": {
145+
"a": "keyword",
146+
"b": "keyword"
147+
}
148+
},
149+
{
150+
"name": "random-match-int-2048-angular-filters",
151+
"vector_size": 2048,
152+
"distance": "cosine",
153+
"type": "tar",
154+
"path": "random-match-int-2048-angular/random_ints_100k",
155+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_ints_100k.tgz",
156+
"schema": {
157+
"a": "int",
158+
"b": "int"
159+
}
160+
},
161+
{
162+
"name": "random-range-2048-angular-filters",
163+
"vector_size": 2048,
164+
"distance": "cosine",
165+
"type": "tar",
166+
"path": "random-range-2048-angular/random_float_100k",
167+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_float_100k.tgz",
168+
"schema": {
169+
"a": "float",
170+
"b": "float"
171+
}
172+
},
173+
{
174+
"name": "random-geo-radius-2048-angular-filters",
175+
"vector_size": 2048,
176+
"distance": "cosine",
177+
"type": "tar",
178+
"path": "random-geo-radius-2048-angular/random_geo_100k",
179+
"link": "https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_geo_100k.tgz",
180+
"schema": {
181+
"a": "geo",
182+
"b": "geo"
183+
}
184+
},
42185
{
43186
"name": "random-100",
44187
"vector_size": 100,

engine/base_client/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ def save_search_results(
3333
experiments_file = (
3434
f"{self.name}-{dataset_name}-search-{search_id}-{timestamp}.json"
3535
)
36-
with open(RESULTS_DIR / experiments_file, "w") as out:
36+
result_path = RESULTS_DIR / experiments_file
37+
with open(result_path, "w") as out:
3738
out.write(
3839
json.dumps({"params": search_params, "results": results}, indent=2)
3940
)
41+
return result_path
4042

4143
def save_upload_results(
4244
self, dataset_name: str, results: dict, upload_params: dict
@@ -60,10 +62,7 @@ def run_experiment(self, dataset: Dataset, skip_upload: bool = False):
6062

6163
if not skip_upload:
6264
print("Experiment stage: Configure")
63-
self.configurator.configure(
64-
distance=dataset.config.distance,
65-
vector_size=dataset.config.vector_size,
66-
)
65+
self.configurator.configure(dataset)
6766

6867
print("Experiment stage: Upload")
6968
upload_stats = self.uploader.upload(
@@ -88,3 +87,4 @@ def run_experiment(self, dataset: Dataset, skip_upload: bool = False):
8887
dataset.config.name, search_stats, search_id, search_params
8988
)
9089
print("Experiment stage: Done")
90+
print("Results saved to: ", RESULTS_DIR)

engine/base_client/configure.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional
22

3+
from benchmark.dataset import Dataset
4+
35

46
class BaseConfigurator:
57
DISTANCE_MAPPING = {}
@@ -12,12 +14,12 @@ def __init__(self, host, collection_params: dict, connection_params: dict):
1214
def clean(self):
1315
raise NotImplementedError()
1416

15-
def recreate(self, distance, vector_size, collection_params):
17+
def recreate(self, dataset: Dataset, collection_params):
1618
raise NotImplementedError()
1719

18-
def configure(self, distance, vector_size) -> Optional[dict]:
20+
def configure(self, dataset: Dataset) -> Optional[dict]:
1921
self.clean()
20-
return self.recreate(distance, vector_size, self.collection_params) or {}
22+
return self.recreate(dataset, self.collection_params) or {}
2123

2224
def execution_params(self, distance, vector_size) -> dict:
2325
return {}

0 commit comments

Comments
 (0)