Skip to content

Commit e088df6

Browse files
committed
Refactor token verification into an interface
Signed-off-by: Brian Groux <[email protected]>
1 parent 8d03c00 commit e088df6

File tree

9 files changed

+746
-680
lines changed

9 files changed

+746
-680
lines changed

server/server.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,22 @@ func checkOIDCConfigChange(currentOIDCConfig *settings_util.OIDCConfig, newArgoC
791791
return false
792792
}
793793

794+
func checkJWTConfigChange(currentJWTConfig *settings_util.JWTConfig, newArgoCDSettings *settings_util.ArgoCDSettings) bool {
795+
newJWTConfig := newArgoCDSettings.JWTConfig
796+
797+
if (currentJWTConfig != nil && newJWTConfig == nil) || (currentJWTConfig == nil && newJWTConfig != nil) {
798+
return true
799+
}
800+
801+
if currentJWTConfig != nil && newJWTConfig != nil {
802+
if !reflect.DeepEqual(*currentJWTConfig, *newJWTConfig) {
803+
return true
804+
}
805+
}
806+
807+
return false
808+
}
809+
794810
// watchSettings watches the configmap and secret for any setting updates that would warrant a
795811
// restart of the API server.
796812
func (server *ArgoCDServer) watchSettings() {
@@ -799,6 +815,7 @@ func (server *ArgoCDServer) watchSettings() {
799815

800816
prevURL := server.settings.URL
801817
prevAdditionalURLs := server.settings.AdditionalURLs
818+
prevJWTConfig := server.settings.JWTConfig
802819
prevOIDCConfig := server.settings.OIDCConfig()
803820
prevDexCfgBytes, err := dexutil.GenerateDexConfigYAML(server.settings, server.DexTLSConfig == nil || server.DexTLSConfig.DisableTLS)
804821
errorsutil.CheckError(err)
@@ -822,6 +839,10 @@ func (server *ArgoCDServer) watchSettings() {
822839
log.Infof("dex config modified. restarting")
823840
break
824841
}
842+
if checkJWTConfigChange(prevJWTConfig, server.settings) {
843+
log.Infof("jwt config modified. restarting")
844+
break
845+
}
825846
if checkOIDCConfigChange(prevOIDCConfig, server.settings) {
826847
log.Infof("oidc config modified. restarting")
827848
break

util/jwt/token/external.go

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
package token
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"net/http"
9+
"strings"
10+
"sync"
11+
"time"
12+
13+
"github.com/go-jose/go-jose/v4"
14+
jwtgo "github.com/golang-jwt/jwt/v5"
15+
log "github.com/sirupsen/logrus"
16+
17+
"github.com/argoproj/argo-cd/v3/util/settings"
18+
)
19+
20+
type externalTokenVerifier struct {
21+
client *http.Client
22+
23+
jwksCache *jose.JSONWebKeySet
24+
jwksExpiry time.Time
25+
jwksCacheMux sync.Mutex
26+
defaultCacheTTL time.Duration
27+
}
28+
29+
func NewExternalTokenVerifier(client *http.Client) Verifier {
30+
return &externalTokenVerifier{
31+
client: client,
32+
defaultCacheTTL: 5 * time.Minute,
33+
}
34+
}
35+
36+
// VerifyToken verifies an externally injected JWT token using the configured JWK Set URL
37+
func (v *externalTokenVerifier) Verify(ctx context.Context, tokenString string, argoSettings *settings.ArgoCDSettings) (jwtgo.Claims, error) {
38+
if !argoSettings.IsJWTConfigured() {
39+
return nil, errors.New("valid JWT configuration not found")
40+
}
41+
42+
cacheTTL := v.defaultCacheTTL
43+
if argoSettings.JWTConfig.CacheTTL != "" {
44+
ttl, err := time.ParseDuration(argoSettings.JWTConfig.CacheTTL)
45+
if err != nil {
46+
log.Warnf("Invalid JWT cache TTL %q, using default (%d)", argoSettings.JWTConfig.CacheTTL, cacheTTL)
47+
} else {
48+
cacheTTL = ttl
49+
}
50+
}
51+
52+
jwks, err := v.getJWKS(ctx, argoSettings.JWTConfig.JWKSetURL, cacheTTL)
53+
if err != nil {
54+
return nil, fmt.Errorf("failed to get JWKS: %w", err)
55+
}
56+
57+
// Determine signing algorithm, default to RS256 if not set
58+
allowedSigningAlg := "RS256"
59+
if argoSettings.JWTConfig.SigningAlgorithm != "" {
60+
allowedSigningAlg = argoSettings.JWTConfig.SigningAlgorithm
61+
}
62+
63+
// --- Key Function ---
64+
keyFunc := func(token *jwtgo.Token) (any, error) {
65+
// Ensure the signing method is expected before continuing.
66+
// The WithValidMethods option below enforces this, but double-checking here is fine.
67+
if token.Method.Alg() != allowedSigningAlg {
68+
return nil, fmt.Errorf("unexpected signing algorithm in external JWT: %v", token.Header["alg"])
69+
}
70+
71+
kid, ok := token.Header["kid"].(string)
72+
if !ok {
73+
return nil, errors.New("kid header not found in external JWT")
74+
}
75+
76+
var key *jose.JSONWebKey
77+
for _, k := range jwks.Keys {
78+
if k.KeyID == kid {
79+
key = &k
80+
break
81+
}
82+
}
83+
if key == nil {
84+
return nil, fmt.Errorf("no key found for kid in external JWT: %q", kid)
85+
}
86+
87+
if key.Algorithm != "" && key.Algorithm != token.Header["alg"] {
88+
return nil, fmt.Errorf("algorithm mismatch for kid %q: expected %v, got %v. External JWT issuer may be misconfigured/broken", kid, key.Algorithm, token.Header["alg"])
89+
}
90+
91+
return key.Key, nil
92+
}
93+
// --- End Key Function ---
94+
95+
// --- Parser Options ---
96+
opts := []jwtgo.ParserOption{
97+
jwtgo.WithValidMethods([]string{allowedSigningAlg}), // Enforce expected signing algorithm
98+
// Add other standard validation options based on config
99+
}
100+
if argoSettings.JWTConfig.Issuer != "" {
101+
opts = append(opts, jwtgo.WithIssuer(argoSettings.JWTConfig.Issuer))
102+
}
103+
if argoSettings.JWTConfig.Audience != "" {
104+
opts = append(opts, jwtgo.WithAudience(argoSettings.JWTConfig.Audience))
105+
}
106+
// By default, Parse validates exp, nbf, iat. Add options if specific behavior is needed.
107+
// opts = append(opts, jwtgo.WithExpirationRequired()) // Uncomment if expiration MUST be present
108+
// opts = append(opts, jwtgo.WithIssuedAt()) // Enforces iat check
109+
// --- End Parser Options ---
110+
111+
// --- Parse and Validate ---
112+
parser := jwtgo.NewParser(opts...)
113+
token, err := parser.Parse(tokenString, keyFunc)
114+
if err != nil {
115+
// Log the specific parsing/verification error for better debugging
116+
log.Debugf("externalJWT parsing/verification failed: %v", err)
117+
// Check for specific validation errors if needed for more context
118+
if errors.Is(err, jwtgo.ErrTokenInvalidIssuer) {
119+
return nil, fmt.Errorf("invalid issuer claim in external JWT: %w", err)
120+
}
121+
if errors.Is(err, jwtgo.ErrTokenInvalidAudience) {
122+
return nil, fmt.Errorf("invalid audience claim in external JWT: %w", err)
123+
}
124+
if errors.Is(err, jwtgo.ErrTokenExpired) {
125+
return nil, fmt.Errorf("external JWT is expired: %w", err)
126+
}
127+
// Return a generic error for other parsing/signature issues
128+
return nil, fmt.Errorf("failed to parse/verify external JWT: %w", err)
129+
}
130+
// --- End Parse and Validate ---
131+
132+
// --- Custom Claim Checks ---
133+
claims, ok := token.Claims.(jwtgo.MapClaims)
134+
if !ok {
135+
// This should ideally not happen if parsing succeeded, but check anyway.
136+
return nil, errors.New("invalid external JWT claims format after successful parse")
137+
}
138+
139+
if argoSettings.JWTConfig.EmailClaim != "" {
140+
if _, ok := claims[argoSettings.JWTConfig.EmailClaim]; !ok {
141+
log.Warnf("Required email claim %q not found in external JWT", argoSettings.JWTConfig.EmailClaim)
142+
// Depending on requirements, you might want to return an error here instead of just logging.
143+
// For now, let's allow it but log a warning.
144+
// return nil, fmt.Errorf("required email claim %q not found", argoSettings.JWTConfig.EmailClaim)
145+
}
146+
}
147+
148+
if argoSettings.JWTConfig.UsernameClaim != "" {
149+
if _, ok := claims[argoSettings.JWTConfig.UsernameClaim]; !ok {
150+
log.Warnf("Required username claim %q not found in external JWT", argoSettings.JWTConfig.UsernameClaim)
151+
// Depending on requirements, you might want to return an error here instead of just logging.
152+
// For now, let's allow it but log a warning.
153+
// return nil, fmt.Errorf("required username claim %q not found", argoSettings.JWTConfig.UsernameClaim)
154+
}
155+
}
156+
157+
// Verify audience if configured
158+
if argoSettings.JWTConfig.Audience != "" {
159+
audience, err := claims.GetAudience()
160+
if err != nil {
161+
// Consider if audience claim is mandatory based on your policy
162+
// return nil, fmt.Errorf("failed to get audience claim: %w", err)
163+
log.Debugf("Failed to get audience claim from external JWT, continuing verification: %v", err)
164+
} else {
165+
validAud := false
166+
for _, aud := range audience {
167+
if aud == argoSettings.JWTConfig.Audience {
168+
validAud = true
169+
break
170+
}
171+
}
172+
if !validAud {
173+
return nil, fmt.Errorf("invalid audience claim in external JWT, expected aud %q not found in %v. Perhaps someone is trying to use a token from a different issuer", argoSettings.JWTConfig.Audience, audience)
174+
}
175+
}
176+
}
177+
178+
// Parse groups and set claim for later handling at "groups" scope
179+
if argoSettings.JWTConfig.GroupsClaim != "" {
180+
if groups, ok := getNestedClaim(claims, argoSettings.JWTConfig.GroupsClaim); ok {
181+
// groups should be an array of strings...
182+
if groupsSlice, ok := groups.([]any); ok {
183+
stringGroups := make([]string, 0, len(groupsSlice))
184+
for _, group := range groupsSlice {
185+
if groupStr, ok := group.(string); ok {
186+
stringGroups = append(stringGroups, groupStr)
187+
}
188+
}
189+
claims["groups"] = stringGroups
190+
}
191+
} else {
192+
log.Warnf("Groups claim %q not found in JWT", argoSettings.JWTConfig.GroupsClaim)
193+
}
194+
}
195+
196+
// --- End Custom Claim Checks ---
197+
198+
return claims, nil
199+
}
200+
201+
func (v *externalTokenVerifier) getJWKS(ctx context.Context, jwksURL string, cacheTTL time.Duration) (*jose.JSONWebKeySet, error) {
202+
v.jwksCacheMux.Lock()
203+
defer v.jwksCacheMux.Unlock()
204+
205+
if v.jwksCache != nil && time.Now().Before(v.jwksExpiry) {
206+
return v.jwksCache, nil
207+
}
208+
209+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, http.NoBody)
210+
if err != nil {
211+
return nil, err
212+
}
213+
214+
resp, err := v.client.Do(req)
215+
if err != nil {
216+
return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
217+
}
218+
defer resp.Body.Close()
219+
220+
var jwks jose.JSONWebKeySet
221+
err = json.NewDecoder(resp.Body).Decode(&jwks)
222+
if err != nil {
223+
return nil, fmt.Errorf("failed to decode JWKS: %w", err)
224+
}
225+
226+
v.jwksCache = &jwks
227+
v.jwksExpiry = time.Now().Add(cacheTTL)
228+
229+
log.Debug("Token verified using JWT")
230+
return &jwks, nil
231+
}
232+
233+
// getNestedClaim retrieves a value from a nested map using a dot-separated path.
234+
// For example, given path "user.profile.name", it will traverse:
235+
// data["user"]["profile"]["name"]
236+
// Returns the value and true if found, nil and false otherwise.
237+
func getNestedClaim(data map[string]any, path string) (any, bool) {
238+
keys := strings.Split(path, ".")
239+
var current any = data
240+
241+
for i, key := range keys {
242+
currentMap, ok := current.(map[string]any)
243+
if !ok {
244+
return nil, false
245+
}
246+
247+
value, exists := currentMap[key]
248+
if !exists {
249+
return nil, false
250+
}
251+
252+
if i == len(keys)-1 {
253+
return value, true
254+
}
255+
current = value
256+
}
257+
return nil, false
258+
}

0 commit comments

Comments
 (0)