Skip to content

Commit 2577745

Browse files
committed
Consolidate shared code across middlewares
Remove case for SecurityRequirementsError Improve test coverage
1 parent f9f18ff commit 2577745

File tree

7 files changed

+198
-109
lines changed

7 files changed

+198
-109
lines changed

context.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package ginmiddleware
2+
3+
import (
4+
"context"
5+
6+
"github.com/gin-gonic/gin"
7+
)
8+
9+
func getRequestContext(
10+
c *gin.Context,
11+
options *Options,
12+
) context.Context {
13+
requestContext := context.WithValue(context.Background(), GinContextKey, c)
14+
if options != nil {
15+
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData)
16+
}
17+
18+
return requestContext
19+
}

error.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package ginmiddleware
2+
3+
import (
4+
"errors"
5+
"net/http"
6+
7+
"github.com/getkin/kin-openapi/routers"
8+
"github.com/gin-gonic/gin"
9+
)
10+
11+
func handleValidationError(
12+
c *gin.Context,
13+
err error,
14+
options *Options,
15+
generalStatusCode int,
16+
) {
17+
var errorHandler ErrorHandler
18+
// if an error handler is provided, use that
19+
if options != nil && options.ErrorHandler != nil {
20+
errorHandler = options.ErrorHandler
21+
} else {
22+
errorHandler = func(c *gin.Context, message string, statusCode int) {
23+
c.AbortWithStatusJSON(statusCode, gin.H{"error": message})
24+
}
25+
}
26+
27+
if errors.Is(err, routers.ErrPathNotFound) {
28+
errorHandler(c, err.Error(), http.StatusNotFound)
29+
} else {
30+
errorHandler(c, err.Error(), generalStatusCode)
31+
}
32+
33+
// in case the handler didn't internally call Abort, stop the chain
34+
c.Abort()
35+
}

oapi_validate_request.go

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package ginmiddleware
1616

1717
import (
18-
"context"
1918
"errors"
2019
"fmt"
2120
"log"
@@ -65,22 +64,7 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.
6564
return func(c *gin.Context) {
6665
err := ValidateRequestFromContext(c, router, options)
6766
if err != nil {
68-
// using errors.Is did not work
69-
if options != nil && options.ErrorHandler != nil && err.Error() == routers.ErrPathNotFound.Error() {
70-
options.ErrorHandler(c, err.Error(), http.StatusNotFound)
71-
// in case the handler didn't internally call Abort, stop the chain
72-
c.Abort()
73-
} else if options != nil && options.ErrorHandler != nil {
74-
options.ErrorHandler(c, err.Error(), http.StatusBadRequest)
75-
// in case the handler didn't internally call Abort, stop the chain
76-
c.Abort()
77-
} else if err.Error() == routers.ErrPathNotFound.Error() {
78-
// note: i am not sure if this is the best way to handle this
79-
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": err.Error()})
80-
} else {
81-
// note: i am not sure if this is the best way to handle this
82-
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
83-
}
67+
handleValidationError(c, err, options, http.StatusBadRequest)
8468
}
8569
c.Next()
8670
}
@@ -89,38 +73,14 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.T, options *Options) gin.
8973
// ValidateRequestFromContext is called from the middleware above and actually does the work
9074
// of validating a request.
9175
func ValidateRequestFromContext(c *gin.Context, router routers.Router, options *Options) error {
92-
req := c.Request
93-
route, pathParams, err := router.FindRoute(req)
94-
95-
// We failed to find a matching route for the request.
76+
validationInput, err := getRequestValidationInput(c.Request, router, options)
9677
if err != nil {
97-
switch e := err.(type) {
98-
case *routers.RouteError:
99-
// We've got a bad request, the path requested doesn't match
100-
// either server, or path, or something.
101-
return errors.New(e.Reason)
102-
default:
103-
// This should never happen today, but if our upstream code changes,
104-
// we don't want to crash the server, so handle the unexpected error.
105-
return fmt.Errorf("error validating route: %s", err.Error())
106-
}
107-
}
108-
109-
validationInput := &openapi3filter.RequestValidationInput{
110-
Request: req,
111-
PathParams: pathParams,
112-
Route: route,
78+
return fmt.Errorf("error getting request validation input from route: %w", err)
11379
}
11480

115-
// Pass the gin context into the request validator, so that any callbacks
81+
// Pass the gin context into the response validator, so that any callbacks
11682
// which it invokes make it available.
117-
requestContext := context.WithValue(context.Background(), GinContextKey, c)
118-
119-
if options != nil {
120-
validationInput.Options = &options.Options
121-
validationInput.ParamDecoder = options.ParamDecoder
122-
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData)
123-
}
83+
requestContext := getRequestContext(c, options)
12484

12585
err = openapi3filter.ValidateRequest(requestContext, validationInput)
12686
if err != nil {

oapi_validate_response.go

Lines changed: 11 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package ginmiddleware
1616

1717
import (
1818
"bytes"
19-
"context"
2019
"errors"
2120
"fmt"
2221
"io"
@@ -47,9 +46,9 @@ func OapiResponseValidatorFromYamlFile(path string) (gin.HandlerFunc, error) {
4746
return OapiRequestValidator(swagger), nil
4847
}
4948

50-
// OapiRequestValidator is an gin middleware function which validates incoming HTTP requests
49+
// OapiResponseValidator is an gin middleware function which validates outgoing HTTP responses
5150
// to make sure that they conform to the given OAPI 3.0 specification. When
52-
// OAPI validation fails on the request, we return an HTTP/400 with error message
51+
// OAPI validation fails on the request, we return an HTTP/500 with error message
5352
func OapiResponseValidator(swagger *openapi3.T) gin.HandlerFunc {
5453
return OapiResponseValidatorWithOptions(swagger, nil)
5554
}
@@ -67,64 +66,28 @@ func OapiResponseValidatorWithOptions(swagger *openapi3.T, options *Options) gin
6766
return func(c *gin.Context) {
6867
err := ValidateResponseFromContext(c, router, options)
6968
if err != nil {
70-
if options != nil && options.ErrorHandler != nil {
71-
options.ErrorHandler(c, err.Error(), http.StatusInternalServerError)
72-
// in case the handler didn't internally call Abort, stop the chain
73-
c.Abort()
74-
} else {
75-
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
76-
}
69+
handleValidationError(c, err, options, http.StatusInternalServerError)
7770
}
78-
}
79-
}
8071

81-
type responseInterceptor struct {
82-
gin.ResponseWriter
83-
body *bytes.Buffer
84-
}
85-
86-
func (w *responseInterceptor) Write(b []byte) (int, error) {
87-
return w.body.Write(b)
72+
// in case an error was encountered before Next() was called, call it here
73+
c.Next()
74+
}
8875
}
8976

9077
// ValidateResponseFromContext is called from the middleware above and actually does the work
9178
// of validating a response.
9279
func ValidateResponseFromContext(c *gin.Context, router routers.Router, options *Options) error {
93-
req := c.Request
94-
route, pathParams, err := router.FindRoute(req)
95-
96-
// We failed to find a matching route for the request.
80+
reqValidationInput, err := getRequestValidationInput(c.Request, router, options)
9781
if err != nil {
98-
switch e := err.(type) {
99-
case *routers.RouteError:
100-
// We've got a bad request, the path requested doesn't match
101-
// either server, or path, or something.
102-
return errors.New(e.Reason)
103-
default:
104-
// This should never happen today, but if our upstream code changes,
105-
// we don't want to crash the server, so handle the unexpected error.
106-
return fmt.Errorf("error validating route: %s", err.Error())
107-
}
108-
}
109-
110-
reqValidationInput := &openapi3filter.RequestValidationInput{
111-
Request: req,
112-
PathParams: pathParams,
113-
Route: route,
82+
return fmt.Errorf("error getting request validation input from route: %w", err)
11483
}
11584

116-
// Pass the gin context into the request validator, so that any callbacks
85+
// Pass the gin context into the response validator, so that any callbacks
11786
// which it invokes make it available.
118-
requestContext := context.WithValue(context.Background(), GinContextKey, c)
119-
120-
if options != nil {
121-
reqValidationInput.Options = &options.Options
122-
reqValidationInput.ParamDecoder = options.ParamDecoder
123-
requestContext = context.WithValue(requestContext, UserDataKey, options.UserData)
124-
}
87+
requestContext := getRequestContext(c, options)
12588

12689
// wrap the response writer in a bodyWriter so we can capture the response body
127-
bw := &responseInterceptor{ResponseWriter: c.Writer, body: bytes.NewBufferString("")}
90+
bw := newResponseInterceptor(c.Writer)
12891
c.Writer = bw
12992

13093
// Call the next handler in the chain, which will actually process the request
@@ -146,7 +109,6 @@ func ValidateResponseFromContext(c *gin.Context, router routers.Router, options
146109
}
147110

148111
err = openapi3filter.ValidateResponse(requestContext, rspValidationInput)
149-
150112
if err != nil {
151113
// restore the original response writer
152114
c.Writer = bw.ResponseWriter
@@ -164,8 +126,6 @@ func ValidateResponseFromContext(c *gin.Context, router routers.Router, options
164126
// openapi errors seem to be multi-line with a decent message on the first
165127
errorLines := strings.Split(e.Error(), "\n")
166128
return fmt.Errorf("error in openapi3filter.ResponseError: %s", errorLines[0])
167-
case *openapi3filter.SecurityRequirementsError:
168-
return fmt.Errorf("error in openapi3filter.SecurityRequirementsError: %s", e.Error())
169129
default:
170130
// This should never happen today, but if our upstream code changes,
171131
// we don't want to crash the server, so handle the unexpected error.

oapi_validate_response_test.go

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,9 @@ func TestOapiResponseValidator(t *testing.T) {
3535
swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema)
3636
require.NoError(t, err, "Error initializing swagger")
3737

38-
// Create a new echo router
38+
// Create a new gin router
3939
g := gin.New()
4040

41-
// Set up an authenticator to check authenticated function. It will allow
42-
// access to "someScope", but disallow others.
4341
options := Options{
4442
ErrorHandler: func(c *gin.Context, message string, statusCode int) {
4543
c.String(statusCode, "test: "+message)
@@ -51,17 +49,21 @@ func TestOapiResponseValidator(t *testing.T) {
5149
UserData: "hi!",
5250
}
5351

54-
// Install our OpenApi based request validator
52+
// Install our OpenApi based response validator
5553
g.Use(OapiResponseValidatorWithOptions(swagger, &options))
5654

57-
tests := []struct {
58-
name string
59-
operationID string
60-
}{
61-
{
62-
name: "GET /resource",
63-
operationID: "getResource",
64-
},
55+
// Test an incorrect route
56+
{
57+
rec := doGet(t, g, "http://deepmap.ai/incorrect")
58+
assert.Equal(t, http.StatusNotFound, rec.Code)
59+
assert.Contains(t, rec.Body.String(), "no matching operation was found")
60+
}
61+
62+
// Test wrong server
63+
{
64+
rec := doGet(t, g, "http://wrongserver.ai/resource")
65+
assert.Equal(t, http.StatusNotFound, rec.Code)
66+
assert.Contains(t, rec.Body.String(), "no matching operation was found")
6567
}
6668

6769
// getResource
@@ -235,7 +237,7 @@ func TestOapiResponseValidator(t *testing.T) {
235237

236238
rec := doPost(t, g, "http://deepmap.ai/resource", gin.H{"name": "Wilhelm Scream"})
237239
assert.Equal(t, tt.wantStatus, rec.Code)
238-
if tt.wantStatus == http.StatusOK {
240+
if tt.wantStatus == http.StatusCreated {
239241
switch tt.contentType {
240242
case "application/json":
241243
assert.JSONEq(t, tt.wantRsp, rec.Body.String())
@@ -249,6 +251,20 @@ func TestOapiResponseValidator(t *testing.T) {
249251
}
250252
}
251253

254+
tests := []struct {
255+
name string
256+
operationID string
257+
}{
258+
{
259+
name: "GET /resource",
260+
operationID: "getResource",
261+
},
262+
{
263+
name: "POST /resource",
264+
operationID: "createResource",
265+
},
266+
}
267+
252268
for _, tt := range tests {
253269
t.Run(tt.name, func(t *testing.T) {
254270
switch tt.operationID {
@@ -259,5 +275,35 @@ func TestOapiResponseValidator(t *testing.T) {
259275
}
260276
})
261277
}
278+
}
262279

280+
func TestOapiResponseValidatorNoOptions(t *testing.T) {
281+
swagger, err := openapi3.NewLoader().LoadFromData(testResponseSchema)
282+
require.NoError(t, err, "Error initializing swagger")
283+
284+
mw := OapiResponseValidator(swagger)
285+
assert.NotNil(t, mw, "Response validator is nil")
286+
}
287+
288+
func TestOapiResponseValidatorFromYamlFile(t *testing.T) {
289+
// Test that we can load a response validator from a yaml file
290+
{
291+
mw, err := OapiResponseValidatorFromYamlFile("test_response_spec.yaml")
292+
assert.NoError(t, err, "Error initializing response validator")
293+
assert.NotNil(t, mw, "Response validator is nil")
294+
}
295+
296+
// Test that we get an error when the file does not exist
297+
{
298+
mw, err := OapiResponseValidatorFromYamlFile("nonexistent.yaml")
299+
assert.Error(t, err, "Expected error initializing response validator")
300+
assert.Nil(t, mw, "Response validator is not nil")
301+
}
302+
303+
// Test that we get an error when the file is not a valid yaml file
304+
{
305+
mw, err := OapiResponseValidatorFromYamlFile("README.md")
306+
assert.Error(t, err, "Expected error initializing response validator")
307+
assert.Nil(t, mw, "Response validator is not nil")
308+
}
263309
}

response_interceptor.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package ginmiddleware
2+
3+
import (
4+
"bytes"
5+
6+
"github.com/gin-gonic/gin"
7+
)
8+
9+
type responseInterceptor struct {
10+
gin.ResponseWriter
11+
body *bytes.Buffer
12+
}
13+
14+
var _ gin.ResponseWriter = &responseInterceptor{}
15+
16+
func newResponseInterceptor(w gin.ResponseWriter) *responseInterceptor {
17+
return &responseInterceptor{
18+
ResponseWriter: w,
19+
body: bytes.NewBufferString(""),
20+
}
21+
}
22+
23+
func (w *responseInterceptor) Write(b []byte) (int, error) {
24+
return w.body.Write(b)
25+
}

0 commit comments

Comments
 (0)