Skip to content

Commit 003b17e

Browse files
fix(mcp): CORS middleware, renaming, error handling (#9423)
**Description** This PR: * Adds a CORS middleware handler to the MCP SSE server * Adds sanity checking for malformed MCP arguments (prevent panics) * Renames tools, resources and prompts to use the snake_case convention of MCP [discussion](https://hypermode.slack.com/archives/C08N5MCE9U4/p1747936419495399) * Uses the correct error handling for returning errors to clients * Adds a comprehensive integration test for the MCP SSE implementation Note, there's a MCP demo client [here](https://github.com/hypermodeinc/dgraph-experimental/tree/main/mcp-client-apps/simple-mcp-demo) that you can spin up to test the SSE implementation. **Checklist** - [x] Code compiles correctly and linting passes locally - [x] Tests added for new functionality, or regression tests for bug fixes added as applicable
1 parent 684cbb2 commit 003b17e

File tree

5 files changed

+330
-37
lines changed

5 files changed

+330
-37
lines changed

dgraph/cmd/alpha/run.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,10 +484,20 @@ func setupMcp(baseMux *http.ServeMux, connectionString, url string, readOnly boo
484484
}
485485

486486
sse := server.NewSSEServer(s,
487-
server.WithBasePath(url),
487+
server.WithStaticBasePath(url),
488488
)
489-
baseMux.HandleFunc(url, sse.ServeHTTP)
490-
baseMux.HandleFunc(url+"/", sse.ServeHTTP)
489+
490+
corsHandler := func(w http.ResponseWriter, r *http.Request) {
491+
x.AddCorsHeaders(w)
492+
if r.Method == http.MethodOptions {
493+
w.WriteHeader(http.StatusOK)
494+
return
495+
}
496+
sse.ServeHTTP(w, r)
497+
}
498+
499+
baseMux.HandleFunc(url, corsHandler)
500+
baseMux.HandleFunc(url+"/", corsHandler)
491501
return nil
492502
}
493503

dgraph/cmd/mcp/mcp_server.go

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
6060
server.WithRecovery(),
6161
)
6262

63-
schemaTool := mcp.NewTool("Get-Schema",
63+
schemaTool := mcp.NewTool("get_schema",
6464
mcp.WithDescription("Get Dgraph DQL Schema from dgraph db"),
6565
mcp.WithToolAnnotation(mcp.ToolAnnotation{
6666
ReadOnlyHint: &True,
@@ -70,8 +70,8 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
7070
}),
7171
)
7272

73-
queryTool := mcp.NewTool("Run-Query",
74-
mcp.WithDescription("Run Dgraph Query on dgraph db"),
73+
queryTool := mcp.NewTool("run_query",
74+
mcp.WithDescription("Run Dgraph DQL Query on dgraph db"),
7575
mcp.WithString("query",
7676
mcp.Required(),
7777
mcp.Description("The query to perform"),
@@ -85,7 +85,7 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
8585
)
8686

8787
if !readOnly {
88-
alterSchemaTool := mcp.NewTool("Alter-Schema",
88+
alterSchemaTool := mcp.NewTool("alter_schema",
8989
mcp.WithDescription("Alter Dgraph DQL Schema in dgraph db"),
9090
mcp.WithString("schema",
9191
mcp.Required(),
@@ -100,24 +100,34 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
100100
)
101101

102102
s.AddTool(alterSchemaTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
103-
schema, ok := request.GetArguments()["schema"].(string)
103+
args := request.GetArguments()
104+
if args == nil {
105+
return mcp.NewToolResultError("Schema must be present"), nil
106+
}
107+
108+
schemaArg, ok := args["schema"]
109+
if !ok || schemaArg == nil {
110+
return mcp.NewToolResultError("Schema must be present"), nil
111+
}
112+
113+
schema, ok := schemaArg.(string)
104114
if !ok {
105-
return nil, fmt.Errorf("schema must be present")
115+
return mcp.NewToolResultError("Schema must be a string"), nil
106116
}
107117

108118
// Execute alter operation
109119
conn, err := getConn(connectionString)
110120
if err != nil {
111-
return nil, fmt.Errorf("error opening connection with Dgraph Alpha: %v", err)
121+
return mcp.NewToolResultErrorFromErr("Error opening connection with Dgraph Alpha", err), nil
112122
}
113123
if err = conn.SetSchema(ctx, dgo.RootNamespace, schema); err != nil {
114-
return nil, fmt.Errorf("schema alteration failed: %v", err)
124+
return mcp.NewToolResultErrorFromErr("Schema alteration failed", err), nil
115125
}
116126

117127
return mcp.NewToolResultText("Schema updated successfully"), nil
118128
})
119129

120-
mutationTool := mcp.NewTool("Run-Mutation",
130+
mutationTool := mcp.NewTool("run_mutation",
121131
mcp.WithDescription("Run DQL Mutation on dgraph db"),
122132
mcp.WithString("mutation",
123133
mcp.Required(),
@@ -134,7 +144,7 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
134144
s.AddTool(mutationTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
135145
conn, err := getConn(connectionString)
136146
if err != nil {
137-
return nil, err
147+
return mcp.NewToolResultErrorFromErr("Error opening connection with Dgraph Alpha", err), nil
138148
}
139149
txn := conn.NewTxn()
140150
defer func() {
@@ -143,25 +153,35 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
143153
glog.Errorf("failed to discard transaction: %v", err)
144154
}
145155
}()
146-
mutation, ok := request.GetArguments()["mutation"].(string)
156+
args := request.GetArguments()
157+
if args == nil {
158+
return mcp.NewToolResultError("Mutation must be present"), nil
159+
}
160+
161+
mutationArg, ok := args["mutation"]
162+
if !ok || mutationArg == nil {
163+
return mcp.NewToolResultError("Mutation must be present"), nil
164+
}
165+
166+
mutation, ok := mutationArg.(string)
147167
if !ok {
148-
return nil, fmt.Errorf("mutation must present")
168+
return mcp.NewToolResultError("Mutation must be a string"), nil
149169
}
150170
resp, err := txn.Mutate(ctx, &api.Mutation{
151171
SetJson: []byte(mutation),
152172
CommitNow: true,
153173
})
154174
if err != nil {
155-
return mcp.NewToolResultError(err.Error()), nil
175+
return mcp.NewToolResultErrorFromErr("Error running mutation", err), nil
156176
}
157-
return mcp.NewToolResultText(string(resp.GetJson())), nil
177+
return mcp.NewToolResultText(fmt.Sprintf("Mutation completed, %d UIDs created", len(resp.Uids)/2)), nil
158178
})
159179
}
160180

161181
s.AddTool(queryTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
162182
conn, err := getConn(connectionString)
163183
if err != nil {
164-
return nil, err
184+
return mcp.NewToolResultErrorFromErr("Error opening connection with Dgraph Alpha", err), nil
165185
}
166186
txn := conn.NewTxn()
167187
defer func() {
@@ -170,18 +190,29 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
170190
glog.Errorf("failed to discard transaction: %v", err)
171191
}
172192
}()
173-
op := request.GetArguments()["query"].(string)
193+
args := request.GetArguments()
194+
if args == nil {
195+
return mcp.NewToolResultError("Query must be present"), nil
196+
}
197+
queryArg, ok := args["query"]
198+
if !ok || queryArg == nil {
199+
return mcp.NewToolResultError("Query must be present"), nil
200+
}
201+
op, ok := queryArg.(string)
202+
if !ok {
203+
return mcp.NewToolResultError("Query must be a string"), nil
204+
}
174205
resp, err := txn.Query(ctx, op)
175206
if err != nil {
176-
return mcp.NewToolResultError(err.Error()), nil
207+
return mcp.NewToolResultErrorFromErr("Error running query", err), nil
177208
}
178209
return mcp.NewToolResultText(string(resp.GetJson())), nil
179210
})
180211

181212
s.AddTool(schemaTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
182213
conn, err := getConn(connectionString)
183214
if err != nil {
184-
return nil, err
215+
return mcp.NewToolResultErrorFromErr("Error opening connection with Dgraph Alpha", err), nil
185216
}
186217
txn := conn.NewTxn()
187218
defer func() {
@@ -192,27 +223,27 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
192223
}()
193224
resp, err := txn.Query(ctx, "schema {}")
194225
if err != nil {
195-
return mcp.NewToolResultError(err.Error()), nil
226+
return mcp.NewToolResultErrorFromErr("Error running query", err), nil
196227
}
197228
return mcp.NewToolResultText(string(resp.GetJson())), nil
198229
})
199230

200231
schemaResource := mcp.NewResource(
201232
"dgraph://schema",
202-
"Dgraph Schema",
203-
mcp.WithResourceDescription("The current Dgraph schema"),
233+
"dgraph_schema",
234+
mcp.WithResourceDescription("The current Dgraph DQL schema"),
204235
mcp.WithMIMEType("text/plain"),
205236
)
206237

207238
s.AddResource(schemaResource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
208239
// Execute operation
209240
conn, err := getConn(connectionString)
210241
if err != nil {
211-
return nil, err
242+
return nil, fmt.Errorf("error opening connection with Dgraph Alpha: %w", err)
212243
}
213244
resp, err := conn.NewTxn().Query(ctx, "schema {}")
214245
if err != nil {
215-
return nil, fmt.Errorf("failed to get schema: %v", err)
246+
return nil, fmt.Errorf("error running query: %w", err)
216247
}
217248

218249
return []mcp.ResourceContents{
@@ -224,7 +255,7 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
224255
}, nil
225256
})
226257

227-
commonQueriesTool := mcp.NewTool("Get-Common-Queries",
258+
commonQueriesTool := mcp.NewTool("get_common_queries",
228259
mcp.WithDescription("Get common queries that you can run on the db. If you are seeing issues with your queries, you can check this tool once."),
229260
mcp.WithToolAnnotation(mcp.ToolAnnotation{
230261
ReadOnlyHint: &True,
@@ -257,16 +288,16 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
257288
})
258289

259290
commonQueries := mcp.NewResource(
260-
"dgraph://common-queries",
261-
"Dgraph common queries",
291+
"dgraph://common_queries",
292+
"dgraph_common_queries",
262293
mcp.WithResourceDescription("The current Dgraph common queries that you can use to fix your queries"),
263294
mcp.WithMIMEType("text/plain"),
264295
)
265296

266297
s.AddResource(commonQueries, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
267298
return []mcp.ResourceContents{
268299
mcp.TextResourceContents{
269-
URI: "dgraph://commmon-queries",
300+
URI: "dgraph://commmon_queries",
270301
MIMEType: "text/plain",
271302
Text: `
272303
{
@@ -322,11 +353,11 @@ func NewMCPServer(connectionString string, readOnly bool) (*server.MCPServer, er
322353

323354
func addPrompt(s *server.MCPServer) {
324355
prompt := string(promptBytes)
325-
s.AddPrompt(mcp.NewPrompt("Quick start prompt",
356+
s.AddPrompt(mcp.NewPrompt("quick_start_prompt",
326357
mcp.WithPromptDescription("A quick Start prompt for new users and llms"),
327358
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
328359
return mcp.NewGetPromptResult(
329-
"A quick start prompt",
360+
"quick_start_prompt",
330361
[]mcp.PromptMessage{
331362
mcp.NewPromptMessage(
332363
mcp.RoleAssistant,

0 commit comments

Comments
 (0)