Skip to content

Commit 5851625

Browse files
authored
Mongo (#6)
1 parent 80e304f commit 5851625

File tree

8 files changed

+221
-8
lines changed

8 files changed

+221
-8
lines changed

labml_db/driver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
class DbDriver:
10-
def __init__(self, serializer: 'Serializer', model_cls: Type['Model']):
10+
def __init__(self, serializer: Optional['Serializer'], model_cls: Type['Model']):
1111
self.model_name = model_cls.__name__
1212
self._serializer = serializer
1313

labml_db/driver/mongo.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import List, Type, TYPE_CHECKING, Optional, Dict
2+
3+
from labml_db.serializer.utils import encode_keys, decode_keys
4+
5+
from . import DbDriver
6+
from ..types import ModelDict
7+
8+
if TYPE_CHECKING:
9+
import pymongo
10+
from ..model import Model
11+
12+
13+
class MongoDbDriver(DbDriver):
14+
def __init__(self, model_cls: Type['Model'], db: 'pymongo.mongo_client.database.Database'):
15+
super().__init__(None, model_cls)
16+
self._collection = db[self.model_name]
17+
18+
def _to_obj_id(self, key: str):
19+
return str(key)
20+
# return ObjectId(key.split(':')[1])
21+
22+
def _to_key(self, mongo_id: any):
23+
# self.model_name + ':' + str(d['_id']
24+
return mongo_id
25+
26+
def _load_data(self, data: ModelDict):
27+
if data is None:
28+
return None
29+
del data['_id']
30+
return decode_keys(data)
31+
32+
def _dump_data(self, key: str, data: ModelDict):
33+
d: Dict = data.copy()
34+
d = encode_keys(d)
35+
d['_id'] = self._to_obj_id(key)
36+
37+
return d
38+
39+
def mload_dict(self, key: List[str]) -> List[Optional[ModelDict]]:
40+
obj_keys = [self._to_obj_id(k) for k in key]
41+
cursor = self._collection.find({'_id': {'$in': obj_keys}})
42+
res = [None for _ in key]
43+
idx = {k: i for i, k in enumerate(obj_keys)}
44+
for d in cursor:
45+
i = idx[d['_id']]
46+
res[i] = self._load_data(d)
47+
48+
return res
49+
50+
def load_dict(self, key: str) -> Optional[ModelDict]:
51+
d = self._collection.find_one({'_id': self._to_obj_id(key)})
52+
return self._load_data(d)
53+
54+
def msave_dict(self, key: List[str], data: List[ModelDict]):
55+
objs = [self._dump_data(k, d) for k, d in zip(key, data)]
56+
57+
from pymongo import ReplaceOne
58+
replacements = [ReplaceOne({'_id': d['_id']}, d, True) for d in objs]
59+
self._collection.bulk_write(replacements, False)
60+
61+
def save_dict(self, key: str, data: ModelDict):
62+
obj = self._dump_data(key, data)
63+
64+
self._collection.replace_one({'_id': obj['_id']}, obj, True)
65+
66+
def delete(self, key: str):
67+
self._collection.delete_one({'_id': self._to_obj_id(key)})
68+
69+
def get_all(self) -> List[str]:
70+
cur = self._collection.find(projection=['_id'])
71+
keys = [self._to_key(d['_id']) for d in cur]
72+
return keys

labml_db/index_driver/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Type, TYPE_CHECKING, List
1+
from typing import Type, TYPE_CHECKING, List, Optional
22

33
if TYPE_CHECKING:
44
from . import Index
@@ -11,10 +11,10 @@ def __init__(self, index_cls: Type['Index']):
1111
def delete(self, index_key: str):
1212
raise NotImplementedError
1313

14-
def get(self, index_key: str) -> str:
14+
def get(self, index_key: str) -> Optional[str]:
1515
raise NotImplementedError
1616

17-
def mget(self, index_key: List[str]) -> List[str]:
17+
def mget(self, index_key: List[str]) -> List[Optional[str]]:
1818
raise NotImplementedError
1919

2020
def set(self, index_key: str, model_key: str):

labml_db/index_driver/mongo.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from typing import Dict, Type, Optional, TYPE_CHECKING, List
2+
3+
from . import IndexDbDriver
4+
5+
if TYPE_CHECKING:
6+
import pymongo
7+
from .. import Index
8+
9+
10+
class MongoIndexDbDriver(IndexDbDriver):
11+
_cache: Optional[Dict[str, str]]
12+
13+
def __init__(self, index_cls: Type['Index'], db: 'pymongo.mongo_client.database.Database'):
14+
super().__init__(index_cls)
15+
self._index = db[f'_index_{self.index_name}']
16+
17+
def delete(self, index_key: str):
18+
self._index.delete_one({'_id': index_key})
19+
20+
def get(self, index_key: str) -> Optional[str]:
21+
d = self._index.find_one({'_id': index_key})
22+
if d is None:
23+
return None
24+
return d['value']
25+
26+
def mget(self, index_key: List[str]) -> List[Optional[str]]:
27+
cursor = self._index.find({'_id': {'$in': index_key}})
28+
res = [None for _ in index_key]
29+
idx = {k: i for i, k in enumerate(index_key)}
30+
for d in cursor:
31+
i = idx[d['_id']]
32+
res[i] = d['value']
33+
34+
return res
35+
36+
def set(self, index_key: str, model_key: str):
37+
self._index.replace_one({'_id': index_key}, {'_id': index_key, 'value': model_key}, True)
38+
39+
def get_all(self):
40+
cur = self._index.find()
41+
return [d['value'] for d in cur]

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55

66
setuptools.setup(
77
name='labml_db',
8-
version='0.0.14',
8+
version='0.0.15',
99
author="Varuna Jayasiri, Nipun Wijerathne",
1010
11-
description="Minimalistic ORM for JSON/YAML/Pickle file based DB",
11+
description="Minimalistic ORM for JSON/YAML/Pickle file based/redis/mongo DB",
1212
long_description=long_description,
1313
long_description_content_type="text/x-rst",
14-
url="https://github.com/lab-ml/db",
14+
url="https://github.com/labmlai/db",
1515
project_urls={
16-
'Documentation': 'https://lab-ml.com/'
16+
'Documentation': 'https://labml.ai'
1717
},
1818
packages=setuptools.find_packages(exclude=('test',
1919
'test.*')),

test/mongo/__init__.py

Whitespace-only changes.

test/mongo/api_check.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from labml import monit
2+
from labml.logger import inspect
3+
from pymongo import MongoClient
4+
5+
CONNECTION_STRING = "mongodb://localhost:27017/"
6+
7+
8+
def get_database():
9+
client = MongoClient(CONNECTION_STRING)
10+
11+
inspect(client)
12+
# inspect(client.server_info())
13+
inspect(client.list_database_names())
14+
15+
# client.drop_database('papers')
16+
17+
return client['papers']
18+
19+
20+
def test():
21+
db = get_database()
22+
inspect(db)
23+
collection = db['papers']
24+
inspect(collection)
25+
26+
with monit.section('Insert'):
27+
paper_1 = {
28+
"_id": "abcd",
29+
"title": "Attention is all you need",
30+
"abstract": "abstract"
31+
}
32+
33+
collection.insert_many([paper_1])
34+
35+
with monit.section('Find'):
36+
item_details = collection.find()
37+
for item in item_details:
38+
inspect(type(item))
39+
inspect(item)
40+
41+
42+
if __name__ == "__main__":
43+
test()

test/mongo/simple.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from pymongo import MongoClient
2+
3+
from labml_db import Model, Index
4+
from labml_db.driver.mongo import MongoDbDriver
5+
from labml_db.index_driver.mongo import MongoIndexDbDriver
6+
from test.simple import User, Project, UsernameIndex
7+
8+
9+
def test_setup():
10+
db = MongoClient("mongodb://localhost:27017/")['testdb']
11+
Model.set_db_drivers([
12+
MongoDbDriver(User, db),
13+
MongoDbDriver(Project, db)
14+
])
15+
Index.set_db_drivers([
16+
MongoIndexDbDriver(UsernameIndex, db)
17+
])
18+
19+
20+
def test():
21+
proj = Project(name='nlp')
22+
user = User(name='John')
23+
user.projects.append(proj.key)
24+
user.occupation = 'test'
25+
user2 = User(name='X')
26+
print(user.projects, user2.projects)
27+
user.save()
28+
proj.save()
29+
30+
print(user.projects[0].load().name)
31+
32+
33+
def test_load():
34+
keys = User.get_all()
35+
print([k.load() for k in keys])
36+
keys = Project.get_all()
37+
print([k.load().name for k in keys])
38+
39+
40+
def test_index():
41+
user_key = UsernameIndex.get('John')
42+
if user_key:
43+
print('index', user_key.load())
44+
user_key.delete()
45+
46+
user = User(name='John')
47+
user.save()
48+
UsernameIndex.set(user.name, user.key)
49+
50+
print(user.key, user.name, user.projects)
51+
52+
53+
if __name__ == '__main__':
54+
test_setup()
55+
test()
56+
test_load()
57+
test_index()

0 commit comments

Comments
 (0)