Skip to content

Commit cc4b9af

Browse files
authored
support customize request header (#315)
1 parent 2084a38 commit cc4b9af

File tree

5 files changed

+129
-1
lines changed

5 files changed

+129
-1
lines changed

client/sse.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ func WithHeaders(headers map[string]string) transport.ClientOption {
1212
return transport.WithHeaders(headers)
1313
}
1414

15+
func WithHeaderFunc(headerFunc transport.HTTPHeaderFunc) transport.ClientOption {
16+
return transport.WithHeaderFunc(headerFunc)
17+
}
18+
1519
func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
1620
return transport.WithHTTPClient(httpClient)
1721
}

client/sse_test.go

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"net/http"
56
"testing"
67
"time"
78

@@ -11,6 +12,13 @@ import (
1112
"github.com/mark3labs/mcp-go/server"
1213
)
1314

15+
type contextKey string
16+
17+
const (
18+
testHeaderKey contextKey = "X-Test-Header"
19+
testHeaderFuncKey contextKey = "X-Test-Header-Func"
20+
)
21+
1422
func TestSSEMCPClient(t *testing.T) {
1523
// Create MCP server with capabilities
1624
mcpServer := server.NewMCPServer(
@@ -41,9 +49,29 @@ func TestSSEMCPClient(t *testing.T) {
4149
},
4250
}, nil
4351
})
52+
mcpServer.AddTool(mcp.NewTool(
53+
"test-tool-for-http-header",
54+
mcp.WithDescription("Test tool for http header"),
55+
), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
56+
// , X-Test-Header-Func
57+
return &mcp.CallToolResult{
58+
Content: []mcp.Content{
59+
mcp.TextContent{
60+
Type: "text",
61+
Text: "context from header: " + ctx.Value(testHeaderKey).(string) + ", " + ctx.Value(testHeaderFuncKey).(string),
62+
},
63+
},
64+
}, nil
65+
})
4466

4567
// Initialize
46-
testServer := server.NewTestServer(mcpServer)
68+
testServer := server.NewTestServer(mcpServer,
69+
server.WithHTTPContextFunc(func(ctx context.Context, r *http.Request) context.Context {
70+
ctx = context.WithValue(ctx, testHeaderKey, r.Header.Get("X-Test-Header"))
71+
ctx = context.WithValue(ctx, testHeaderFuncKey, r.Header.Get("X-Test-Header-Func"))
72+
return ctx
73+
}),
74+
)
4775
defer testServer.Close()
4876

4977
t.Run("Can create client", func(t *testing.T) {
@@ -250,4 +278,56 @@ func TestSSEMCPClient(t *testing.T) {
250278
t.Errorf("Expected 1 content item, got %d", len(result.Content))
251279
}
252280
})
281+
282+
t.Run("CallTool with customized header", func(t *testing.T) {
283+
client, err := NewSSEMCPClient(testServer.URL+"/sse",
284+
WithHeaders(map[string]string{
285+
"X-Test-Header": "test-header-value",
286+
}),
287+
WithHeaderFunc(func(ctx context.Context) map[string]string {
288+
return map[string]string{
289+
"X-Test-Header-Func": "test-header-func-value",
290+
}
291+
}),
292+
)
293+
if err != nil {
294+
t.Fatalf("Failed to create client: %v", err)
295+
}
296+
defer client.Close()
297+
298+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
299+
defer cancel()
300+
301+
if err := client.Start(ctx); err != nil {
302+
t.Fatalf("Failed to start client: %v", err)
303+
}
304+
305+
// Initialize
306+
initRequest := mcp.InitializeRequest{}
307+
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
308+
initRequest.Params.ClientInfo = mcp.Implementation{
309+
Name: "test-client",
310+
Version: "1.0.0",
311+
}
312+
313+
_, err = client.Initialize(ctx, initRequest)
314+
if err != nil {
315+
t.Fatalf("Failed to initialize: %v", err)
316+
}
317+
318+
request := mcp.CallToolRequest{}
319+
request.Params.Name = "test-tool-for-http-header"
320+
321+
result, err := client.CallTool(ctx, request)
322+
if err != nil {
323+
t.Fatalf("CallTool failed: %v", err)
324+
}
325+
326+
if len(result.Content) != 1 {
327+
t.Errorf("Expected 1 content item, got %d", len(result.Content))
328+
}
329+
if result.Content[0].(mcp.TextContent).Text != "context from header: test-header-value, test-header-func-value" {
330+
t.Errorf("Got %q, want %q", result.Content[0].(mcp.TextContent).Text, "context from header: test-header-value, test-header-func-value")
331+
}
332+
})
253333
}

client/transport/interface.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ import (
77
"github.com/mark3labs/mcp-go/mcp"
88
)
99

10+
// HTTPHeaderFunc is a function that extracts header entries from the given context
11+
// and returns them as key-value pairs. This is typically used to add context values
12+
// as HTTP headers in outgoing requests.
13+
type HTTPHeaderFunc func(context.Context) map[string]string
14+
1015
// Interface for the transport layer.
1116
type Interface interface {
1217
// Start the connection. Start should only be called once.

client/transport/sse.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type SSE struct {
3131
notifyMu sync.RWMutex
3232
endpointChan chan struct{}
3333
headers map[string]string
34+
headerFunc HTTPHeaderFunc
3435

3536
started atomic.Bool
3637
closed atomic.Bool
@@ -45,6 +46,12 @@ func WithHeaders(headers map[string]string) ClientOption {
4546
}
4647
}
4748

49+
func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
50+
return func(sc *SSE) {
51+
sc.headerFunc = headerFunc
52+
}
53+
}
54+
4855
func WithHTTPClient(httpClient *http.Client) ClientOption {
4956
return func(sc *SSE) {
5057
sc.httpClient = httpClient
@@ -99,6 +106,11 @@ func (c *SSE) Start(ctx context.Context) error {
99106
for k, v := range c.headers {
100107
req.Header.Set(k, v)
101108
}
109+
if c.headerFunc != nil {
110+
for k, v := range c.headerFunc(ctx) {
111+
req.Header.Set(k, v)
112+
}
113+
}
102114

103115
resp, err := c.httpClient.Do(req)
104116
if err != nil {
@@ -269,6 +281,11 @@ func (c *SSE) SendRequest(
269281
for k, v := range c.headers {
270282
req.Header.Set(k, v)
271283
}
284+
if c.headerFunc != nil {
285+
for k, v := range c.headerFunc(ctx) {
286+
req.Header.Set(k, v)
287+
}
288+
}
272289

273290
// Create string key for map lookup
274291
idKey := request.ID.String()
@@ -368,6 +385,11 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
368385
for k, v := range c.headers {
369386
req.Header.Set(k, v)
370387
}
388+
if c.headerFunc != nil {
389+
for k, v := range c.headerFunc(ctx) {
390+
req.Header.Set(k, v)
391+
}
392+
}
371393

372394
resp, err := c.httpClient.Do(req)
373395
if err != nil {

client/transport/streamable_http.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
2626
}
2727
}
2828

29+
func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
30+
return func(sc *StreamableHTTP) {
31+
sc.headerFunc = headerFunc
32+
}
33+
}
34+
2935
// WithHTTPTimeout sets the timeout for a HTTP request and stream.
3036
func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
3137
return func(sc *StreamableHTTP) {
@@ -52,6 +58,7 @@ type StreamableHTTP struct {
5258
baseURL *url.URL
5359
httpClient *http.Client
5460
headers map[string]string
61+
headerFunc HTTPHeaderFunc
5562

5663
sessionID atomic.Value // string
5764

@@ -172,6 +179,11 @@ func (c *StreamableHTTP) SendRequest(
172179
for k, v := range c.headers {
173180
req.Header.Set(k, v)
174181
}
182+
if c.headerFunc != nil {
183+
for k, v := range c.headerFunc(ctx) {
184+
req.Header.Set(k, v)
185+
}
186+
}
175187

176188
// Send request
177189
resp, err := c.httpClient.Do(req)
@@ -362,6 +374,11 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.
362374
for k, v := range c.headers {
363375
req.Header.Set(k, v)
364376
}
377+
if c.headerFunc != nil {
378+
for k, v := range c.headerFunc(ctx) {
379+
req.Header.Set(k, v)
380+
}
381+
}
365382

366383
// Send request
367384
resp, err := c.httpClient.Do(req)

0 commit comments

Comments
 (0)