Skip to content

Commit 0c3f535

Browse files
feat: Implement OAuth in the client (#296)
* Implement OAuth in the client * Fix linting issues * More fixes * Update client/oauth.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * Update client/transport/oauth.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * fix * oauth discovery * fix * handle invalid urls * more error handling * use errors.As * get baseURL from server * Fix misleading naming --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 617c676 commit 0c3f535

File tree

10 files changed

+1889
-8
lines changed

10 files changed

+1889
-8
lines changed

client/oauth.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package client
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
7+
"github.com/mark3labs/mcp-go/client/transport"
8+
)
9+
10+
// OAuthConfig is a convenience type that wraps transport.OAuthConfig
11+
type OAuthConfig = transport.OAuthConfig
12+
13+
// Token is a convenience type that wraps transport.Token
14+
type Token = transport.Token
15+
16+
// TokenStore is a convenience type that wraps transport.TokenStore
17+
type TokenStore = transport.TokenStore
18+
19+
// MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore
20+
type MemoryTokenStore = transport.MemoryTokenStore
21+
22+
// NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore
23+
var NewMemoryTokenStore = transport.NewMemoryTokenStore
24+
25+
// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
26+
// Returns an error if the URL is invalid.
27+
func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) {
28+
// Add OAuth option to the list of options
29+
options = append(options, transport.WithOAuth(oauthConfig))
30+
31+
trans, err := transport.NewStreamableHTTP(baseURL, options...)
32+
if err != nil {
33+
return nil, fmt.Errorf("failed to create HTTP transport: %w", err)
34+
}
35+
return NewClient(trans), nil
36+
}
37+
38+
// GenerateCodeVerifier generates a code verifier for PKCE
39+
var GenerateCodeVerifier = transport.GenerateCodeVerifier
40+
41+
// GenerateCodeChallenge generates a code challenge from a code verifier
42+
var GenerateCodeChallenge = transport.GenerateCodeChallenge
43+
44+
// GenerateState generates a state parameter for OAuth
45+
var GenerateState = transport.GenerateState
46+
47+
// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
48+
type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError
49+
50+
// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError
51+
func IsOAuthAuthorizationRequiredError(err error) bool {
52+
var target *OAuthAuthorizationRequiredError
53+
return errors.As(err, &target)
54+
}
55+
56+
// GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError
57+
func GetOAuthHandler(err error) *transport.OAuthHandler {
58+
var oauthErr *OAuthAuthorizationRequiredError
59+
if errors.As(err, &oauthErr) {
60+
return oauthErr.Handler
61+
}
62+
return nil
63+
}

client/oauth_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package client
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
"time"
11+
12+
"github.com/mark3labs/mcp-go/client/transport"
13+
)
14+
15+
func TestNewOAuthStreamableHttpClient(t *testing.T) {
16+
// Create a test server
17+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18+
// Check for Authorization header
19+
authHeader := r.Header.Get("Authorization")
20+
if authHeader != "Bearer test-token" {
21+
w.WriteHeader(http.StatusUnauthorized)
22+
return
23+
}
24+
25+
// Return a successful response
26+
w.WriteHeader(http.StatusOK)
27+
w.Header().Set("Content-Type", "application/json")
28+
if err := json.NewEncoder(w).Encode(map[string]any{
29+
"jsonrpc": "2.0",
30+
"id": 1,
31+
"result": map[string]any{
32+
"protocolVersion": "2024-11-05",
33+
"serverInfo": map[string]any{
34+
"name": "test-server",
35+
"version": "1.0.0",
36+
},
37+
"capabilities": map[string]any{},
38+
},
39+
}); err != nil {
40+
t.Errorf("Failed to encode JSON response: %v", err)
41+
}
42+
}))
43+
defer server.Close()
44+
45+
// Create a token store with a valid token
46+
tokenStore := NewMemoryTokenStore()
47+
validToken := &Token{
48+
AccessToken: "test-token",
49+
TokenType: "Bearer",
50+
RefreshToken: "refresh-token",
51+
ExpiresIn: 3600,
52+
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
53+
}
54+
if err := tokenStore.SaveToken(validToken); err != nil {
55+
t.Fatalf("Failed to save token: %v", err)
56+
}
57+
58+
// Create OAuth config
59+
oauthConfig := OAuthConfig{
60+
ClientID: "test-client",
61+
RedirectURI: "http://localhost:8085/callback",
62+
Scopes: []string{"mcp.read", "mcp.write"},
63+
TokenStore: tokenStore,
64+
PKCEEnabled: true,
65+
}
66+
67+
// Create client with OAuth
68+
client, err := NewOAuthStreamableHttpClient(server.URL, oauthConfig)
69+
if err != nil {
70+
t.Fatalf("Failed to create client: %v", err)
71+
}
72+
73+
// Start the client
74+
if err := client.Start(context.Background()); err != nil {
75+
t.Fatalf("Failed to start client: %v", err)
76+
}
77+
defer client.Close()
78+
79+
// Verify that the client was created successfully
80+
trans := client.GetTransport()
81+
streamableHTTP, ok := trans.(*transport.StreamableHTTP)
82+
if !ok {
83+
t.Fatalf("Expected transport to be *transport.StreamableHTTP, got %T", trans)
84+
}
85+
86+
// Verify OAuth is enabled
87+
if !streamableHTTP.IsOAuthEnabled() {
88+
t.Errorf("Expected IsOAuthEnabled() to return true")
89+
}
90+
91+
// Verify the OAuth handler is set
92+
if streamableHTTP.GetOAuthHandler() == nil {
93+
t.Errorf("Expected GetOAuthHandler() to return a handler")
94+
}
95+
}
96+
97+
func TestIsOAuthAuthorizationRequiredError(t *testing.T) {
98+
// Create a test error
99+
err := &transport.OAuthAuthorizationRequiredError{
100+
Handler: transport.NewOAuthHandler(transport.OAuthConfig{}),
101+
}
102+
103+
// Verify IsOAuthAuthorizationRequiredError returns true
104+
if !IsOAuthAuthorizationRequiredError(err) {
105+
t.Errorf("Expected IsOAuthAuthorizationRequiredError to return true")
106+
}
107+
108+
// Verify GetOAuthHandler returns the handler
109+
handler := GetOAuthHandler(err)
110+
if handler == nil {
111+
t.Errorf("Expected GetOAuthHandler to return a handler")
112+
}
113+
114+
// Test with a different error
115+
err2 := fmt.Errorf("some other error")
116+
117+
// Verify IsOAuthAuthorizationRequiredError returns false
118+
if IsOAuthAuthorizationRequiredError(err2) {
119+
t.Errorf("Expected IsOAuthAuthorizationRequiredError to return false")
120+
}
121+
122+
// Verify GetOAuthHandler returns nil
123+
handler = GetOAuthHandler(err2)
124+
if handler != nil {
125+
t.Errorf("Expected GetOAuthHandler to return nil")
126+
}
127+
}

0 commit comments

Comments
 (0)