Skip to content

Commit 4b9d73a

Browse files
Add mongo storage
1 parent 24419f6 commit 4b9d73a

File tree

3 files changed

+69
-6
lines changed

3 files changed

+69
-6
lines changed

main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515
from src.models import OpenAIModel
1616
from src.memory import Memory
1717
from src.logger import logger
18-
from src.storage import Storage
18+
from src.storage import Storage, FileStorage, MongoStorage
1919
from src.utils import get_role_and_content
2020
from src.service.youtube import Youtube, YoutubeTranscriptReader
2121
from src.service.website import Website, WebsiteReader
22+
from src.mongodb import mongodb
2223

2324
load_dotenv('.env')
2425

2526
app = Flask(__name__)
2627
line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN'))
2728
handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET'))
28-
storage = Storage('db.json')
29+
storage = None
2930
youtube = Youtube(step=4)
3031
website = Website()
3132

@@ -62,8 +63,9 @@ def handle_text_message(event):
6263
if not is_successful:
6364
raise ValueError('Invalid API token')
6465
model_management[user_id] = model
65-
api_keys[user_id] = api_key
66-
storage.save(api_keys)
66+
storage.save({
67+
user_id: api_key
68+
})
6769
msg = TextSendMessage(text='Token 有效,註冊成功')
6870

6971
elif text.startswith('/指令說明'):
@@ -180,6 +182,11 @@ def home():
180182

181183

182184
if __name__ == "__main__":
185+
if os.getenv('USE_MONGO'):
186+
mongodb.connect_to_database()
187+
storage = Storage(MongoStorage(mongodb.db))
188+
else:
189+
storage = Storage(FileStorage('db.json'))
183190
try:
184191
data = storage.load()
185192
for user_id in data.keys():

src/mongodb.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
3+
from pymongo import MongoClient
4+
5+
6+
class MongoDB():
7+
"""
8+
Environment Variables:
9+
MONGODB__PATH
10+
MONGODB__DBNAME
11+
"""
12+
client: None
13+
db: None
14+
15+
def connect_to_database(self, mongo_path=None, db_name=None):
16+
mongo_path = mongo_path or os.getenv('MONGODB__PATH')
17+
db_name = db_name or os.getenv('MONGODB__DBNAME')
18+
self.client = MongoClient(mongo_path)
19+
assert self.client.config.command('ping')['ok'] == 1.0
20+
self.db = self.client[db_name]
21+
22+
23+
mongodb = MongoDB()

src/storage.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,48 @@
11
import json
22

33

4-
class Storage():
4+
class FileStorage:
55
def __init__(self, file_name):
66
self.fine_name = file_name
77

88
def save(self, data):
9-
with open(self.fine_name, 'w', newline='') as f:
9+
with open(self.fine_name, 'a+', newline='') as f:
1010
json.dump(data, f)
1111

1212
def load(self):
1313
with open(self.fine_name, newline='') as jsonfile:
1414
data = json.load(jsonfile)
1515
return data
16+
17+
18+
class MongoStorage:
19+
def __init__(self, db):
20+
self.db = db
21+
22+
def save(self, data):
23+
self.db['api_key'].update_one({
24+
'user_id': data.get('user_id')
25+
}, {
26+
'$set': {
27+
'user_id': data.get('user_id'),
28+
'api_key': data.get('api_key'),
29+
}
30+
}, upsert=True)
31+
32+
def load(self):
33+
data = list(self.db['api_key'].find())
34+
res = {}
35+
for i in range(len(data)):
36+
res[data[i]['user_id']] = data[i]['api_key']
37+
return res
38+
39+
40+
class Storage:
41+
def __init__(self, storage):
42+
self.storage = storage
43+
44+
def save(self, data):
45+
self.storage.save(data)
46+
47+
def load(self):
48+
return self.storage.load()

0 commit comments

Comments
 (0)