|
| 1 | +package token |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "encoding/json" |
| 6 | + "errors" |
| 7 | + "fmt" |
| 8 | + "net/http" |
| 9 | + "strings" |
| 10 | + "sync" |
| 11 | + "time" |
| 12 | + |
| 13 | + "github.com/go-jose/go-jose/v4" |
| 14 | + jwtgo "github.com/golang-jwt/jwt/v5" |
| 15 | + log "github.com/sirupsen/logrus" |
| 16 | + |
| 17 | + "github.com/argoproj/argo-cd/v3/util/settings" |
| 18 | +) |
| 19 | + |
| 20 | +type externalTokenVerifier struct { |
| 21 | + client *http.Client |
| 22 | + |
| 23 | + jwksCache *jose.JSONWebKeySet |
| 24 | + jwksExpiry time.Time |
| 25 | + jwksCacheMux sync.Mutex |
| 26 | + defaultCacheTTL time.Duration |
| 27 | +} |
| 28 | + |
| 29 | +func NewExternalTokenVerifier(client *http.Client) Verifier { |
| 30 | + return &externalTokenVerifier{ |
| 31 | + client: client, |
| 32 | + defaultCacheTTL: 5 * time.Minute, |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +// VerifyToken verifies an externally injected JWT token using the configured JWK Set URL |
| 37 | +func (v *externalTokenVerifier) Verify(ctx context.Context, tokenString string, argoSettings *settings.ArgoCDSettings) (jwtgo.Claims, error) { |
| 38 | + if !argoSettings.IsJWTConfigured() { |
| 39 | + return nil, errors.New("valid JWT configuration not found") |
| 40 | + } |
| 41 | + |
| 42 | + cacheTTL := v.defaultCacheTTL |
| 43 | + if argoSettings.JWTConfig.CacheTTL != "" { |
| 44 | + ttl, err := time.ParseDuration(argoSettings.JWTConfig.CacheTTL) |
| 45 | + if err != nil { |
| 46 | + log.Warnf("Invalid JWT cache TTL %q, using default (%d)", argoSettings.JWTConfig.CacheTTL, cacheTTL) |
| 47 | + } else { |
| 48 | + cacheTTL = ttl |
| 49 | + } |
| 50 | + } |
| 51 | + |
| 52 | + jwks, err := v.getJWKS(ctx, argoSettings.JWTConfig.JWKSetURL, cacheTTL) |
| 53 | + if err != nil { |
| 54 | + return nil, fmt.Errorf("failed to get JWKS: %w", err) |
| 55 | + } |
| 56 | + |
| 57 | + // Determine signing algorithm, default to RS256 if not set |
| 58 | + allowedSigningAlg := "RS256" |
| 59 | + if argoSettings.JWTConfig.SigningAlgorithm != "" { |
| 60 | + allowedSigningAlg = argoSettings.JWTConfig.SigningAlgorithm |
| 61 | + } |
| 62 | + |
| 63 | + // --- Key Function --- |
| 64 | + keyFunc := func(token *jwtgo.Token) (any, error) { |
| 65 | + // Ensure the signing method is expected before continuing. |
| 66 | + // The WithValidMethods option below enforces this, but double-checking here is fine. |
| 67 | + if token.Method.Alg() != allowedSigningAlg { |
| 68 | + return nil, fmt.Errorf("unexpected signing algorithm in external JWT: %v", token.Header["alg"]) |
| 69 | + } |
| 70 | + |
| 71 | + kid, ok := token.Header["kid"].(string) |
| 72 | + if !ok { |
| 73 | + return nil, errors.New("kid header not found in external JWT") |
| 74 | + } |
| 75 | + |
| 76 | + var key *jose.JSONWebKey |
| 77 | + for _, k := range jwks.Keys { |
| 78 | + if k.KeyID == kid { |
| 79 | + key = &k |
| 80 | + break |
| 81 | + } |
| 82 | + } |
| 83 | + if key == nil { |
| 84 | + return nil, fmt.Errorf("no key found for kid in external JWT: %q", kid) |
| 85 | + } |
| 86 | + |
| 87 | + if key.Algorithm != "" && key.Algorithm != token.Header["alg"] { |
| 88 | + return nil, fmt.Errorf("algorithm mismatch for kid %q: expected %v, got %v. External JWT issuer may be misconfigured/broken", kid, key.Algorithm, token.Header["alg"]) |
| 89 | + } |
| 90 | + |
| 91 | + return key.Key, nil |
| 92 | + } |
| 93 | + // --- End Key Function --- |
| 94 | + |
| 95 | + // --- Parser Options --- |
| 96 | + opts := []jwtgo.ParserOption{ |
| 97 | + jwtgo.WithValidMethods([]string{allowedSigningAlg}), // Enforce expected signing algorithm |
| 98 | + // Add other standard validation options based on config |
| 99 | + } |
| 100 | + if argoSettings.JWTConfig.Issuer != "" { |
| 101 | + opts = append(opts, jwtgo.WithIssuer(argoSettings.JWTConfig.Issuer)) |
| 102 | + } |
| 103 | + if argoSettings.JWTConfig.Audience != "" { |
| 104 | + opts = append(opts, jwtgo.WithAudience(argoSettings.JWTConfig.Audience)) |
| 105 | + } |
| 106 | + // By default, Parse validates exp, nbf, iat. Add options if specific behavior is needed. |
| 107 | + // opts = append(opts, jwtgo.WithExpirationRequired()) // Uncomment if expiration MUST be present |
| 108 | + // opts = append(opts, jwtgo.WithIssuedAt()) // Enforces iat check |
| 109 | + // --- End Parser Options --- |
| 110 | + |
| 111 | + // --- Parse and Validate --- |
| 112 | + parser := jwtgo.NewParser(opts...) |
| 113 | + token, err := parser.Parse(tokenString, keyFunc) |
| 114 | + if err != nil { |
| 115 | + // Log the specific parsing/verification error for better debugging |
| 116 | + log.Debugf("externalJWT parsing/verification failed: %v", err) |
| 117 | + // Check for specific validation errors if needed for more context |
| 118 | + if errors.Is(err, jwtgo.ErrTokenInvalidIssuer) { |
| 119 | + return nil, fmt.Errorf("invalid issuer claim in external JWT: %w", err) |
| 120 | + } |
| 121 | + if errors.Is(err, jwtgo.ErrTokenInvalidAudience) { |
| 122 | + return nil, fmt.Errorf("invalid audience claim in external JWT: %w", err) |
| 123 | + } |
| 124 | + if errors.Is(err, jwtgo.ErrTokenExpired) { |
| 125 | + return nil, fmt.Errorf("external JWT is expired: %w", err) |
| 126 | + } |
| 127 | + // Return a generic error for other parsing/signature issues |
| 128 | + return nil, fmt.Errorf("failed to parse/verify external JWT: %w", err) |
| 129 | + } |
| 130 | + // --- End Parse and Validate --- |
| 131 | + |
| 132 | + // --- Custom Claim Checks --- |
| 133 | + claims, ok := token.Claims.(jwtgo.MapClaims) |
| 134 | + if !ok { |
| 135 | + // This should ideally not happen if parsing succeeded, but check anyway. |
| 136 | + return nil, errors.New("invalid external JWT claims format after successful parse") |
| 137 | + } |
| 138 | + |
| 139 | + if argoSettings.JWTConfig.EmailClaim != "" { |
| 140 | + if _, ok := claims[argoSettings.JWTConfig.EmailClaim]; !ok { |
| 141 | + log.Warnf("Required email claim %q not found in external JWT", argoSettings.JWTConfig.EmailClaim) |
| 142 | + // Depending on requirements, you might want to return an error here instead of just logging. |
| 143 | + // For now, let's allow it but log a warning. |
| 144 | + // return nil, fmt.Errorf("required email claim %q not found", argoSettings.JWTConfig.EmailClaim) |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + if argoSettings.JWTConfig.UsernameClaim != "" { |
| 149 | + if _, ok := claims[argoSettings.JWTConfig.UsernameClaim]; !ok { |
| 150 | + log.Warnf("Required username claim %q not found in external JWT", argoSettings.JWTConfig.UsernameClaim) |
| 151 | + // Depending on requirements, you might want to return an error here instead of just logging. |
| 152 | + // For now, let's allow it but log a warning. |
| 153 | + // return nil, fmt.Errorf("required username claim %q not found", argoSettings.JWTConfig.UsernameClaim) |
| 154 | + } |
| 155 | + } |
| 156 | + |
| 157 | + // Verify audience if configured |
| 158 | + if argoSettings.JWTConfig.Audience != "" { |
| 159 | + audience, err := claims.GetAudience() |
| 160 | + if err != nil { |
| 161 | + // Consider if audience claim is mandatory based on your policy |
| 162 | + // return nil, fmt.Errorf("failed to get audience claim: %w", err) |
| 163 | + log.Debugf("Failed to get audience claim from external JWT, continuing verification: %v", err) |
| 164 | + } else { |
| 165 | + validAud := false |
| 166 | + for _, aud := range audience { |
| 167 | + if aud == argoSettings.JWTConfig.Audience { |
| 168 | + validAud = true |
| 169 | + break |
| 170 | + } |
| 171 | + } |
| 172 | + if !validAud { |
| 173 | + return nil, fmt.Errorf("invalid audience claim in external JWT, expected aud %q not found in %v. Perhaps someone is trying to use a token from a different issuer", argoSettings.JWTConfig.Audience, audience) |
| 174 | + } |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + // Parse groups and set claim for later handling at "groups" scope |
| 179 | + if argoSettings.JWTConfig.GroupsClaim != "" { |
| 180 | + if groups, ok := getNestedClaim(claims, argoSettings.JWTConfig.GroupsClaim); ok { |
| 181 | + // groups should be an array of strings... |
| 182 | + if groupsSlice, ok := groups.([]any); ok { |
| 183 | + stringGroups := make([]string, 0, len(groupsSlice)) |
| 184 | + for _, group := range groupsSlice { |
| 185 | + if groupStr, ok := group.(string); ok { |
| 186 | + stringGroups = append(stringGroups, groupStr) |
| 187 | + } |
| 188 | + } |
| 189 | + claims["groups"] = stringGroups |
| 190 | + } |
| 191 | + } else { |
| 192 | + log.Warnf("Groups claim %q not found in JWT", argoSettings.JWTConfig.GroupsClaim) |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + // --- End Custom Claim Checks --- |
| 197 | + |
| 198 | + return claims, nil |
| 199 | +} |
| 200 | + |
| 201 | +func (v *externalTokenVerifier) getJWKS(ctx context.Context, jwksURL string, cacheTTL time.Duration) (*jose.JSONWebKeySet, error) { |
| 202 | + v.jwksCacheMux.Lock() |
| 203 | + defer v.jwksCacheMux.Unlock() |
| 204 | + |
| 205 | + if v.jwksCache != nil && time.Now().Before(v.jwksExpiry) { |
| 206 | + return v.jwksCache, nil |
| 207 | + } |
| 208 | + |
| 209 | + req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, http.NoBody) |
| 210 | + if err != nil { |
| 211 | + return nil, err |
| 212 | + } |
| 213 | + |
| 214 | + resp, err := v.client.Do(req) |
| 215 | + if err != nil { |
| 216 | + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) |
| 217 | + } |
| 218 | + defer resp.Body.Close() |
| 219 | + |
| 220 | + var jwks jose.JSONWebKeySet |
| 221 | + err = json.NewDecoder(resp.Body).Decode(&jwks) |
| 222 | + if err != nil { |
| 223 | + return nil, fmt.Errorf("failed to decode JWKS: %w", err) |
| 224 | + } |
| 225 | + |
| 226 | + v.jwksCache = &jwks |
| 227 | + v.jwksExpiry = time.Now().Add(cacheTTL) |
| 228 | + |
| 229 | + log.Debug("Token verified using JWT") |
| 230 | + return &jwks, nil |
| 231 | +} |
| 232 | + |
| 233 | +// getNestedClaim retrieves a value from a nested map using a dot-separated path. |
| 234 | +// For example, given path "user.profile.name", it will traverse: |
| 235 | +// data["user"]["profile"]["name"] |
| 236 | +// Returns the value and true if found, nil and false otherwise. |
| 237 | +func getNestedClaim(data map[string]any, path string) (any, bool) { |
| 238 | + keys := strings.Split(path, ".") |
| 239 | + var current any = data |
| 240 | + |
| 241 | + for i, key := range keys { |
| 242 | + currentMap, ok := current.(map[string]any) |
| 243 | + if !ok { |
| 244 | + return nil, false |
| 245 | + } |
| 246 | + |
| 247 | + value, exists := currentMap[key] |
| 248 | + if !exists { |
| 249 | + return nil, false |
| 250 | + } |
| 251 | + |
| 252 | + if i == len(keys)-1 { |
| 253 | + return value, true |
| 254 | + } |
| 255 | + current = value |
| 256 | + } |
| 257 | + return nil, false |
| 258 | +} |
0 commit comments