mirror of
https://github.com/labstack/echo.git
synced 2024-11-24 08:22:21 +02:00
4a1ccdfdc5
* CSRF, JWT, KeyAuth middleware support for multivalue value extractors * Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
377 lines
10 KiB
Go
377 lines
10 KiB
Go
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func testKeyValidator(key string, c echo.Context) (bool, error) {
|
|
switch key {
|
|
case "valid-key":
|
|
return true, nil
|
|
case "error-key":
|
|
return false, errors.New("some user defined error")
|
|
default:
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
func TestKeyAuth(t *testing.T) {
|
|
handlerCalled := false
|
|
handler := func(c echo.Context) error {
|
|
handlerCalled = true
|
|
return c.String(http.StatusOK, "test")
|
|
}
|
|
middlewareChain := KeyAuth(testKeyValidator)(handler)
|
|
|
|
e := echo.New()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := middlewareChain(c)
|
|
|
|
assert.NoError(t, err)
|
|
assert.True(t, handlerCalled)
|
|
}
|
|
|
|
func TestKeyAuthWithConfig(t *testing.T) {
|
|
var testCases = []struct {
|
|
name string
|
|
givenRequestFunc func() *http.Request
|
|
givenRequest func(req *http.Request)
|
|
whenConfig func(conf *KeyAuthConfig)
|
|
expectHandlerCalled bool
|
|
expectError string
|
|
}{
|
|
{
|
|
name: "ok, defaults, key from header",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key")
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "ok, custom skipper",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.Skipper = func(context echo.Context) bool {
|
|
return true
|
|
}
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "nok, defaults, invalid key from header, Authorization: Bearer",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=401, message=Unauthorized, internal=invalid key",
|
|
},
|
|
{
|
|
name: "nok, defaults, invalid scheme in header",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set(echo.HeaderAuthorization, "Bear valid-key")
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=400, message=invalid key in the request header",
|
|
},
|
|
{
|
|
name: "nok, defaults, missing header",
|
|
givenRequest: func(req *http.Request) {},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=400, message=missing key in request header",
|
|
},
|
|
{
|
|
name: "ok, custom key lookup from multiple places, query and header",
|
|
givenRequest: func(req *http.Request) {
|
|
req.URL.RawQuery = "key=invalid-key"
|
|
req.Header.Set("API-Key", "valid-key")
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "query:key,header:API-Key"
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "ok, custom key lookup, header",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set("API-Key", "valid-key")
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "header:API-Key"
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "nok, custom key lookup, missing header",
|
|
givenRequest: func(req *http.Request) {
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "header:API-Key"
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=400, message=missing key in request header",
|
|
},
|
|
{
|
|
name: "ok, custom key lookup, query",
|
|
givenRequest: func(req *http.Request) {
|
|
q := req.URL.Query()
|
|
q.Add("key", "valid-key")
|
|
req.URL.RawQuery = q.Encode()
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "query:key"
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "nok, custom key lookup, missing query param",
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "query:key"
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=400, message=missing key in the query string",
|
|
},
|
|
{
|
|
name: "ok, custom key lookup, form",
|
|
givenRequestFunc: func() *http.Request {
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key"))
|
|
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
|
return req
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "form:key"
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "nok, custom key lookup, missing key in form",
|
|
givenRequestFunc: func() *http.Request {
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key"))
|
|
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
|
return req
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "form:key"
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=400, message=missing key in the form",
|
|
},
|
|
{
|
|
name: "ok, custom key lookup, cookie",
|
|
givenRequest: func(req *http.Request) {
|
|
req.AddCookie(&http.Cookie{
|
|
Name: "key",
|
|
Value: "valid-key",
|
|
})
|
|
q := req.URL.Query()
|
|
q.Add("key", "valid-key")
|
|
req.URL.RawQuery = q.Encode()
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "cookie:key"
|
|
},
|
|
expectHandlerCalled: true,
|
|
},
|
|
{
|
|
name: "nok, custom key lookup, missing cookie param",
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "cookie:key"
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=400, message=missing key in cookies",
|
|
},
|
|
{
|
|
name: "nok, custom errorHandler, error from extractor",
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.KeyLookup = "header:token"
|
|
conf.ErrorHandler = func(err error, context echo.Context) error {
|
|
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
|
|
httpError.Internal = err
|
|
return httpError
|
|
}
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=418, message=custom, internal=missing key in request header",
|
|
},
|
|
{
|
|
name: "nok, custom errorHandler, error from validator",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {
|
|
conf.ErrorHandler = func(err error, context echo.Context) error {
|
|
httpError := echo.NewHTTPError(http.StatusTeapot, "custom")
|
|
httpError.Internal = err
|
|
return httpError
|
|
}
|
|
},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=418, message=custom, internal=some user defined error",
|
|
},
|
|
{
|
|
name: "nok, defaults, error from validator",
|
|
givenRequest: func(req *http.Request) {
|
|
req.Header.Set(echo.HeaderAuthorization, "Bearer error-key")
|
|
},
|
|
whenConfig: func(conf *KeyAuthConfig) {},
|
|
expectHandlerCalled: false,
|
|
expectError: "code=401, message=Unauthorized, internal=some user defined error",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
handlerCalled := false
|
|
handler := func(c echo.Context) error {
|
|
handlerCalled = true
|
|
return c.String(http.StatusOK, "test")
|
|
}
|
|
config := KeyAuthConfig{
|
|
Validator: testKeyValidator,
|
|
}
|
|
if tc.whenConfig != nil {
|
|
tc.whenConfig(&config)
|
|
}
|
|
middlewareChain := KeyAuthWithConfig(config)(handler)
|
|
|
|
e := echo.New()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
if tc.givenRequestFunc != nil {
|
|
req = tc.givenRequestFunc()
|
|
}
|
|
if tc.givenRequest != nil {
|
|
tc.givenRequest(req)
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
c := e.NewContext(req, rec)
|
|
|
|
err := middlewareChain(c)
|
|
|
|
assert.Equal(t, tc.expectHandlerCalled, handlerCalled)
|
|
if tc.expectError != "" {
|
|
assert.EqualError(t, err, tc.expectError)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
|
|
assert.PanicsWithError(
|
|
t,
|
|
"extractor source for lookup could not be split into needed parts: a",
|
|
func() {
|
|
handler := func(c echo.Context) error {
|
|
return c.String(http.StatusOK, "test")
|
|
}
|
|
KeyAuthWithConfig(KeyAuthConfig{
|
|
Validator: testKeyValidator,
|
|
KeyLookup: "a",
|
|
})(handler)
|
|
},
|
|
)
|
|
}
|
|
|
|
func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
|
|
assert.PanicsWithValue(
|
|
t,
|
|
"echo: key-auth middleware requires a validator function",
|
|
func() {
|
|
handler := func(c echo.Context) error {
|
|
return c.String(http.StatusOK, "test")
|
|
}
|
|
KeyAuthWithConfig(KeyAuthConfig{
|
|
Validator: nil,
|
|
})(handler)
|
|
},
|
|
)
|
|
}
|
|
|
|
func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
|
|
var testCases = []struct {
|
|
name string
|
|
whenContinueOnIgnoredError bool
|
|
givenKey string
|
|
expectStatus int
|
|
expectBody string
|
|
}{
|
|
{
|
|
name: "no error handler is called",
|
|
whenContinueOnIgnoredError: true,
|
|
givenKey: "valid-key",
|
|
expectStatus: http.StatusTeapot,
|
|
expectBody: "",
|
|
},
|
|
{
|
|
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
|
|
whenContinueOnIgnoredError: false,
|
|
givenKey: "",
|
|
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
|
|
expectStatus: http.StatusOK,
|
|
expectBody: "",
|
|
},
|
|
{
|
|
name: "error handler is called for missing token",
|
|
whenContinueOnIgnoredError: true,
|
|
givenKey: "",
|
|
expectStatus: http.StatusTeapot,
|
|
expectBody: "public-auth",
|
|
},
|
|
{
|
|
name: "error handler is called for invalid token",
|
|
whenContinueOnIgnoredError: true,
|
|
givenKey: "x.x.x",
|
|
expectStatus: http.StatusUnauthorized,
|
|
expectBody: "{\"message\":\"Unauthorized\"}\n",
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
e := echo.New()
|
|
|
|
e.GET("/", func(c echo.Context) error {
|
|
testValue, _ := c.Get("test").(string)
|
|
return c.String(http.StatusTeapot, testValue)
|
|
})
|
|
|
|
e.Use(KeyAuthWithConfig(KeyAuthConfig{
|
|
Validator: testKeyValidator,
|
|
ErrorHandler: func(err error, c echo.Context) error {
|
|
if _, ok := err.(*ErrKeyAuthMissing); ok {
|
|
c.Set("test", "public-auth")
|
|
return nil
|
|
}
|
|
return echo.ErrUnauthorized
|
|
},
|
|
KeyLookup: "header:X-API-Key",
|
|
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
if tc.givenKey != "" {
|
|
req.Header.Set("X-API-Key", tc.givenKey)
|
|
}
|
|
res := httptest.NewRecorder()
|
|
|
|
e.ServeHTTP(res, req)
|
|
|
|
assert.Equal(t, tc.expectStatus, res.Code)
|
|
assert.Equal(t, tc.expectBody, res.Body.String())
|
|
})
|
|
}
|
|
}
|