mirror of
https://github.com/labstack/echo.git
synced 2025-01-01 22:09:21 +02:00
add a custom error handler to key-auth middleware (#1847)
* add a custom error handler to key-auth middleware
This commit is contained in:
parent
76f186ad3b
commit
7256cb2274
@ -30,12 +30,19 @@ type (
|
||||
// Validator is a function to validate key.
|
||||
// Required.
|
||||
Validator KeyAuthValidator
|
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid key.
|
||||
// It may be used to define a custom error.
|
||||
ErrorHandler KeyAuthErrorHandler
|
||||
}
|
||||
|
||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||
KeyAuthValidator func(string, echo.Context) (bool, error)
|
||||
|
||||
keyExtractor func(echo.Context) (string, error)
|
||||
|
||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||
KeyAuthErrorHandler func(error, echo.Context) error
|
||||
)
|
||||
|
||||
var (
|
||||
@ -95,10 +102,16 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
||||
// Extract and verify key
|
||||
key, err := extractor(c)
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err, c)
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
valid, err := config.Validator(key, c)
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err, c)
|
||||
}
|
||||
return &echo.HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "invalid key",
|
||||
|
@ -1,9 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -11,65 +11,225 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testKeyValidator(key string, c echo.Context) (bool, error) {
|
||||
switch key {
|
||||
case "valid-key":
|
||||
return true, nil
|
||||
case "error-key":
|
||||
return false, errors.New("some user defined error")
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyAuth(t *testing.T) {
|
||||
handlerCalled := false
|
||||
handler := func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
middlewareChain := KeyAuth(testKeyValidator)(handler)
|
||||
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
config := KeyAuthConfig{
|
||||
Validator: func(key string, c echo.Context) (bool, error) {
|
||||
return key == "valid-key", nil
|
||||
|
||||
err := middlewareChain(c)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, handlerCalled)
|
||||
}
|
||||
|
||||
func TestKeyAuthWithConfig(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequestFunc func() *http.Request
|
||||
givenRequest func(req *http.Request)
|
||||
whenConfig func(conf *KeyAuthConfig)
|
||||
expectHandlerCalled bool
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, defaults, key from header",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
{
|
||||
name: "ok, custom skipper",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.Skipper = func(context echo.Context) bool {
|
||||
return true
|
||||
}
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, invalid key from header, Authorization: Bearer",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, invalid scheme in header",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=invalid key in the request header",
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, missing header",
|
||||
givenRequest: func(req *http.Request) {},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in request header",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, header",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set("API-Key", "valid-key")
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "header:API-Key"
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
{
|
||||
name: "nok, custom key lookup, missing header",
|
||||
givenRequest: func(req *http.Request) {
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "header:API-Key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in request header",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, query",
|
||||
givenRequest: func(req *http.Request) {
|
||||
q := req.URL.Query()
|
||||
q.Add("key", "valid-key")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "query:key"
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
{
|
||||
name: "nok, custom key lookup, missing query param",
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "query:key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in the query string",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, form",
|
||||
givenRequestFunc: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
return req
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "form:key"
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
{
|
||||
name: "nok, custom key lookup, missing key in form",
|
||||
givenRequestFunc: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
return req
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "form:key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in the form",
|
||||
},
|
||||
{
|
||||
name: "nok, custom errorHandler, error from extractor",
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "header:token"
|
||||
conf.ErrorHandler = func(err error, context echo.Context) error {
|
||||
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
|
||||
httpError.Internal = err
|
||||
return httpError
|
||||
}
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=418, message=custom, internal=missing key in request header",
|
||||
},
|
||||
{
|
||||
name: "nok, custom errorHandler, error from validator",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.ErrorHandler = func(err error, context echo.Context) error {
|
||||
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
|
||||
httpError.Internal = err
|
||||
return httpError
|
||||
}
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=418, message=custom, internal=some user defined error",
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, error from validator",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=invalid key, internal=some user defined error",
|
||||
},
|
||||
}
|
||||
h := KeyAuthWithConfig(config)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
assert := assert.New(t)
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlerCalled := false
|
||||
handler := func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
config := KeyAuthConfig{
|
||||
Validator: testKeyValidator,
|
||||
}
|
||||
if tc.whenConfig != nil {
|
||||
tc.whenConfig(&config)
|
||||
}
|
||||
middlewareChain := KeyAuthWithConfig(config)(handler)
|
||||
|
||||
// Valid key
|
||||
auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key"
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(h(c))
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequestFunc != nil {
|
||||
req = tc.givenRequestFunc()
|
||||
}
|
||||
if tc.givenRequest != nil {
|
||||
tc.givenRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
// Invalid key
|
||||
auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key"
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusUnauthorized, he.Code)
|
||||
err := middlewareChain(c)
|
||||
|
||||
// Missing Authorization header
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(http.StatusBadRequest, he.Code)
|
||||
|
||||
// Key from custom header
|
||||
config.KeyLookup = "header:API-Key"
|
||||
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
req.Header.Set("API-Key", "valid-key")
|
||||
assert.NoError(h(c))
|
||||
|
||||
// Key from query string
|
||||
config.KeyLookup = "query:key"
|
||||
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
q := req.URL.Query()
|
||||
q.Add("key", "valid-key")
|
||||
req.URL.RawQuery = q.Encode()
|
||||
assert.NoError(h(c))
|
||||
|
||||
// Key from form
|
||||
config.KeyLookup = "form:key"
|
||||
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
f := make(url.Values)
|
||||
f.Set("key", "valid-key")
|
||||
req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
c = e.NewContext(req, rec)
|
||||
assert.NoError(h(c))
|
||||
assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user