package middleware import ( "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func testKeyValidator(key string, c echo.Context) (bool, error) { switch key { case "valid-key": return true, nil case "error-key": return false, errors.New("some user defined error") default: return false, nil } } func TestKeyAuth(t *testing.T) { handlerCalled := false handler := func(c echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } middlewareChain := KeyAuth(testKeyValidator)(handler) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := middlewareChain(c) assert.NoError(t, err) assert.True(t, handlerCalled) } func TestKeyAuthWithConfig(t *testing.T) { var testCases = []struct { name string givenRequestFunc func() *http.Request givenRequest func(req *http.Request) whenConfig func(conf *KeyAuthConfig) expectHandlerCalled bool expectError string }{ { name: "ok, defaults, key from header", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer valid-key") }, expectHandlerCalled: true, }, { name: "ok, custom skipper", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.Skipper = func(context echo.Context) bool { return true } }, expectHandlerCalled: true, }, { name: "nok, defaults, invalid key from header, Authorization: Bearer", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, expectError: "code=401, message=Unauthorized, internal=invalid key", }, { name: "nok, defaults, invalid scheme in header", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, expectError: "code=400, message=invalid key in the request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, expectError: "code=400, message=missing key in request header", }, { name: "ok, custom key lookup from multiple places, query and header", givenRequest: func(req *http.Request) { req.URL.RawQuery = "key=invalid-key" req.Header.Set("API-Key", "valid-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "query:key,header:API-Key" }, expectHandlerCalled: true, }, { name: "ok, custom key lookup, header", givenRequest: func(req *http.Request) { req.Header.Set("API-Key", "valid-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing header", givenRequest: func(req *http.Request) { }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, expectError: "code=400, message=missing key in request header", }, { name: "ok, custom key lookup, query", givenRequest: func(req *http.Request) { q := req.URL.Query() q.Add("key", "valid-key") req.URL.RawQuery = q.Encode() }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "query:key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing query param", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, expectError: "code=400, message=missing key in the query string", }, { name: "ok, custom key lookup, form", givenRequestFunc: func() *http.Request { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("key=valid-key")) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "form:key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing key in form", givenRequestFunc: func() *http.Request { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("xxx=valid-key")) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, expectError: "code=400, message=missing key in the form", }, { name: "ok, custom key lookup, cookie", givenRequest: func(req *http.Request) { req.AddCookie(&http.Cookie{ Name: "key", Value: "valid-key", }) q := req.URL.Query() q.Add("key", "valid-key") req.URL.RawQuery = q.Encode() }, whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: true, }, { name: "nok, custom key lookup, missing cookie param", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, expectError: "code=400, message=missing key in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" conf.ErrorHandler = func(err error, context echo.Context) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError } }, expectHandlerCalled: false, expectError: "code=418, message=custom, internal=missing key in request header", }, { name: "nok, custom errorHandler, error from validator", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { conf.ErrorHandler = func(err error, context echo.Context) error { httpError := echo.NewHTTPError(http.StatusTeapot, "custom") httpError.Internal = err return httpError } }, expectHandlerCalled: false, expectError: "code=418, message=custom, internal=some user defined error", }, { name: "nok, defaults, error from validator", givenRequest: func(req *http.Request) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, expectError: "code=401, message=Unauthorized, internal=some user defined error", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { handlerCalled := false handler := func(c echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } config := KeyAuthConfig{ Validator: testKeyValidator, } if tc.whenConfig != nil { tc.whenConfig(&config) } middlewareChain := KeyAuthWithConfig(config)(handler) e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.givenRequestFunc != nil { req = tc.givenRequestFunc() } if tc.givenRequest != nil { tc.givenRequest(req) } rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := middlewareChain(c) assert.Equal(t, tc.expectHandlerCalled, handlerCalled) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } }) } } func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { assert.PanicsWithError( t, "extractor source for lookup could not be split into needed parts: a", func() { handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } KeyAuthWithConfig(KeyAuthConfig{ Validator: testKeyValidator, KeyLookup: "a", })(handler) }, ) } func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { assert.PanicsWithValue( t, "echo: key-auth middleware requires a validator function", func() { handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } KeyAuthWithConfig(KeyAuthConfig{ Validator: nil, })(handler) }, ) } func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { var testCases = []struct { name string whenContinueOnIgnoredError bool givenKey string expectStatus int expectBody string }{ { name: "no error handler is called", whenContinueOnIgnoredError: true, givenKey: "valid-key", expectStatus: http.StatusTeapot, expectBody: "", }, { name: "ContinueOnIgnoredError is false and error handler is called for missing token", whenContinueOnIgnoredError: false, givenKey: "", // empty response with 200. This emulates previous behaviour when error handler swallowed the error expectStatus: http.StatusOK, expectBody: "", }, { name: "error handler is called for missing token", whenContinueOnIgnoredError: true, givenKey: "", expectStatus: http.StatusTeapot, expectBody: "public-auth", }, { name: "error handler is called for invalid token", whenContinueOnIgnoredError: true, givenKey: "x.x.x", expectStatus: http.StatusUnauthorized, expectBody: "{\"message\":\"Unauthorized\"}\n", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { testValue, _ := c.Get("test").(string) return c.String(http.StatusTeapot, testValue) }) e.Use(KeyAuthWithConfig(KeyAuthConfig{ Validator: testKeyValidator, ErrorHandler: func(err error, c echo.Context) error { if _, ok := err.(*ErrKeyAuthMissing); ok { c.Set("test", "public-auth") return nil } return echo.ErrUnauthorized }, KeyLookup: "header:X-API-Key", ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, })) req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.givenKey != "" { req.Header.Set("X-API-Key", tc.givenKey) } res := httptest.NewRecorder() e.ServeHTTP(res, req) assert.Equal(t, tc.expectStatus, res.Code) assert.Equal(t, tc.expectBody, res.Body.String()) }) } }