Skip to content

Commit bb12633

Browse files
sixcolorsgaby
andauthored
Revert "🔥 feat: Add support for context.Context in keyauth middleware" (#3364)
Revert "🔥 feat: Add support for context.Context in keyauth middleware (#3287)" This reverts commit 4177ab4. Co-authored-by: Juan Calderon-Perez <[email protected]>
1 parent 36b9381 commit bb12633

File tree

2 files changed

+29
-75
lines changed

2 files changed

+29
-75
lines changed

middleware/keyauth/keyauth.go

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
package keyauth
33

44
import (
5-
"context"
65
"errors"
76
"fmt"
87
"net/url"
@@ -60,10 +59,7 @@ func New(config ...Config) fiber.Handler {
6059
valid, err := cfg.Validator(c, key)
6160

6261
if err == nil && valid {
63-
// Store in both Locals and Context
6462
c.Locals(tokenKey, key)
65-
ctx := context.WithValue(c.Context(), tokenKey, key)
66-
c.SetContext(ctx)
6763
return cfg.SuccessHandler(c)
6864
}
6965
return cfg.ErrorHandler(c, err)
@@ -72,20 +68,12 @@ func New(config ...Config) fiber.Handler {
7268

7369
// TokenFromContext returns the bearer token from the request context.
7470
// returns an empty string if the token does not exist
75-
func TokenFromContext(c any) string {
76-
switch ctx := c.(type) {
77-
case context.Context:
78-
if token, ok := ctx.Value(tokenKey).(string); ok {
79-
return token
80-
}
81-
case fiber.Ctx:
82-
if token, ok := ctx.Locals(tokenKey).(string); ok {
83-
return token
84-
}
85-
default:
86-
panic("unsupported context type, expected fiber.Ctx or context.Context")
71+
func TokenFromContext(c fiber.Ctx) string {
72+
token, ok := c.Locals(tokenKey).(string)
73+
if !ok {
74+
return ""
8775
}
88-
return ""
76+
return token
8977
}
9078

9179
// MultipleKeySourceLookup creates a CustomKeyLookup function that checks multiple sources until one is found

middleware/keyauth/keyauth_test.go

Lines changed: 24 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -503,67 +503,33 @@ func Test_TokenFromContext_None(t *testing.T) {
503503
}
504504

505505
func Test_TokenFromContext(t *testing.T) {
506-
// Test that TokenFromContext returns the correct token
507-
t.Run("fiber.Ctx", func(t *testing.T) {
508-
app := fiber.New()
509-
app.Use(New(Config{
510-
KeyLookup: "header:Authorization",
511-
AuthScheme: "Basic",
512-
Validator: func(_ fiber.Ctx, key string) (bool, error) {
513-
if key == CorrectKey {
514-
return true, nil
515-
}
516-
return false, ErrMissingOrMalformedAPIKey
517-
},
518-
}))
519-
app.Get("/", func(c fiber.Ctx) error {
520-
return c.SendString(TokenFromContext(c))
521-
})
522-
523-
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
524-
req.Header.Add("Authorization", "Basic "+CorrectKey)
525-
res, err := app.Test(req)
526-
require.NoError(t, err)
527-
528-
body, err := io.ReadAll(res.Body)
529-
require.NoError(t, err)
530-
require.Equal(t, CorrectKey, string(body))
506+
app := fiber.New()
507+
// Wire up keyauth middleware to set TokenFromContext now
508+
app.Use(New(Config{
509+
KeyLookup: "header:Authorization",
510+
AuthScheme: "Basic",
511+
Validator: func(_ fiber.Ctx, key string) (bool, error) {
512+
if key == CorrectKey {
513+
return true, nil
514+
}
515+
return false, ErrMissingOrMalformedAPIKey
516+
},
517+
}))
518+
// Define a test handler that checks TokenFromContext
519+
app.Get("/", func(c fiber.Ctx) error {
520+
return c.SendString(TokenFromContext(c))
531521
})
532522

533-
t.Run("context.Context", func(t *testing.T) {
534-
app := fiber.New()
535-
app.Use(New(Config{
536-
KeyLookup: "header:Authorization",
537-
AuthScheme: "Basic",
538-
Validator: func(_ fiber.Ctx, key string) (bool, error) {
539-
if key == CorrectKey {
540-
return true, nil
541-
}
542-
return false, ErrMissingOrMalformedAPIKey
543-
},
544-
}))
545-
// Verify that TokenFromContext works with context.Context
546-
app.Get("/", func(c fiber.Ctx) error {
547-
ctx := c.Context()
548-
token := TokenFromContext(ctx)
549-
return c.SendString(token)
550-
})
551-
552-
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
553-
req.Header.Add("Authorization", "Basic "+CorrectKey)
554-
res, err := app.Test(req)
555-
require.NoError(t, err)
556-
557-
body, err := io.ReadAll(res.Body)
558-
require.NoError(t, err)
559-
require.Equal(t, CorrectKey, string(body))
560-
})
523+
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
524+
req.Header.Add("Authorization", "Basic "+CorrectKey)
525+
// Send
526+
res, err := app.Test(req)
527+
require.NoError(t, err)
561528

562-
t.Run("invalid context type", func(t *testing.T) {
563-
require.Panics(t, func() {
564-
_ = TokenFromContext("invalid")
565-
})
566-
})
529+
// Read the response body into a string
530+
body, err := io.ReadAll(res.Body)
531+
require.NoError(t, err)
532+
require.Equal(t, CorrectKey, string(body))
567533
}
568534

569535
func Test_AuthSchemeToken(t *testing.T) {

0 commit comments

Comments
 (0)