Skip to content

Commit 2c8bf2b

Browse files
authored
feat(server): persist client info in sessions (#313)
* feat(server): persist client info in sessions Add SessionWithClientInfo interface and implementations to store and retrieve client information provided during initialization. This allows servers to access client implementation details throughout the session lifecycle. * refactor: use atomic.Value instead of mutex * chore: cleanup * fix: restore named parameter in handleInitialize method * chore: test order
1 parent 0c3f535 commit 2c8bf2b

File tree

5 files changed

+147
-18
lines changed

5 files changed

+147
-18
lines changed

server/server.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ func (s *MCPServer) AddNotificationHandler(
512512
func (s *MCPServer) handleInitialize(
513513
ctx context.Context,
514514
_ any,
515-
_ mcp.InitializeRequest,
515+
request mcp.InitializeRequest,
516516
) (*mcp.InitializeResult, *requestError) {
517517
capabilities := mcp.ServerCapabilities{}
518518

@@ -561,6 +561,11 @@ func (s *MCPServer) handleInitialize(
561561

562562
if session := ClientSessionFromContext(ctx); session != nil {
563563
session.Initialize()
564+
565+
// Store client info if the session supports it
566+
if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok {
567+
sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo)
568+
}
564569
}
565570
return &result, nil
566571
}

server/session.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,15 @@ type SessionWithTools interface {
3939
SetSessionTools(tools map[string]ServerTool)
4040
}
4141

42+
// SessionWithClientInfo is an extension of ClientSession that can store client info
43+
type SessionWithClientInfo interface {
44+
ClientSession
45+
// GetClientInfo returns the client information for this session
46+
GetClientInfo() mcp.Implementation
47+
// SetClientInfo sets the client information for this session
48+
SetClientInfo(clientInfo mcp.Implementation)
49+
}
50+
4251
// clientSessionKey is the context key for storing current client notification channel.
4352
type clientSessionKey struct{}
4453

server/session_test.go

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ import (
99
"testing"
1010
"time"
1111

12-
"github.com/mark3labs/mcp-go/mcp"
1312
"github.com/stretchr/testify/assert"
1413
"github.com/stretchr/testify/require"
14+
15+
"github.com/mark3labs/mcp-go/mcp"
1516
)
1617

1718
// sessionTestClient implements the basic ClientSession interface for testing
@@ -99,12 +100,49 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
99100
f.sessionTools = toolsCopy
100101
}
101102

103+
// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
104+
type sessionTestClientWithClientInfo struct {
105+
sessionID string
106+
notificationChannel chan mcp.JSONRPCNotification
107+
initialized bool
108+
clientInfo atomic.Value
109+
}
110+
111+
func (f *sessionTestClientWithClientInfo) SessionID() string {
112+
return f.sessionID
113+
}
114+
115+
func (f *sessionTestClientWithClientInfo) NotificationChannel() chan<- mcp.JSONRPCNotification {
116+
return f.notificationChannel
117+
}
118+
119+
func (f *sessionTestClientWithClientInfo) Initialize() {
120+
f.initialized = true
121+
}
122+
123+
func (f *sessionTestClientWithClientInfo) Initialized() bool {
124+
return f.initialized
125+
}
126+
127+
func (f *sessionTestClientWithClientInfo) GetClientInfo() mcp.Implementation {
128+
if value := f.clientInfo.Load(); value != nil {
129+
if clientInfo, ok := value.(mcp.Implementation); ok {
130+
return clientInfo
131+
}
132+
}
133+
return mcp.Implementation{}
134+
}
135+
136+
func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implementation) {
137+
f.clientInfo.Store(clientInfo)
138+
}
139+
102140
// sessionTestClientWithTools implements the SessionWithLogging interface for testing
103141
type sessionTestClientWithLogging struct {
104142
sessionID string
105143
notificationChannel chan mcp.JSONRPCNotification
106144
initialized bool
107-
loggingLevel atomic.Value
145+
loggingLevel atomic.Value
108146
}
109147

110148
func (f *sessionTestClientWithLogging) SessionID() string {
@@ -136,9 +174,10 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
136174

137175
// Verify that all implementations satisfy their respective interfaces
138176
var (
139-
_ ClientSession = (*sessionTestClient)(nil)
140-
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
141-
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
177+
_ ClientSession = (*sessionTestClient)(nil)
178+
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
179+
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
180+
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
142181
)
143182

144183
func TestSessionWithTools_Integration(t *testing.T) {
@@ -1041,4 +1080,49 @@ func TestMCPServer_SetLevel(t *testing.T) {
10411080
if session.GetLogLevel() != mcp.LoggingLevelCritical {
10421081
t.Errorf("Expected critical level, got %v", session.GetLogLevel())
10431082
}
1044-
}
1083+
}
1084+
1085+
func TestSessionWithClientInfo_Integration(t *testing.T) {
1086+
server := NewMCPServer("test-server", "1.0.0")
1087+
1088+
session := &sessionTestClientWithClientInfo{
1089+
sessionID: "session-1",
1090+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
1091+
initialized: false,
1092+
}
1093+
1094+
err := server.RegisterSession(context.Background(), session)
1095+
require.NoError(t, err)
1096+
1097+
clientInfo := mcp.Implementation{
1098+
Name: "test-client",
1099+
Version: "1.0.0",
1100+
}
1101+
1102+
initRequest := mcp.InitializeRequest{}
1103+
initRequest.Params.ClientInfo = clientInfo
1104+
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
1105+
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
1106+
1107+
sessionCtx := server.WithContext(context.Background(), session)
1108+
1109+
// Retrieve the session from context
1110+
retrievedSession := ClientSessionFromContext(sessionCtx)
1111+
require.NotNil(t, retrievedSession, "Session should be available from context")
1112+
assert.Equal(t, session.SessionID(), retrievedSession.SessionID(), "Session ID should match")
1113+
1114+
result, reqErr := server.handleInitialize(sessionCtx, 1, initRequest)
1115+
require.Nil(t, reqErr)
1116+
require.NotNil(t, result)
1117+
1118+
// Check if the session can be cast to SessionWithClientInfo
1119+
sessionWithClientInfo, ok := retrievedSession.(SessionWithClientInfo)
1120+
require.True(t, ok, "Session should implement SessionWithClientInfo")
1121+
1122+
assert.True(t, sessionWithClientInfo.Initialized(), "Session should be initialized")
1123+
1124+
storedClientInfo := sessionWithClientInfo.GetClientInfo()
1125+
1126+
assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match")
1127+
assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match")
1128+
}

server/sse.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"time"
1616

1717
"github.com/google/uuid"
18+
1819
"github.com/mark3labs/mcp-go/mcp"
1920
)
2021

@@ -27,7 +28,8 @@ type sseSession struct {
2728
notificationChannel chan mcp.JSONRPCNotification
2829
initialized atomic.Bool
2930
loggingLevel atomic.Value
30-
tools sync.Map // stores session-specific tools
31+
tools sync.Map // stores session-specific tools
32+
clientInfo atomic.Value // stores session-specific client info
3133
}
3234

3335
// SSEContextFunc is a function that takes an existing context and the current
@@ -93,10 +95,24 @@ func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
9395
}
9496
}
9597

98+
func (s *sseSession) GetClientInfo() mcp.Implementation {
99+
if value := s.clientInfo.Load(); value != nil {
100+
if clientInfo, ok := value.(mcp.Implementation); ok {
101+
return clientInfo
102+
}
103+
}
104+
return mcp.Implementation{}
105+
}
106+
107+
func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
108+
s.clientInfo.Store(clientInfo)
109+
}
110+
96111
var (
97-
_ ClientSession = (*sseSession)(nil)
98-
_ SessionWithTools = (*sseSession)(nil)
99-
_ SessionWithLogging = (*sseSession)(nil)
112+
_ ClientSession = (*sseSession)(nil)
113+
_ SessionWithTools = (*sseSession)(nil)
114+
_ SessionWithLogging = (*sseSession)(nil)
115+
_ SessionWithClientInfo = (*sseSession)(nil)
100116
)
101117

102118
// SSEServer implements a Server-Sent Events (SSE) based MCP server.

server/stdio.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
5151

5252
// stdioSession is a static client session, since stdio has only one client.
5353
type stdioSession struct {
54-
notifications chan mcp.JSONRPCNotification
55-
initialized atomic.Bool
56-
loggingLevel atomic.Value
54+
notifications chan mcp.JSONRPCNotification
55+
initialized atomic.Bool
56+
loggingLevel atomic.Value
57+
clientInfo atomic.Value // stores session-specific client info
5758
}
5859

5960
func (s *stdioSession) SessionID() string {
@@ -74,11 +75,24 @@ func (s *stdioSession) Initialized() bool {
7475
return s.initialized.Load()
7576
}
7677

77-
func(s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
78+
func (s *stdioSession) GetClientInfo() mcp.Implementation {
79+
if value := s.clientInfo.Load(); value != nil {
80+
if clientInfo, ok := value.(mcp.Implementation); ok {
81+
return clientInfo
82+
}
83+
}
84+
return mcp.Implementation{}
85+
}
86+
87+
func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
88+
s.clientInfo.Store(clientInfo)
89+
}
90+
91+
func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
7892
s.loggingLevel.Store(level)
7993
}
8094

81-
func(s *stdioSession) GetLogLevel() mcp.LoggingLevel {
95+
func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
8296
level := s.loggingLevel.Load()
8397
if level == nil {
8498
return mcp.LoggingLevelError
@@ -87,8 +101,9 @@ func(s *stdioSession) GetLogLevel() mcp.LoggingLevel {
87101
}
88102

89103
var (
90-
_ ClientSession = (*stdioSession)(nil)
91-
_ SessionWithLogging = (*stdioSession)(nil)
104+
_ ClientSession = (*stdioSession)(nil)
105+
_ SessionWithLogging = (*stdioSession)(nil)
106+
_ SessionWithClientInfo = (*stdioSession)(nil)
92107
)
93108

94109
var stdioSessionInstance = stdioSession{

0 commit comments

Comments
 (0)