Skip to content

Commit c61624c

Browse files
authored
Merge pull request #8 from mark3labs/implement-notification-handling
Implement notification handling
2 parents 48c485e + d1c3cfc commit c61624c

File tree

10 files changed

+529
-283
lines changed

10 files changed

+529
-283
lines changed

README.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ func main() {
4646
}
4747
}
4848

49-
func helloHandler(ctx context.Context, arguments map[string]interface{}) (*mcp.CallToolResult, error) {
50-
name, ok := arguments["name"].(string)
49+
func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
50+
name, ok := request.Params.Arguments["name"].(string)
5151
if !ok {
5252
return mcp.NewToolResultError("name must be a string"), nil
5353
}
@@ -137,10 +137,10 @@ func main() {
137137
)
138138

139139
// Add the calculator handler
140-
s.AddTool(calculatorTool, func(args map[string]interface{}) (*mcp.CallToolResult, error) {
141-
op := args["operation"].(string)
142-
x := args["x"].(float64)
143-
y := args["y"].(float64)
140+
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
141+
op := request.Params.Arguments["operation"].(string)
142+
x := request.Params.Arguments["x"].(float64)
143+
y := request.Params.Arguments["y"].(float64)
144144

145145
var result float64
146146
switch op {
@@ -223,7 +223,7 @@ resource := mcp.NewResource(
223223
)
224224

225225
// Add resource with its handler
226-
s.AddResource(resource, func(ctx context.Context) ([]interface{}, error) {
226+
s.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]interface{}, error) {
227227
content, err := os.ReadFile("README.md")
228228
if err != nil {
229229
return nil, err
@@ -254,8 +254,8 @@ template := mcp.NewResourceTemplate(
254254
)
255255

256256
// Add template with its handler
257-
s.AddResourceTemplate(template, func(ctx context.Context, args map[string]interface{}) ([]interface{}, error) {
258-
userID := args["id"].(string)
257+
s.AddResourceTemplate(template, func(ctx context.Context, request mcp.ReadResourceRequest) ([]interface{}, error) {
258+
userID := request.Params.URI // Extract ID from the full URI
259259

260260
profile, err := getUserProfile(userID) // Your DB/API call here
261261
if err != nil {
@@ -303,10 +303,10 @@ calculatorTool := mcp.NewTool("calculate",
303303
),
304304
)
305305

306-
s.AddTool(calculatorTool, func(args map[string]interface{}) (*mcp.CallToolResult, error) {
307-
op := args["operation"].(string)
308-
x := args["x"].(float64)
309-
y := args["y"].(float64)
306+
s.AddTool(calculatorTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
307+
op := request.Params.Arguments["operation"].(string)
308+
x := request.Params.Arguments["x"].(float64)
309+
y := request.Params.Arguments["y"].(float64)
310310

311311
var result float64
312312
switch op {
@@ -346,11 +346,11 @@ httpTool := mcp.NewTool("http_request",
346346
),
347347
)
348348

349-
s.AddTool(httpTool, func(args map[string]interface{}) (*mcp.CallToolResult, error) {
350-
method := args["method"].(string)
351-
url := args["url"].(string)
349+
s.AddTool(httpTool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
350+
method := request.Params.Arguments["method"].(string)
351+
url := request.Params.Arguments["url"].(string)
352352
body := ""
353-
if b, ok := args["body"].(string); ok {
353+
if b, ok := request.Params.Arguments["body"].(string); ok {
354354
body = b
355355
}
356356

@@ -413,8 +413,8 @@ s.AddPrompt(mcp.NewPrompt("greeting",
413413
mcp.WithArgument("name",
414414
mcp.ArgumentDescription("Name of the person to greet"),
415415
),
416-
), func(args map[string]string) (*mcp.GetPromptResult, error) {
417-
name := args["name"]
416+
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
417+
name := request.Params.Arguments["name"].(string)
418418
if name == "" {
419419
name = "friend"
420420
}
@@ -437,8 +437,8 @@ s.AddPrompt(mcp.NewPrompt("code_review",
437437
mcp.ArgumentDescription("Pull request number to review"),
438438
mcp.RequiredArgument(),
439439
),
440-
), func(args map[string]string) (*mcp.GetPromptResult, error) {
441-
prNumber := args["pr_number"]
440+
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
441+
prNumber := request.Params.Arguments["pr_number"].(string)
442442
if prNumber == "" {
443443
return nil, fmt.Errorf("pr_number is required")
444444
}
@@ -468,8 +468,8 @@ s.AddPrompt(mcp.NewPrompt("query_builder",
468468
mcp.ArgumentDescription("Name of the table to query"),
469469
mcp.RequiredArgument(),
470470
),
471-
), func(args map[string]string) (*mcp.GetPromptResult, error) {
472-
tableName := args["table"]
471+
), func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
472+
tableName := request.Params.Arguments["table"].(string)
473473
if tableName == "" {
474474
return nil, fmt.Errorf("table name is required")
475475
}

client/sse_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ func TestSSEMCPClient(t *testing.T) {
2626
Type: "object",
2727
Properties: map[string]interface{}{},
2828
},
29-
}, func(arguments map[string]interface{}) (*mcp.CallToolResult, error) {
29+
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
3030
return &mcp.CallToolResult{}, nil
3131
})
3232

33-
// Create test server
33+
// Initialize
3434
testServer := server.NewTestServer(mcpServer)
3535
defer testServer.Close()
3636

examples/everything/main.go

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package main
22

33
import (
4+
"context"
5+
"flag"
46
"fmt"
57
"log"
68
"time"
@@ -79,6 +81,12 @@ func NewMCPServer() *MCPServer {
7981
mcp.Required(),
8082
),
8183
), s.handleEchoTool)
84+
85+
s.server.AddTool(
86+
mcp.NewTool("notify"),
87+
s.handleSendNotification,
88+
)
89+
8290
s.server.AddTool(mcp.NewTool(string(ADD),
8391
mcp.WithDescription("Adds two numbers"),
8492
mcp.WithNumber("a",
@@ -127,7 +135,7 @@ func NewMCPServer() *MCPServer {
127135
mcp.WithDescription("Returns the MCP_TINY_IMAGE"),
128136
), s.handleGetTinyImageTool)
129137

130-
s.server.AddNotificationHandler(s.handleNotification)
138+
s.server.AddNotificationHandler("notification", s.handleNotification)
131139

132140
go s.runUpdateInterval()
133141

@@ -177,6 +185,7 @@ func (s *MCPServer) runUpdateInterval() {
177185
}
178186

179187
func (s *MCPServer) handleReadResource(
188+
ctx context.Context,
180189
request mcp.ReadResourceRequest,
181190
) ([]interface{}, error) {
182191
return []interface{}{
@@ -191,6 +200,7 @@ func (s *MCPServer) handleReadResource(
191200
}
192201

193202
func (s *MCPServer) handleResourceTemplate(
203+
ctx context.Context,
194204
request mcp.ReadResourceRequest,
195205
) ([]interface{}, error) {
196206
return []interface{}{
@@ -205,7 +215,8 @@ func (s *MCPServer) handleResourceTemplate(
205215
}
206216

207217
func (s *MCPServer) handleSimplePrompt(
208-
arguments map[string]string,
218+
ctx context.Context,
219+
request mcp.GetPromptRequest,
209220
) (*mcp.GetPromptResult, error) {
210221
return &mcp.GetPromptResult{
211222
Description: "A simple prompt without arguments",
@@ -222,8 +233,10 @@ func (s *MCPServer) handleSimplePrompt(
222233
}
223234

224235
func (s *MCPServer) handleComplexPrompt(
225-
arguments map[string]string,
236+
ctx context.Context,
237+
request mcp.GetPromptRequest,
226238
) (*mcp.GetPromptResult, error) {
239+
arguments := request.Params.Arguments
227240
return &mcp.GetPromptResult{
228241
Description: "A complex prompt with arguments",
229242
Messages: []mcp.PromptMessage{
@@ -258,8 +271,10 @@ func (s *MCPServer) handleComplexPrompt(
258271
}
259272

260273
func (s *MCPServer) handleEchoTool(
261-
arguments map[string]interface{},
274+
ctx context.Context,
275+
request mcp.CallToolRequest,
262276
) (*mcp.CallToolResult, error) {
277+
arguments := request.Params.Arguments
263278
message, ok := arguments["message"].(string)
264279
if !ok {
265280
return nil, fmt.Errorf("invalid message argument")
@@ -275,8 +290,10 @@ func (s *MCPServer) handleEchoTool(
275290
}
276291

277292
func (s *MCPServer) handleAddTool(
278-
arguments map[string]interface{},
293+
ctx context.Context,
294+
request mcp.CallToolRequest,
279295
) (*mcp.CallToolResult, error) {
296+
arguments := request.Params.Arguments
280297
a, ok1 := arguments["a"].(float64)
281298
b, ok2 := arguments["b"].(float64)
282299
if !ok1 || !ok2 {
@@ -293,35 +310,65 @@ func (s *MCPServer) handleAddTool(
293310
}, nil
294311
}
295312

313+
func (s *MCPServer) handleSendNotification(
314+
ctx context.Context,
315+
request mcp.CallToolRequest,
316+
) (*mcp.CallToolResult, error) {
317+
318+
server := server.ServerFromContext(ctx)
319+
320+
err := server.SendNotificationToClient(
321+
"notifications/progress",
322+
map[string]interface{}{
323+
"progress": 10,
324+
"total": 10,
325+
"progressToken": 0,
326+
},
327+
)
328+
if err != nil {
329+
return nil, fmt.Errorf("failed to send notification: %w", err)
330+
}
331+
332+
return &mcp.CallToolResult{
333+
Content: []interface{}{
334+
mcp.TextContent{
335+
Type: "text",
336+
Text: "notification sent successfully",
337+
},
338+
},
339+
}, nil
340+
}
341+
342+
func (s *MCPServer) ServeSSE(addr string) *server.SSEServer {
343+
return server.NewSSEServer(s.server, fmt.Sprintf("http://%s", addr))
344+
}
345+
346+
func (s *MCPServer) ServeStdio() error {
347+
return server.ServeStdio(s.server)
348+
}
349+
296350
func (s *MCPServer) handleLongRunningOperationTool(
297-
arguments map[string]interface{},
351+
ctx context.Context,
352+
request mcp.CallToolRequest,
298353
) (*mcp.CallToolResult, error) {
354+
arguments := request.Params.Arguments
355+
progressToken := request.Params.Meta.ProgressToken
299356
duration, _ := arguments["duration"].(float64)
300357
steps, _ := arguments["steps"].(float64)
301358
stepDuration := duration / steps
302-
progressToken, _ := arguments["_meta"].(map[string]interface{})["progressToken"].(mcp.ProgressToken)
359+
server := server.ServerFromContext(ctx)
303360

304361
for i := 1; i < int(steps)+1; i++ {
305362
time.Sleep(time.Duration(stepDuration * float64(time.Second)))
306363
if progressToken != nil {
307-
// s.server.HandleMessage(
308-
// context.Background(),
309-
// mcp.JSONRPCNotification{
310-
// JSONRPC: mcp.JSONRPC_VERSION,
311-
// Notification: mcp.Notification{
312-
// Method: "progress",
313-
// Params: struct {
314-
// Meta map[string]interface{} `json:"_meta,omitempty"`
315-
// }{
316-
// Meta: map[string]interface{}{
317-
// "progress": i,
318-
// "total": int(steps),
319-
// "progressToken": progressToken,
320-
// },
321-
// },
322-
// },
323-
// },
324-
// )
364+
server.SendNotificationToClient(
365+
"notifications/progress",
366+
map[string]interface{}{
367+
"progress": i,
368+
"total": int(steps),
369+
"progressToken": progressToken,
370+
},
371+
)
325372
}
326373
}
327374

@@ -361,7 +408,8 @@ func (s *MCPServer) handleLongRunningOperationTool(
361408
// }
362409

363410
func (s *MCPServer) handleGetTinyImageTool(
364-
arguments map[string]interface{},
411+
ctx context.Context,
412+
request mcp.CallToolRequest,
365413
) (*mcp.CallToolResult, error) {
366414
return &mcp.CallToolResult{
367415
Content: []interface{}{
@@ -382,7 +430,10 @@ func (s *MCPServer) handleGetTinyImageTool(
382430
}, nil
383431
}
384432

385-
func (s *MCPServer) handleNotification(notification mcp.JSONRPCNotification) {
433+
func (s *MCPServer) handleNotification(
434+
ctx context.Context,
435+
notification mcp.JSONRPCNotification,
436+
) {
386437
log.Printf("Received notification: %s", notification.Method)
387438
}
388439

@@ -391,9 +442,34 @@ func (s *MCPServer) Serve() error {
391442
}
392443

393444
func main() {
445+
var transport string
446+
flag.StringVar(&transport, "t", "stdio", "Transport type (stdio or sse)")
447+
flag.StringVar(
448+
&transport,
449+
"transport",
450+
"stdio",
451+
"Transport type (stdio or sse)",
452+
)
453+
flag.Parse()
454+
394455
server := NewMCPServer()
395-
if err := server.Serve(); err != nil {
396-
log.Fatalf("Server error: %v", err)
456+
457+
switch transport {
458+
case "stdio":
459+
if err := server.ServeStdio(); err != nil {
460+
log.Fatalf("Server error: %v", err)
461+
}
462+
case "sse":
463+
sseServer := server.ServeSSE("localhost:8080")
464+
log.Printf("SSE server listening on :8080")
465+
if err := sseServer.Start(":8080"); err != nil {
466+
log.Fatalf("Server error: %v", err)
467+
}
468+
default:
469+
log.Fatalf(
470+
"Invalid transport type: %s. Must be 'stdio' or 'sse'",
471+
transport,
472+
)
397473
}
398474
}
399475

0 commit comments

Comments
 (0)