@@ -9,9 +9,10 @@ import (
9
9
"testing"
10
10
"time"
11
11
12
- "github.com/mark3labs/mcp-go/mcp"
13
12
"github.com/stretchr/testify/assert"
14
13
"github.com/stretchr/testify/require"
14
+
15
+ "github.com/mark3labs/mcp-go/mcp"
15
16
)
16
17
17
18
// sessionTestClient implements the basic ClientSession interface for testing
@@ -99,12 +100,49 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
99
100
f .sessionTools = toolsCopy
100
101
}
101
102
103
+ // sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
104
+ type sessionTestClientWithClientInfo struct {
105
+ sessionID string
106
+ notificationChannel chan mcp.JSONRPCNotification
107
+ initialized bool
108
+ clientInfo atomic.Value
109
+ }
110
+
111
+ func (f * sessionTestClientWithClientInfo ) SessionID () string {
112
+ return f .sessionID
113
+ }
114
+
115
+ func (f * sessionTestClientWithClientInfo ) NotificationChannel () chan <- mcp.JSONRPCNotification {
116
+ return f .notificationChannel
117
+ }
118
+
119
+ func (f * sessionTestClientWithClientInfo ) Initialize () {
120
+ f .initialized = true
121
+ }
122
+
123
+ func (f * sessionTestClientWithClientInfo ) Initialized () bool {
124
+ return f .initialized
125
+ }
126
+
127
+ func (f * sessionTestClientWithClientInfo ) GetClientInfo () mcp.Implementation {
128
+ if value := f .clientInfo .Load (); value != nil {
129
+ if clientInfo , ok := value .(mcp.Implementation ); ok {
130
+ return clientInfo
131
+ }
132
+ }
133
+ return mcp.Implementation {}
134
+ }
135
+
136
+ func (f * sessionTestClientWithClientInfo ) SetClientInfo (clientInfo mcp.Implementation ) {
137
+ f .clientInfo .Store (clientInfo )
138
+ }
139
+
102
140
// sessionTestClientWithTools implements the SessionWithLogging interface for testing
103
141
type sessionTestClientWithLogging struct {
104
142
sessionID string
105
143
notificationChannel chan mcp.JSONRPCNotification
106
144
initialized bool
107
- loggingLevel atomic.Value
145
+ loggingLevel atomic.Value
108
146
}
109
147
110
148
func (f * sessionTestClientWithLogging ) SessionID () string {
@@ -136,9 +174,10 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
136
174
137
175
// Verify that all implementations satisfy their respective interfaces
138
176
var (
139
- _ ClientSession = (* sessionTestClient )(nil )
140
- _ SessionWithTools = (* sessionTestClientWithTools )(nil )
141
- _ SessionWithLogging = (* sessionTestClientWithLogging )(nil )
177
+ _ ClientSession = (* sessionTestClient )(nil )
178
+ _ SessionWithTools = (* sessionTestClientWithTools )(nil )
179
+ _ SessionWithLogging = (* sessionTestClientWithLogging )(nil )
180
+ _ SessionWithClientInfo = (* sessionTestClientWithClientInfo )(nil )
142
181
)
143
182
144
183
func TestSessionWithTools_Integration (t * testing.T ) {
@@ -1041,4 +1080,49 @@ func TestMCPServer_SetLevel(t *testing.T) {
1041
1080
if session .GetLogLevel () != mcp .LoggingLevelCritical {
1042
1081
t .Errorf ("Expected critical level, got %v" , session .GetLogLevel ())
1043
1082
}
1044
- }
1083
+ }
1084
+
1085
+ func TestSessionWithClientInfo_Integration (t * testing.T ) {
1086
+ server := NewMCPServer ("test-server" , "1.0.0" )
1087
+
1088
+ session := & sessionTestClientWithClientInfo {
1089
+ sessionID : "session-1" ,
1090
+ notificationChannel : make (chan mcp.JSONRPCNotification , 10 ),
1091
+ initialized : false ,
1092
+ }
1093
+
1094
+ err := server .RegisterSession (context .Background (), session )
1095
+ require .NoError (t , err )
1096
+
1097
+ clientInfo := mcp.Implementation {
1098
+ Name : "test-client" ,
1099
+ Version : "1.0.0" ,
1100
+ }
1101
+
1102
+ initRequest := mcp.InitializeRequest {}
1103
+ initRequest .Params .ClientInfo = clientInfo
1104
+ initRequest .Params .ProtocolVersion = mcp .LATEST_PROTOCOL_VERSION
1105
+ initRequest .Params .Capabilities = mcp.ClientCapabilities {}
1106
+
1107
+ sessionCtx := server .WithContext (context .Background (), session )
1108
+
1109
+ // Retrieve the session from context
1110
+ retrievedSession := ClientSessionFromContext (sessionCtx )
1111
+ require .NotNil (t , retrievedSession , "Session should be available from context" )
1112
+ assert .Equal (t , session .SessionID (), retrievedSession .SessionID (), "Session ID should match" )
1113
+
1114
+ result , reqErr := server .handleInitialize (sessionCtx , 1 , initRequest )
1115
+ require .Nil (t , reqErr )
1116
+ require .NotNil (t , result )
1117
+
1118
+ // Check if the session can be cast to SessionWithClientInfo
1119
+ sessionWithClientInfo , ok := retrievedSession .(SessionWithClientInfo )
1120
+ require .True (t , ok , "Session should implement SessionWithClientInfo" )
1121
+
1122
+ assert .True (t , sessionWithClientInfo .Initialized (), "Session should be initialized" )
1123
+
1124
+ storedClientInfo := sessionWithClientInfo .GetClientInfo ()
1125
+
1126
+ assert .Equal (t , clientInfo .Name , storedClientInfo .Name , "Client name should match" )
1127
+ assert .Equal (t , clientInfo .Version , storedClientInfo .Version , "Client version should match" )
1128
+ }
0 commit comments