Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions oapi_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,22 @@ type ErrorHandler func(w http.ResponseWriter, message string, statusCode int)

// ErrorHandlerWithOpts is called when there is an error in validation, with more information about the `error` that occurred and which request is currently being processed.
//
// There are a number of known types that the `error` can be:
//
// - `*openapi3filter.SecurityRequirementsError` - if the `AuthenticationFunc` has failed to authenticate the request
// - `*openapi3filter.RequestError` - if a bad request has been made
//
// Additionally, if you have set `openapi3filter.Options#MultiError`:
//
// - `openapi3.MultiError` (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
//
// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence.
//
// NOTE that this should ideally be used instead of ErrorHandler
type ErrorHandlerWithOpts func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts ErrorHandlerOpts)
type ErrorHandlerWithOpts func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts ErrorHandlerOpts)

// ErrorHandlerOpts contains additional options that are passed to the `ErrorHandlerWithOpts` function in the case of an error being returned by the middleware
type ErrorHandlerOpts struct {
// Error is the underlying error that triggered this error handler to be executed.
//
// Known error types:
//
// - `*openapi3filter.SecurityRequirementsError` - if the `AuthenticationFunc` has failed to authenticate the request
// - `*openapi3filter.RequestError` - if a bad request has been made
//
// Additionally, if you have set `openapi3filter.Options#MultiError`:
//
// - `openapi3.MultiError` (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError)
Error error

// StatusCode indicates the HTTP Status Code that the OpenAPI validation middleware _suggests_ is returned to the user.
//
// NOTE that this is very much a suggestion, and can be overridden if you believe you have a better approach.
Expand Down Expand Up @@ -163,11 +160,10 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
if err != nil {
errOpts := ErrorHandlerOpts{
// MatchedRoute will be nil, as we've not matched a route we know about
Error: err,
StatusCode: http.StatusNotFound,
}

options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
options.ErrorHandlerWithOpts(r.Context(), err, w, r, errOpts)
return
}

Expand Down Expand Up @@ -197,26 +193,28 @@ func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.R
return
}

var theErr error

switch e := err.(type) {
case openapi3.MultiError:
errOpts.Error = e
theErr = e
errOpts.StatusCode = determineStatusCodeForMultiError(e)
case *openapi3filter.RequestError:
// We've got a bad request
errOpts.Error = e
theErr = e
errOpts.StatusCode = http.StatusBadRequest
case *openapi3filter.SecurityRequirementsError:
errOpts.Error = e
theErr = e
errOpts.StatusCode = http.StatusUnauthorized
default:
// This should never happen today, but if our upstream code changes,
// we don't want to crash the server, so handle the unexpected error.
// return http.StatusInternalServerError,
errOpts.Error = fmt.Errorf("error validating route: %w", e)
theErr = fmt.Errorf("error validating route: %w", e)
errOpts.StatusCode = http.StatusUnauthorized
}

options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts)
options.ErrorHandlerWithOpts(r.Context(), theErr, w, r, errOpts)
}

// validateRequest is called from the middleware above and actually does the work
Expand Down
8 changes: 2 additions & 6 deletions oapi_validate_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,7 @@ components:
return fmt.Errorf("this check always fails - don't let anyone in!")
}

errorHandlerFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
err := opts.Error

errorHandlerFunc := func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
if opts.MatchedRoute == nil {
fmt.Printf("ErrorHandlerWithOpts: An HTTP %d was returned by the middleware with error message: %s\n", opts.StatusCode, err.Error())

Expand Down Expand Up @@ -716,9 +714,7 @@ paths:
w.WriteHeader(http.StatusMethodNotAllowed)
})

errorHandlerFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
err := opts.Error

errorHandlerFunc := func(ctx context.Context, err error, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) {
if opts.MatchedRoute == nil {
fmt.Printf("ErrorHandlerWithOpts: An HTTP %d was returned by the middleware with error message: %s\n", opts.StatusCode, err.Error())

Expand Down