Skip to content

Commit 7b615ab

Browse files
authored
Bring more compatibility to OIDC verification (#19813)
* change something * add userinfo fetch * fixup * fixup * fixup * remove unnecessary code * fixup * change logic * fixup * fixup * fixup comment
1 parent 63858c2 commit 7b615ab

File tree

2 files changed

+51
-20
lines changed

2 files changed

+51
-20
lines changed

components/public-api-server/pkg/oidc/service.go

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -278,21 +278,16 @@ func (s *Service) authenticate(ctx context.Context, params authenticateParams) (
278278
if err != nil {
279279
return nil, fmt.Errorf("failed to verify id_token: %w", err)
280280
}
281-
claims := map[string]interface{}{}
282-
err = idToken.Claims(&claims)
283-
if err != nil {
284-
return nil, fmt.Errorf("failed to unmarshal the payload of the ID token: %w", err)
285-
}
286281
if idToken.Nonce != params.NonceCookieValue {
287282
return nil, fmt.Errorf("nonce mismatch")
288283
}
289-
err = s.validateRequiredClaims(idToken)
284+
validatedClaims, err := s.validateRequiredClaims(ctx, provider, idToken)
290285
if err != nil {
291286
return nil, fmt.Errorf("failed to validate required claims: %w", err)
292287
}
293288
return &AuthFlowResult{
294289
IDToken: idToken,
295-
Claims: claims,
290+
Claims: validatedClaims,
296291
}, nil
297292
}
298293

@@ -338,24 +333,60 @@ func (s *Service) createSession(ctx context.Context, flowResult *AuthFlowResult,
338333
return nil, message, fmt.Errorf("unexpected status code: %v", res.StatusCode)
339334
}
340335

341-
func (s *Service) validateRequiredClaims(token *goidc.IDToken) error {
336+
func (s *Service) validateRequiredClaims(ctx context.Context, provider *oidc.Provider, token *goidc.IDToken) (jwt.MapClaims, error) {
342337
if len(token.Audience) < 1 {
343-
return fmt.Errorf("audience claim is missing")
344-
}
345-
var claims struct {
346-
Email string `json:"email,omitempty"`
347-
Name string `json:"name,omitempty"`
348-
jwt.RegisteredClaims
338+
return nil, fmt.Errorf("audience claim is missing")
349339
}
340+
var claims jwt.MapClaims
350341
err := token.Claims(&claims)
351342
if err != nil {
352-
return fmt.Errorf("failed to unmarshal claims of ID token: %w", err)
343+
return nil, fmt.Errorf("failed to unmarshal claims of ID token: %w", err)
344+
}
345+
requiredClaims := []string{"email", "name"}
346+
missingClaims := []string{}
347+
for _, claim := range requiredClaims {
348+
if _, ok := claims[claim]; !ok {
349+
missingClaims = append(missingClaims, claim)
350+
}
351+
}
352+
if len(missingClaims) > 0 {
353+
err = s.fillClaims(ctx, provider, claims, missingClaims)
354+
if err != nil {
355+
log.WithError(err).Error("failed to fill claims")
356+
}
357+
// continue
353358
}
354-
if claims.Email == "" {
355-
return fmt.Errorf("email claim is missing")
359+
for _, claim := range requiredClaims {
360+
if _, ok := claims[claim]; !ok {
361+
return nil, fmt.Errorf("%s claim is missing", claim)
362+
}
356363
}
357-
if claims.Name == "" {
358-
return fmt.Errorf("name claim is missing")
364+
return claims, nil
365+
}
366+
367+
func (s *Service) fillClaims(ctx context.Context, provider *oidc.Provider, claims jwt.MapClaims, missingClaims []string) error {
368+
oauth2Info := GetOAuth2ResultFromContext(ctx)
369+
if oauth2Info == nil {
370+
return fmt.Errorf("oauth2 info not found")
371+
}
372+
userinfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Info.OAuth2Token))
373+
if err != nil {
374+
return fmt.Errorf("failed to get userinfo: %w", err)
375+
}
376+
var userinfoClaims map[string]interface{}
377+
if err := userinfo.Claims(&userinfoClaims); err != nil {
378+
return fmt.Errorf("failed to unmarshal userinfo claims: %w", err)
379+
}
380+
for _, key := range missingClaims {
381+
switch key {
382+
case "email":
383+
// check userinfo definition to get more info
384+
claims["email"] = userinfo.Email
385+
default:
386+
if value, ok := userinfoClaims[key]; ok {
387+
claims[key] = value
388+
}
389+
}
359390
}
360391
return nil
361392
}

components/public-api-server/pkg/oidc/service_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ func Test_validateRequiredClaims(t *testing.T) {
344344
t.Run(tc.Label, func(t *testing.T) {
345345
token := createTestIDToken(t, tc.Claims)
346346

347-
err := service.validateRequiredClaims(token)
347+
_, err := service.validateRequiredClaims(context.Background(), nil, token)
348348
if tc.ExpectedError == "" {
349349
require.NoError(t, err)
350350
}

0 commit comments

Comments
 (0)