1
+ package client
2
+
3
+ import (
4
+ "context"
5
+ "encoding/json"
6
+ "fmt"
7
+ "net/http"
8
+ "net/http/httptest"
9
+ "testing"
10
+ "time"
11
+
12
+ "github.com/mark3labs/mcp-go/client/transport"
13
+ )
14
+
15
+ func TestNewOAuthStreamableHttpClient (t * testing.T ) {
16
+ // Create a test server
17
+ server := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
18
+ // Check for Authorization header
19
+ authHeader := r .Header .Get ("Authorization" )
20
+ if authHeader != "Bearer test-token" {
21
+ w .WriteHeader (http .StatusUnauthorized )
22
+ return
23
+ }
24
+
25
+ // Return a successful response
26
+ w .WriteHeader (http .StatusOK )
27
+ w .Header ().Set ("Content-Type" , "application/json" )
28
+ if err := json .NewEncoder (w ).Encode (map [string ]any {
29
+ "jsonrpc" : "2.0" ,
30
+ "id" : 1 ,
31
+ "result" : map [string ]any {
32
+ "protocolVersion" : "2024-11-05" ,
33
+ "serverInfo" : map [string ]any {
34
+ "name" : "test-server" ,
35
+ "version" : "1.0.0" ,
36
+ },
37
+ "capabilities" : map [string ]any {},
38
+ },
39
+ }); err != nil {
40
+ t .Errorf ("Failed to encode JSON response: %v" , err )
41
+ }
42
+ }))
43
+ defer server .Close ()
44
+
45
+ // Create a token store with a valid token
46
+ tokenStore := NewMemoryTokenStore ()
47
+ validToken := & Token {
48
+ AccessToken : "test-token" ,
49
+ TokenType : "Bearer" ,
50
+ RefreshToken : "refresh-token" ,
51
+ ExpiresIn : 3600 ,
52
+ ExpiresAt : time .Now ().Add (1 * time .Hour ), // Valid for 1 hour
53
+ }
54
+ if err := tokenStore .SaveToken (validToken ); err != nil {
55
+ t .Fatalf ("Failed to save token: %v" , err )
56
+ }
57
+
58
+ // Create OAuth config
59
+ oauthConfig := OAuthConfig {
60
+ ClientID : "test-client" ,
61
+ RedirectURI : "http://localhost:8085/callback" ,
62
+ Scopes : []string {"mcp.read" , "mcp.write" },
63
+ TokenStore : tokenStore ,
64
+ PKCEEnabled : true ,
65
+ }
66
+
67
+ // Create client with OAuth
68
+ client , err := NewOAuthStreamableHttpClient (server .URL , oauthConfig )
69
+ if err != nil {
70
+ t .Fatalf ("Failed to create client: %v" , err )
71
+ }
72
+
73
+ // Start the client
74
+ if err := client .Start (context .Background ()); err != nil {
75
+ t .Fatalf ("Failed to start client: %v" , err )
76
+ }
77
+ defer client .Close ()
78
+
79
+ // Verify that the client was created successfully
80
+ trans := client .GetTransport ()
81
+ streamableHTTP , ok := trans .(* transport.StreamableHTTP )
82
+ if ! ok {
83
+ t .Fatalf ("Expected transport to be *transport.StreamableHTTP, got %T" , trans )
84
+ }
85
+
86
+ // Verify OAuth is enabled
87
+ if ! streamableHTTP .IsOAuthEnabled () {
88
+ t .Errorf ("Expected IsOAuthEnabled() to return true" )
89
+ }
90
+
91
+ // Verify the OAuth handler is set
92
+ if streamableHTTP .GetOAuthHandler () == nil {
93
+ t .Errorf ("Expected GetOAuthHandler() to return a handler" )
94
+ }
95
+ }
96
+
97
+ func TestIsOAuthAuthorizationRequiredError (t * testing.T ) {
98
+ // Create a test error
99
+ err := & transport.OAuthAuthorizationRequiredError {
100
+ Handler : transport .NewOAuthHandler (transport.OAuthConfig {}),
101
+ }
102
+
103
+ // Verify IsOAuthAuthorizationRequiredError returns true
104
+ if ! IsOAuthAuthorizationRequiredError (err ) {
105
+ t .Errorf ("Expected IsOAuthAuthorizationRequiredError to return true" )
106
+ }
107
+
108
+ // Verify GetOAuthHandler returns the handler
109
+ handler := GetOAuthHandler (err )
110
+ if handler == nil {
111
+ t .Errorf ("Expected GetOAuthHandler to return a handler" )
112
+ }
113
+
114
+ // Test with a different error
115
+ err2 := fmt .Errorf ("some other error" )
116
+
117
+ // Verify IsOAuthAuthorizationRequiredError returns false
118
+ if IsOAuthAuthorizationRequiredError (err2 ) {
119
+ t .Errorf ("Expected IsOAuthAuthorizationRequiredError to return false" )
120
+ }
121
+
122
+ // Verify GetOAuthHandler returns nil
123
+ handler = GetOAuthHandler (err2 )
124
+ if handler != nil {
125
+ t .Errorf ("Expected GetOAuthHandler to return nil" )
126
+ }
127
+ }
0 commit comments