@@ -278,21 +278,16 @@ func (s *Service) authenticate(ctx context.Context, params authenticateParams) (
278
278
if err != nil {
279
279
return nil , fmt .Errorf ("failed to verify id_token: %w" , err )
280
280
}
281
- claims := map [string ]interface {}{}
282
- err = idToken .Claims (& claims )
283
- if err != nil {
284
- return nil , fmt .Errorf ("failed to unmarshal the payload of the ID token: %w" , err )
285
- }
286
281
if idToken .Nonce != params .NonceCookieValue {
287
282
return nil , fmt .Errorf ("nonce mismatch" )
288
283
}
289
- err = s .validateRequiredClaims (idToken )
284
+ validatedClaims , err : = s .validateRequiredClaims (ctx , provider , idToken )
290
285
if err != nil {
291
286
return nil , fmt .Errorf ("failed to validate required claims: %w" , err )
292
287
}
293
288
return & AuthFlowResult {
294
289
IDToken : idToken ,
295
- Claims : claims ,
290
+ Claims : validatedClaims ,
296
291
}, nil
297
292
}
298
293
@@ -338,24 +333,60 @@ func (s *Service) createSession(ctx context.Context, flowResult *AuthFlowResult,
338
333
return nil , message , fmt .Errorf ("unexpected status code: %v" , res .StatusCode )
339
334
}
340
335
341
- func (s * Service ) validateRequiredClaims (token * goidc.IDToken ) error {
336
+ func (s * Service ) validateRequiredClaims (ctx context. Context , provider * oidc. Provider , token * goidc.IDToken ) (jwt. MapClaims , error ) {
342
337
if len (token .Audience ) < 1 {
343
- return fmt .Errorf ("audience claim is missing" )
344
- }
345
- var claims struct {
346
- Email string `json:"email,omitempty"`
347
- Name string `json:"name,omitempty"`
348
- jwt.RegisteredClaims
338
+ return nil , fmt .Errorf ("audience claim is missing" )
349
339
}
340
+ var claims jwt.MapClaims
350
341
err := token .Claims (& claims )
351
342
if err != nil {
352
- return fmt .Errorf ("failed to unmarshal claims of ID token: %w" , err )
343
+ return nil , fmt .Errorf ("failed to unmarshal claims of ID token: %w" , err )
344
+ }
345
+ requiredClaims := []string {"email" , "name" }
346
+ missingClaims := []string {}
347
+ for _ , claim := range requiredClaims {
348
+ if _ , ok := claims [claim ]; ! ok {
349
+ missingClaims = append (missingClaims , claim )
350
+ }
351
+ }
352
+ if len (missingClaims ) > 0 {
353
+ err = s .fillClaims (ctx , provider , claims , missingClaims )
354
+ if err != nil {
355
+ log .WithError (err ).Error ("failed to fill claims" )
356
+ }
357
+ // continue
353
358
}
354
- if claims .Email == "" {
355
- return fmt .Errorf ("email claim is missing" )
359
+ for _ , claim := range requiredClaims {
360
+ if _ , ok := claims [claim ]; ! ok {
361
+ return nil , fmt .Errorf ("%s claim is missing" , claim )
362
+ }
356
363
}
357
- if claims .Name == "" {
358
- return fmt .Errorf ("name claim is missing" )
364
+ return claims , nil
365
+ }
366
+
367
+ func (s * Service ) fillClaims (ctx context.Context , provider * oidc.Provider , claims jwt.MapClaims , missingClaims []string ) error {
368
+ oauth2Info := GetOAuth2ResultFromContext (ctx )
369
+ if oauth2Info == nil {
370
+ return fmt .Errorf ("oauth2 info not found" )
371
+ }
372
+ userinfo , err := provider .UserInfo (ctx , oauth2 .StaticTokenSource (oauth2Info .OAuth2Token ))
373
+ if err != nil {
374
+ return fmt .Errorf ("failed to get userinfo: %w" , err )
375
+ }
376
+ var userinfoClaims map [string ]interface {}
377
+ if err := userinfo .Claims (& userinfoClaims ); err != nil {
378
+ return fmt .Errorf ("failed to unmarshal userinfo claims: %w" , err )
379
+ }
380
+ for _ , key := range missingClaims {
381
+ switch key {
382
+ case "email" :
383
+ // check userinfo definition to get more info
384
+ claims ["email" ] = userinfo .Email
385
+ default :
386
+ if value , ok := userinfoClaims [key ]; ok {
387
+ claims [key ] = value
388
+ }
389
+ }
359
390
}
360
391
return nil
361
392
}
0 commit comments