Skip to content

Commit 7f2ea88

Browse files
yash025Yashwanth H L
andauthored
Add option to StreamableHTTPServer to allow custom http server instance (#347)
* Add WithStreamableHTTPServer option to StreamableHTTPServer to allow setting a custom HTTP server instance, similar to existing functionality in SSE. * Add better documentation notes to WithHTTPServer and WithStreamableHTTPServer --------- Co-authored-by: Yashwanth H L <[email protected]>
1 parent d250b38 commit 7f2ea88

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

server/sse.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ func WithSSEEndpoint(endpoint string) SSEOption {
227227
}
228228
}
229229

230-
// WithHTTPServer sets the HTTP server instance
230+
// WithHTTPServer sets the HTTP server instance.
231+
// NOTE: When providing a custom HTTP server, you must handle routing yourself
232+
// If routing is not set up, the server will start but won't handle any MCP requests.
231233
func WithHTTPServer(srv *http.Server) SSEOption {
232234
return func(s *SSEServer) {
233235
s.srv = srv

server/streamable_http.go

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,15 @@ func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
7474
}
7575
}
7676

77+
// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
78+
// NOTE: When providing a custom HTTP server, you must handle routing yourself
79+
// If routing is not set up, the server will start but won't handle any MCP requests.
80+
func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
81+
return func(s *StreamableHTTPServer) {
82+
s.httpServer = srv
83+
}
84+
}
85+
7786
// WithLogger sets the logger for the server
7887
func WithLogger(logger util.Logger) StreamableHTTPOption {
7988
return func(s *StreamableHTTPServer) {
@@ -156,15 +165,24 @@ func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request)
156165
// s.Start(":8080")
157166
func (s *StreamableHTTPServer) Start(addr string) error {
158167
s.mu.Lock()
159-
mux := http.NewServeMux()
160-
mux.Handle(s.endpointPath, s)
161-
s.httpServer = &http.Server{
162-
Addr: addr,
163-
Handler: mux,
168+
if s.httpServer == nil {
169+
mux := http.NewServeMux()
170+
mux.Handle(s.endpointPath, s)
171+
s.httpServer = &http.Server{
172+
Addr: addr,
173+
Handler: mux,
174+
}
175+
} else {
176+
if s.httpServer.Addr == "" {
177+
s.httpServer.Addr = addr
178+
} else if s.httpServer.Addr != addr {
179+
return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
180+
}
164181
}
182+
srv := s.httpServer
165183
s.mu.Unlock()
166184

167-
return s.httpServer.ListenAndServe()
185+
return srv.ListenAndServe()
168186
}
169187

170188
// Shutdown gracefully stops the server, closing all active sessions

server/streamable_http_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,56 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) {
670670
})
671671
}
672672

673+
func TestStreamableHTTPServer_WithOptions(t *testing.T) {
674+
t.Run("WithStreamableHTTPServer sets httpServer field", func(t *testing.T) {
675+
mcpServer := NewMCPServer("test", "1.0.0")
676+
customServer := &http.Server{Addr: ":9999"}
677+
httpServer := NewStreamableHTTPServer(mcpServer, WithStreamableHTTPServer(customServer))
678+
679+
if httpServer.httpServer != customServer {
680+
t.Errorf("Expected httpServer to be set to custom server instance, got %v", httpServer.httpServer)
681+
}
682+
})
683+
684+
t.Run("Start with conflicting address returns error", func(t *testing.T) {
685+
mcpServer := NewMCPServer("test", "1.0.0")
686+
customServer := &http.Server{Addr: ":9999"}
687+
httpServer := NewStreamableHTTPServer(mcpServer, WithStreamableHTTPServer(customServer))
688+
689+
err := httpServer.Start(":8888")
690+
if err == nil {
691+
t.Error("Expected error for conflicting address, got nil")
692+
} else if !strings.Contains(err.Error(), "conflicting listen address") {
693+
t.Errorf("Expected error message to contain 'conflicting listen address', got '%s'", err.Error())
694+
}
695+
})
696+
697+
t.Run("Options consistency test", func(t *testing.T) {
698+
mcpServer := NewMCPServer("test", "1.0.0")
699+
endpointPath := "/test-mcp"
700+
customServer := &http.Server{}
701+
702+
// Options to test
703+
options := []StreamableHTTPOption{
704+
WithEndpointPath(endpointPath),
705+
WithStreamableHTTPServer(customServer),
706+
}
707+
708+
// Apply options multiple times and verify consistency
709+
for i := 0; i < 10; i++ {
710+
server := NewStreamableHTTPServer(mcpServer, options...)
711+
712+
if server.endpointPath != endpointPath {
713+
t.Errorf("Expected endpointPath %s, got %s", endpointPath, server.endpointPath)
714+
}
715+
716+
if server.httpServer != customServer {
717+
t.Errorf("Expected httpServer to match, got %v", server.httpServer)
718+
}
719+
}
720+
})
721+
}
722+
673723
func postJSON(url string, bodyObject any) (*http.Response, error) {
674724
jsonBody, _ := json.Marshal(bodyObject)
675725
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))

0 commit comments

Comments
 (0)