Skip to content

Commit af1d980

Browse files
authored
Merge pull request #1836 from oracle-devrel/lsa-custom-rag4
added mcp server for semantic search
2 parents 6370e5f + 34545e3 commit af1d980

File tree

2 files changed

+132
-3
lines changed

2 files changed

+132
-3
lines changed

ai/gen-ai-agents/custom_rag_agent/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,24 @@
3131

3232
# embeddings
3333
EMBED_MODEL_ID = "cohere.embed-multilingual-v3.0"
34+
# EMBED_MODEL_ID = "cohere.embed-multilingual-image-v3.0"
3435

3536
# LLM
3637
# this is the default model
3738
LLM_MODEL_ID = "meta.llama-3.3-70b-instruct"
3839
TEMPERATURE = 0.1
39-
MAX_TOKENS = 1024
40+
MAX_TOKENS = 2048
4041

4142
# for the UI
4243
LANGUAGE_LIST = ["same as the question", "en", "fr", "it", "es"]
43-
MODEL_LIST = ["meta.llama-3.3-70b-instruct", "cohere.command-r-plus-08-2024"]
44+
# replaced command-r with command-a
45+
MODEL_LIST = ["meta.llama-3.3-70b-instruct", "cohere.command-a-03-2025"]
4446

4547
ENABLE_USER_FEEDBACK = True
4648

4749
# semantic search
4850
TOP_K = 6
49-
COLLECTION_LIST = ["BOOKS", "CNAF"]
51+
COLLECTION_LIST = ["DEV_COACHING", "BOOKS", "CNAF"]
5052

5153
# OCI general
5254
COMPARTMENT_ID = "ocid1.compartment.oc1..aaaaaaaaushuwb2evpuf7rcpl4r7ugmqoe7ekmaiik3ra3m7gec3d234eknq"
@@ -69,3 +71,8 @@
6971
# for loading
7072
CHUNK_SIZE = 2000
7173
CHUNK_OVERLAP = 100
74+
75+
# for MCP server
76+
TRANSPORT = "streamable-http"
77+
HOST = "0.0.0.0"
78+
PORT = 9000
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
Semantic Search exposed as an MCP tool
3+
4+
Author: L. Saetta
5+
License: MIT
6+
"""
7+
8+
from typing import Annotated
9+
from pydantic import Field
10+
import oracledb
11+
from fastmcp import FastMCP
12+
from langchain_community.vectorstores.utils import DistanceStrategy
13+
from langchain_community.embeddings import OCIGenAIEmbeddings
14+
from langchain_community.vectorstores.oraclevs import OracleVS
15+
from utils import get_console_logger
16+
17+
from config import DEBUG
18+
from config import AUTH, EMBED_MODEL_ID, SERVICE_ENDPOINT, COMPARTMENT_ID
19+
from config import TRANSPORT, HOST, PORT
20+
from config_private import CONNECT_ARGS
21+
22+
logger = get_console_logger()
23+
24+
mcp = FastMCP("Demo Semantic Search as MCP server")
25+
26+
27+
#
28+
# Helper functions
29+
#
30+
def get_connection():
31+
"""
32+
get a connection to the DB
33+
"""
34+
return oracledb.connect(**CONNECT_ARGS)
35+
36+
37+
def get_embedding_model():
38+
"""
39+
Create the Embedding Model
40+
"""
41+
embed_model = OCIGenAIEmbeddings(
42+
auth_type=AUTH,
43+
model_id=EMBED_MODEL_ID,
44+
service_endpoint=SERVICE_ENDPOINT,
45+
compartment_id=COMPARTMENT_ID,
46+
)
47+
return embed_model
48+
49+
50+
@mcp.tool
51+
def semantic_search(
52+
query: Annotated[
53+
str, Field(description="The search query to find relevant documents.")
54+
],
55+
top_k: Annotated[int, Field(description="TOP_K parameter for search")] = 5,
56+
collection_name: Annotated[
57+
str, Field(description="The name of DB table")
58+
] = "BOOKS",
59+
) -> dict:
60+
"""
61+
Perform a semantic search based on the provided query.
62+
Args:
63+
query (str): The search query.
64+
top_k (int): The number of top results to return.
65+
Returns:
66+
dict: a dictionary containing the relevant documents.
67+
"""
68+
try:
69+
# must be the same embedding model used during load in the Vector Store
70+
embed_model = get_embedding_model()
71+
72+
# get a connection to the DB and init VS
73+
with get_connection() as conn:
74+
v_store = OracleVS(
75+
client=conn,
76+
table_name=collection_name,
77+
distance_strategy=DistanceStrategy.COSINE,
78+
embedding_function=embed_model,
79+
)
80+
81+
relevant_docs = v_store.similarity_search(query=query, k=top_k)
82+
83+
if DEBUG:
84+
logger.info("Result from similarity search:")
85+
logger.info(relevant_docs)
86+
87+
except Exception as e:
88+
logger.error("Error in vector_store.invoke: %s", e)
89+
error = str(e)
90+
return {"error": error}
91+
92+
result = {"relevant_docs": relevant_docs}
93+
94+
return result
95+
96+
@mcp.tool
97+
def get_collections() -> list:
98+
"""
99+
Get the list of collections (DB tables) available in the Oracle Vector Store.
100+
Returns:
101+
list: A list of collection names.
102+
"""
103+
with get_connection() as conn:
104+
cursor = conn.cursor()
105+
106+
cursor.execute(
107+
"""SELECT DISTINCT utc.table_name
108+
FROM user_tab_columns utc
109+
WHERE utc.data_type = 'VECTOR'
110+
ORDER BY 1 ASC"""
111+
)
112+
collections = [row[0] for row in cursor.fetchall()]
113+
return collections
114+
115+
if __name__ == "__main__":
116+
mcp.run(
117+
transport=TRANSPORT,
118+
# Bind to all interfaces
119+
host=HOST,
120+
port=PORT,
121+
log_level="INFO",
122+
)

0 commit comments

Comments
 (0)