From 3247d408ee6bfeeae531e0221a8b61238babfd67 Mon Sep 17 00:00:00 2001 From: wayyoung <1017761807@qq.com> Date: Sun, 30 Mar 2025 20:40:38 +0800 Subject: [PATCH 1/6] support oceanbase --- main/xiaozhi-server/config.yaml | 6 ++ .../providers/memory/oceanbase/oceanbase.py | 75 +++++++++++++++++++ main/xiaozhi-server/requirements.txt | 2 +- 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py diff --git a/main/xiaozhi-server/config.yaml b/main/xiaozhi-server/config.yaml index e2c35e670..7e4b4b455 100644 --- a/main/xiaozhi-server/config.yaml +++ b/main/xiaozhi-server/config.yaml @@ -161,6 +161,12 @@ Memory: mem_local_short: # 本地记忆功能,通过selected_module的llm总结,数据保存在本地,不会上传到服务器 type: mem_local_short + oceanbase: + uri: 127.0.0.1:2881 + user: root@sys + password: root + database: xiaozhi + table: xiaozhi ASR: FunASR: diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py new file mode 100644 index 000000000..892c46ff6 --- /dev/null +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -0,0 +1,75 @@ +import traceback +from pyobvector import MilvusLikeClient +from ..base import MemoryProviderBase, logger + + +TAG = __name__ + +class MemoryProvider(MemoryProviderBase): + def __init__(self, config): + super().__init__(config) + self.uri = config.get("uri", "") + self.user = config.get("user", "root@xiaozhi") + self.password = config.get("password", "root") + self.db_name = config.get("database", "xiaozhi") + self.table_name = config.get("table_name", "xiaozhi") + logger.bind(tag=TAG).info(f"连接到 oceanbase 服务: {self.uri}") + self.client = self.connect_to_client() + + def connect_to_client(self): + try: + return MilvusLikeClient(uri=self.uri, user=self.user,password=self.password, db_name=self.db_name) + except Exception as e: + logger.bind(tag=TAG).error(f"连接到 oceanbase 服务时发生错误: {str(e)}") + logger.bind(tag=TAG).error(f"详细错误: {traceback.format_exc()}") + return None + + def init_memory(self, role_id, llm): + super().init_memory(role_id, llm) + + async def save_memory(self, msgs): + if not self.client or len(msgs) < 2: + return None + + try: + messages = [{"role": message.role, "content": message.content} for message in msgs if message.role != "system"] + for i in range(0, len(messages), 1): + self.client.insert(collection_name=self.table_name, data=messages[i:i+1]) + logger.bind(tag=TAG).debug("Save memory") + except Exception as e: + logger.bind(tag=TAG).error(f"保存记忆失败: {str(e)}") + return None + + async def query_memory(self, query: str) -> str: + if not self.client: + return "" + + try: + results = self.client.search( + collection_name=self.table_name, + data=[query], + anns_field="embedding", + limit=5, + output_fields=["id", "metadata"], + ) + memories = self.format_memories(results) + return "\n".join(f"- {memory[1]}" for memory in memories) + except Exception as e: + logger.bind(tag=TAG).error(f"查询记忆失败: {str(e)}") + return "" + + def format_memories(self, results): + memories = [] + for entry in results['results']: + timestamp = entry.get('updated_at', '') + if timestamp: + try: + dt = timestamp.split('.')[0] + formatted_time = dt.replace('T', ' ') + except: + formatted_time = timestamp + memory = entry.get('memory', '') + if timestamp and memory: + memories.append((timestamp, f"[{formatted_time}] {memory}")) + memories.sort(key=lambda x: x[0], reverse=True) + return memories \ No newline at end of file diff --git a/main/xiaozhi-server/requirements.txt b/main/xiaozhi-server/requirements.txt index 164d594d5..d0c29509c 100755 --- a/main/xiaozhi-server/requirements.txt +++ b/main/xiaozhi-server/requirements.txt @@ -23,5 +23,5 @@ bs4==0.0.2 modelscope==1.23.2 sherpa_onnx==1.11.0 mcp==1.4.1 - +pyobvector==0.2.4 cnlunar==0.2.0 From 10872b0e69df878aab1374c8c07097ccff8fc833 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A0=E7=A3=8A?= Date: Mon, 7 Apr 2025 19:06:19 +0800 Subject: [PATCH 2/6] support memory oceanbase --- .../providers/memory/oceanbase/oceanbase.py | 55 ++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py index 892c46ff6..0a87da061 100644 --- a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -1,61 +1,90 @@ +import os.path import traceback from pyobvector import MilvusLikeClient -from ..base import MemoryProviderBase, logger +from sentence_transformers import SentenceTransformer +from core.providers.memory.base import MemoryProviderBase TAG = __name__ +create_table_sql = """ +CREATE TABLE xiaozhi( + id INT AUTO_INCREMENT PRIMARY KEY, + role VARCHAR(200), + content VARCHAR(200), + embedding VECTOR(384), + VECTOR INDEX idx1(embedding) WITH (distance=L2, type=hnsw) + ); +""" + + class MemoryProvider(MemoryProviderBase): def __init__(self, config): super().__init__(config) - self.uri = config.get("uri", "") - self.user = config.get("user", "root@xiaozhi") - self.password = config.get("password", "root") + self.uri = config.get("uri", "localhost") + self.user = config.get("user", "root@test") + self.password = config.get("password", "") self.db_name = config.get("database", "xiaozhi") self.table_name = config.get("table_name", "xiaozhi") - logger.bind(tag=TAG).info(f"连接到 oceanbase 服务: {self.uri}") + self.model_path = config.get("model_path", "model/all-MiniLM-L6-v2") + if not os.path.exists(self.model_path): + raise Exception(f"模型路径不存在,请下载到: {self.model_path}") + print(f"连接到 oceanbase 服务: {self.uri}") self.client = self.connect_to_client() def connect_to_client(self): try: return MilvusLikeClient(uri=self.uri, user=self.user,password=self.password, db_name=self.db_name) except Exception as e: - logger.bind(tag=TAG).error(f"连接到 oceanbase 服务时发生错误: {str(e)}") - logger.bind(tag=TAG).error(f"详细错误: {traceback.format_exc()}") + print(f"连接到 oceanbase 服务时发生错误: {str(e)}") + print(f"请检查配置并确认表是否存在: 初始化sql: {create_table_sql}") + print(f"详细错误: {traceback.format_exc()}") return None + def _string_to_embeddings(self, sentences): + # 加载预训练的 'all-MiniLM-L6-v2' 模型 + model = SentenceTransformer(self.model_path ) + embeddings = model.encode(sentences) + return embeddings + def init_memory(self, role_id, llm): super().init_memory(role_id, llm) + pass async def save_memory(self, msgs): if not self.client or len(msgs) < 2: return None try: - messages = [{"role": message.role, "content": message.content} for message in msgs if message.role != "system"] + messages = [{"role": message.role, "content": message.content, + "embedding": self._string_to_embeddings(message.content)} for message in msgs if message.role != "system"] for i in range(0, len(messages), 1): self.client.insert(collection_name=self.table_name, data=messages[i:i+1]) - logger.bind(tag=TAG).debug("Save memory") + print(f"Save memory") except Exception as e: - logger.bind(tag=TAG).error(f"保存记忆失败: {str(e)}") + print(f"保存记忆失败: {str(e)}") return None async def query_memory(self, query: str) -> str: if not self.client: return "" + # 把 query 向量化 + query = self._string_to_embeddings(query) + try: results = self.client.search( collection_name=self.table_name, - data=[query], + data=query, anns_field="embedding", limit=5, - output_fields=["id", "metadata"], + output_fields=["role", "content"] ) + return results memories = self.format_memories(results) return "\n".join(f"- {memory[1]}" for memory in memories) except Exception as e: - logger.bind(tag=TAG).error(f"查询记忆失败: {str(e)}") + print(f"查询记忆失败: {str(e)}") return "" def format_memories(self, results): From d9337657e15c0da1130eac9a10ffab012dbd5fef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A0=E7=A3=8A?= Date: Tue, 8 Apr 2025 20:22:40 +0800 Subject: [PATCH 3/6] support memory oceanbase --- main/xiaozhi-server/config.yaml | 7 ++++--- .../core/providers/memory/oceanbase/oceanbase.py | 16 ++++++++++++---- main/xiaozhi-server/requirements.txt | 2 ++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/main/xiaozhi-server/config.yaml b/main/xiaozhi-server/config.yaml index ddc5ff05c..4a92e44b1 100644 --- a/main/xiaozhi-server/config.yaml +++ b/main/xiaozhi-server/config.yaml @@ -196,11 +196,12 @@ Memory: # 本地记忆功能,通过selected_module的llm总结,数据保存在本地,不会上传到服务器 type: mem_local_short oceanbase: - uri: 127.0.0.1:2881 - user: root@sys - password: root + uri: "*****:**" + user: "***" + password: "***" database: xiaozhi table: xiaozhi + model_path: "models/all-MiniLM-L6-v2" ASR: FunASR: diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py index 0a87da061..ee56bbfd9 100644 --- a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -11,7 +11,7 @@ CREATE TABLE xiaozhi( id INT AUTO_INCREMENT PRIMARY KEY, role VARCHAR(200), - content VARCHAR(200), + content text, embedding VECTOR(384), VECTOR INDEX idx1(embedding) WITH (distance=L2, type=hnsw) ); @@ -26,7 +26,7 @@ def __init__(self, config): self.password = config.get("password", "") self.db_name = config.get("database", "xiaozhi") self.table_name = config.get("table_name", "xiaozhi") - self.model_path = config.get("model_path", "model/all-MiniLM-L6-v2") + self.model_path = config.get("model_path", "models/all-MiniLM-L6-v2") if not os.path.exists(self.model_path): raise Exception(f"模型路径不存在,请下载到: {self.model_path}") print(f"连接到 oceanbase 服务: {self.uri}") @@ -56,12 +56,20 @@ async def save_memory(self, msgs): return None try: - messages = [{"role": message.role, "content": message.content, - "embedding": self._string_to_embeddings(message.content)} for message in msgs if message.role != "system"] + + messages =[] + for message in msgs: + if message.role != "system": + if message.content: + messages.append({"role": message.role, "content": message.content, + "embedding": self._string_to_embeddings(message.content)}) + for i in range(0, len(messages), 1): self.client.insert(collection_name=self.table_name, data=messages[i:i+1]) print(f"Save memory") except Exception as e: + import traceback + print(traceback.format_exc()) print(f"保存记忆失败: {str(e)}") return None diff --git a/main/xiaozhi-server/requirements.txt b/main/xiaozhi-server/requirements.txt index 81f16caee..f85d6c842 100755 --- a/main/xiaozhi-server/requirements.txt +++ b/main/xiaozhi-server/requirements.txt @@ -26,3 +26,5 @@ mcp==1.4.1 pyobvector==0.2.4 cnlunar==0.2.0 PySocks==1.7.1 +sentence-transformers==4.0.1 +transformers==4.50.3 From 5c2242d553ff7a6ae3dc2170664b478066d1a3b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A0=E7=A3=8A?= Date: Tue, 8 Apr 2025 20:23:22 +0800 Subject: [PATCH 4/6] support memory oceanbase --- main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py | 1 + 1 file changed, 1 insertion(+) diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py index ee56bbfd9..95c6294ba 100644 --- a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -27,6 +27,7 @@ def __init__(self, config): self.db_name = config.get("database", "xiaozhi") self.table_name = config.get("table_name", "xiaozhi") self.model_path = config.get("model_path", "models/all-MiniLM-L6-v2") + self.model_path = os.path.abspath(self.model_path) if not os.path.exists(self.model_path): raise Exception(f"模型路径不存在,请下载到: {self.model_path}") print(f"连接到 oceanbase 服务: {self.uri}") From f63730b30f84e52c786b69104ea48067ac580699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A0=E7=A3=8A?= Date: Thu, 10 Apr 2025 15:25:15 +0800 Subject: [PATCH 5/6] add oceanbase --- docker-setup.sh | 11 +++++++++++ docs/Deployment.md | 11 +++++++++++ .../core/providers/memory/oceanbase/oceanbase.py | 2 +- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/docker-setup.sh b/docker-setup.sh index 04ea547a3..2067ef39d 100755 --- a/docker-setup.sh +++ b/docker-setup.sh @@ -74,6 +74,17 @@ else $DOWNLOAD_CMD "data/.config.yaml" "https://raw.githubusercontent.com/xinnan-tech/xiaozhi-esp32-server/main/main/xiaozhi-server/config.yaml" fi +# 下载量化模型 +echo "下载量化模型..." +mkdir -p models/all-MiniLM-L6-v2 +if [ "$DOWNLOAD_CMD" = "powershell -Command Invoke-WebRequest -Uri" ]; then + $DOWNLOAD_CMD "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip" $DOWNLOAD_CMD_SUFFIX "models/all-MiniLM-L6-v2/all-MiniLM-L6-v2.zip" + +else + $DOWNLOAD_CMD "models/all-MiniLM-L6-v2.zip" "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip" +fi +unzip models/all-MiniLM-L6-v2.zip -d models/all-MiniLM-L6-v2 + # 检查文件是否存在 echo "检查文件完整性..." FILES_TO_CHECK="docker-compose.yml data/.config.yaml models/SenseVoiceSmall/model.pt" diff --git a/docs/Deployment.md b/docs/Deployment.md index dfac9dcde..14241f678 100644 --- a/docs/Deployment.md +++ b/docs/Deployment.md @@ -306,6 +306,17 @@ LLM: - 线路二:百度网盘下载[SenseVoiceSmall](https://pan.baidu.com/share/init?surl=QlgM58FHhYv1tFnUT_A8Sg&pwd=qvna) 提取码: `qvna` +如果需要使用 oceanbase 的长期记忆向量化能力需要额外下载量化模型,用于把历史对话向量化到成向量化数据。目前支持 `all-MiniLM-L6-v2` 模型。因为模型较大,需要独立下载, + +``` +wget https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip +mkdir models/all-MiniLM-L6-v2 +``` +解压后放在`models/all-MiniLM-L6-v2`目录下。 + + +``` + ## 运行状态确认 如果你能看到,类似以下日志,则是本项目服务启动成功的标志。 diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py index 95c6294ba..c5f49cd07 100644 --- a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -29,7 +29,7 @@ def __init__(self, config): self.model_path = config.get("model_path", "models/all-MiniLM-L6-v2") self.model_path = os.path.abspath(self.model_path) if not os.path.exists(self.model_path): - raise Exception(f"模型路径不存在,请下载到: {self.model_path}") + raise Exception(f"模型路径不存在,请下载量化模型 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip 并解压到: {self.model_path}") print(f"连接到 oceanbase 服务: {self.uri}") self.client = self.connect_to_client() From 5f41059559bdc683ef07e3a214ab1043315b0355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A0=E7=A3=8A?= Date: Thu, 10 Apr 2025 15:32:57 +0800 Subject: [PATCH 6/6] add oceanbase --- .../providers/memory/oceanbase/oceanbase.py | 37 ++++--------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py index c5f49cd07..7e82a5dc3 100644 --- a/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py +++ b/main/xiaozhi-server/core/providers/memory/oceanbase/oceanbase.py @@ -2,8 +2,7 @@ import traceback from pyobvector import MilvusLikeClient from sentence_transformers import SentenceTransformer - -from core.providers.memory.base import MemoryProviderBase +from core.providers.memory.base import MemoryProviderBase ,logger TAG = __name__ @@ -30,16 +29,16 @@ def __init__(self, config): self.model_path = os.path.abspath(self.model_path) if not os.path.exists(self.model_path): raise Exception(f"模型路径不存在,请下载量化模型 https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip 并解压到: {self.model_path}") - print(f"连接到 oceanbase 服务: {self.uri}") + logger.bind(tag=TAG).info(f"连接到 oceanbase 服务: {self.uri}") self.client = self.connect_to_client() def connect_to_client(self): try: return MilvusLikeClient(uri=self.uri, user=self.user,password=self.password, db_name=self.db_name) except Exception as e: - print(f"连接到 oceanbase 服务时发生错误: {str(e)}") - print(f"请检查配置并确认表是否存在: 初始化sql: {create_table_sql}") - print(f"详细错误: {traceback.format_exc()}") + logger.bind(tag=TAG).error(f"连接到 oceanbase 服务时发生错误: {str(e)}") + logger.bind(tag=TAG).error(f"请检查配置并确认表是否存在: 初始化sql: {create_table_sql}") + logger.bind(tag=TAG).error(f"详细错误: {traceback.format_exc()}") return None def _string_to_embeddings(self, sentences): @@ -67,11 +66,9 @@ async def save_memory(self, msgs): for i in range(0, len(messages), 1): self.client.insert(collection_name=self.table_name, data=messages[i:i+1]) - print(f"Save memory") + logger.bind(tag=TAG).info(f"Save memory") except Exception as e: - import traceback - print(traceback.format_exc()) - print(f"保存记忆失败: {str(e)}") + logger.bind(tag=TAG).error(f"保存记忆失败: {traceback.format_exc()}") return None async def query_memory(self, query: str) -> str: @@ -90,24 +87,6 @@ async def query_memory(self, query: str) -> str: output_fields=["role", "content"] ) return results - memories = self.format_memories(results) - return "\n".join(f"- {memory[1]}" for memory in memories) except Exception as e: - print(f"查询记忆失败: {str(e)}") + logger.bind(tag=TAG).error(f"查询记忆失败: {traceback.format_exc()}") return "" - - def format_memories(self, results): - memories = [] - for entry in results['results']: - timestamp = entry.get('updated_at', '') - if timestamp: - try: - dt = timestamp.split('.')[0] - formatted_time = dt.replace('T', ' ') - except: - formatted_time = timestamp - memory = entry.get('memory', '') - if timestamp and memory: - memories.append((timestamp, f"[{formatted_time}] {memory}")) - memories.sort(key=lambda x: x[0], reverse=True) - return memories \ No newline at end of file