Skip to content

Commit 5faa4cb

Browse files
committed
start weaviate
1 parent 555b4ae commit 5faa4cb

File tree

19 files changed

+218
-34
lines changed

19 files changed

+218
-34
lines changed

benchmark/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import os
12
from pathlib import Path
23

34
# Base directory point to the main directory of the project, so all the data
45
# loaded from files can refer to it as a root directory
6+
57
BASE_DIRECTORY = Path(__file__).parent.parent
68
DATASETS_DIR = BASE_DIRECTORY / "datasets"
9+
CODE_DIR = os.path.dirname(__file__)
10+
ROOT_DIR = Path(os.path.dirname(CODE_DIR))

benchmark/settings.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

dataset_reader/ann_h5_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import h5py
44

5+
from benchmark import DATASETS_DIR
56
from dataset_reader.base_reader import BaseReader, Record, Query
67

78

@@ -33,7 +34,6 @@ def read_data(self) -> Iterator[Record]:
3334

3435
if __name__ == '__main__':
3536
import os
36-
from benchmark.settings import DATASET_DIR
3737

3838
# h5py file 4 keys:
3939
# `train` - float vectors (num vectors 1183514)
@@ -42,7 +42,7 @@ def read_data(self) -> Iterator[Record]:
4242
# contains info about 100 nearest neighbors)
4343
# `distances` - float - distances for nearest neighbors for test vectors
4444

45-
test_path = os.path.join(DATASET_DIR, 'glove-100-angular', 'glove-100-angular.hdf5')
45+
test_path = os.path.join(DATASETS_DIR, 'glove-100-angular', 'glove-100-angular.hdf5')
4646
record = next(AnnH5Reader(test_path).read_data())
4747
print(record, end='\n\n')
4848

engine/base_client/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List
44

55
from benchmark.dataset import Dataset
6-
from benchmark.settings import ROOT_DIR
6+
from benchmark import ROOT_DIR
77
from engine.base_client.configure import BaseConfigurator
88
from engine.base_client.search import BaseSearcher
99
from engine.base_client.upload import BaseUploader

engine/base_client/search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def search_all(
5454
parallel = self.search_params.pop("parallel", 1)
5555
top = self.search_params.pop("top", None)
5656

57+
self.setup_search()
58+
5759
search_one = functools.partial(self.__class__._search_one, top=top)
5860

5961
if parallel == 1:
@@ -88,3 +90,6 @@ def search_all(
8890

8991
def set_process_start_method(self, start_method):
9092
self.MP_CONTEXT = start_method
93+
94+
def setup_search(self):
95+
pass

engine/clients/client_factory.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from engine.clients.qdrant.search import QdrantSearcher
1010
from engine.clients.qdrant.upload import QdrantUploader
1111

12-
1312
ENGINE_CONFIGURATORS = {
1413
"qdrant": QdrantConfigurator,
1514
}
@@ -31,17 +30,17 @@ def _create_configurator(self, experiment) -> BaseConfigurator:
3130
engine_configurator_class = ENGINE_CONFIGURATORS[experiment["engine"]]
3231
engine_configurator = engine_configurator_class(
3332
self.host,
34-
experiment.get("collection_params", {}),
35-
experiment.get("connection_params", {}),
33+
collection_params={**experiment.get("collection_params", {})},
34+
connection_params={**experiment.get("connection_params", {})},
3635
)
3736
return engine_configurator
3837

3938
def _create_uploader(self, experiment) -> BaseUploader:
4039
engine_uploader_class = ENGINE_UPLOADERS[experiment["engine"]]
4140
engine_uploader = engine_uploader_class(
4241
self.host,
43-
experiment.get("connection_params", {}),
44-
experiment.get("upload_params", {}),
42+
connection_params={**experiment.get("connection_params", {})},
43+
upload_params={**experiment.get("upload_params", {})},
4544
)
4645
return engine_uploader
4746

@@ -51,7 +50,7 @@ def _create_searchers(self, experiment) -> List[BaseSearcher]:
5150
engine_searchers = [
5251
engine_searcher_class(
5352
self.host,
54-
connection_params=experiment.get("connection_params", {}),
53+
connection_params={**experiment.get("connection_params", {})},
5554
search_params=search_params,
5655
)
5756
for search_params in experiment.get("search_params", [{}])

engine/clients/qdrant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
QDRANT_COLLECTION_NAME = "benchmark_collection"

engine/clients/qdrant/config.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

engine/clients/qdrant/configure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from engine.base_client.configure import BaseConfigurator
55
from engine.base_client.distances import Distance
6-
from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME
6+
from engine.clients.qdrant import QDRANT_COLLECTION_NAME
77

88

99
class QdrantConfigurator(BaseConfigurator):

engine/clients/qdrant/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from qdrant_client.http import models as rest
55

66
from engine.base_client.search import BaseSearcher
7-
from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME
7+
from engine.clients.qdrant import QDRANT_COLLECTION_NAME
88

99

1010
class QdrantSearcher(BaseSearcher):

engine/clients/qdrant/upload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from qdrant_client.http.models import Batch, CollectionStatus
66

77
from engine.base_client.upload import BaseUploader
8-
from engine.clients.qdrant.config import QDRANT_COLLECTION_NAME
8+
from engine.clients.qdrant import QDRANT_COLLECTION_NAME
99

1010

1111
class QdrantUploader(BaseUploader):

engine/clients/weaviate/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
WEAVIATE_CLASS_NAME = 'benchmark'
2+
WEAVIATE_DEFAULT_PORT = 8080

engine/clients/weaviate/configure.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from weaviate import Client
2+
3+
from engine.base_client.configure import BaseConfigurator
4+
from engine.base_client.distances import Distance
5+
from engine.clients.weaviate import WEAVIATE_CLASS_NAME, WEAVIATE_DEFAULT_PORT
6+
7+
8+
class WeaviateConfigurator(BaseConfigurator):
9+
DISTANCE_MAPPING = {
10+
Distance.L2_SQUARED: "l2-squared",
11+
Distance.COSINE: "cosine",
12+
Distance.DOT: "dot",
13+
}
14+
15+
def __init__(self, host, collection_params: dict, connection_params: dict):
16+
super().__init__(host, collection_params, connection_params)
17+
url = f"http://{host}:{connection_params.pop('port', WEAVIATE_DEFAULT_PORT)}"
18+
self.client = Client(url, **connection_params)
19+
20+
def clean(self):
21+
classes = self.client.schema.get()
22+
for cl in classes["classes"]:
23+
if cl["class"] == WEAVIATE_CLASS_NAME:
24+
self.client.schema.delete_class(WEAVIATE_CLASS_NAME)
25+
26+
def recreate(
27+
self, distance, vector_size, collection_params,
28+
):
29+
self.client.schema.create_class({
30+
"class": WEAVIATE_CLASS_NAME,
31+
"vectorizer": "none",
32+
"properties": [],
33+
"vectorIndexConfig": {
34+
**{
35+
"vectorCacheMaxObjects": 1000000000,
36+
"distance": self.DISTANCE_MAPPING.get(distance),
37+
},
38+
**collection_params["vectorIndexConfig"]
39+
},
40+
})

engine/clients/weaviate/search.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Optional, Tuple, List
2+
3+
from weaviate import Client
4+
5+
from engine.base_client.search import BaseSearcher
6+
from engine.clients.qdrant import QDRANT_COLLECTION_NAME
7+
from engine.clients.weaviate import WEAVIATE_DEFAULT_PORT
8+
9+
10+
class QdrantSearcher(BaseSearcher):
11+
search_params = {}
12+
client: Client = None
13+
14+
@classmethod
15+
def init_client(cls, host, connection_params: dict, search_params: dict):
16+
url = f"http://{host}:{connection_params.pop('port', WEAVIATE_DEFAULT_PORT)}"
17+
cls.client = Client(url, **connection_params)
18+
cls.search_params = search_params
19+
20+
@classmethod
21+
def conditions_to_filter(cls, _meta_conditions):
22+
# ToDo: implement
23+
return None
24+
25+
@classmethod
26+
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
27+
top = 10
28+
near_vector = {"vector": vector}
29+
res = (
30+
cls.client.query.get(cls.collection, ["_additional {id certainty}"])
31+
.with_near_vector(near_vector)
32+
.with_limit(top)
33+
.do()
34+
)
35+
res = cls.client.search(
36+
collection_name=QDRANT_COLLECTION_NAME,
37+
query_vector=vector,
38+
query_filter=cls.conditions_to_filter(meta_conditions),
39+
limit=top,
40+
**cls.search_params
41+
)
42+
43+
return [(hit.id, hit.score) for hit in res]

engine/clients/weaviate/upload.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Optional, List
2+
3+
from weaviate import Client
4+
5+
from engine.base_client.upload import BaseUploader
6+
from engine.clients.weaviate import WEAVIATE_DEFAULT_PORT
7+
8+
9+
class WeaviateUploader(BaseUploader):
10+
client = None
11+
upload_params = {}
12+
13+
@classmethod
14+
def init_client(cls, host, connection_params, upload_params):
15+
url = f"http://{host}:{connection_params.pop('port', WEAVIATE_DEFAULT_PORT)}"
16+
cls.client = Client(url, **connection_params)
17+
18+
cls.upload_params = upload_params
19+
cls.connection_params = connection_params
20+
21+
@classmethod
22+
def upload_batch(
23+
cls,
24+
ids: List[int],
25+
vectors: List[list],
26+
metadata: Optional[List[dict]]
27+
):
28+
...
29+
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[
2+
{
3+
"name": "weaviate-default",
4+
"engine": "weaviate",
5+
"connection_params": {},
6+
"collection_params": {
7+
"vectorIndexConfig": {
8+
"ef": 100,
9+
"efConstruction": 100,
10+
"maxConnections": 16
11+
}
12+
},
13+
"search_params": [
14+
{"parallel": 1},
15+
{"parallel": 2},
16+
{"parallel": 4},
17+
{"parallel": 100}
18+
]
19+
}
20+
]

0 commit comments

Comments
 (0)