From 4a1ccdfdc520eb90573a97a7d04fd9fc300c1629 Mon Sep 17 00:00:00 2001 From: Martti T Date: Mon, 24 Jan 2022 22:03:45 +0200 Subject: [PATCH] JWT, KeyAuth, CSRF multivalue extractors (#2060) * CSRF, JWT, KeyAuth middleware support for multivalue value extractors * Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil). --- echo.go | 4 +- middleware/csrf.go | 110 +++---- middleware/csrf_test.go | 276 +++++++++++++--- middleware/extractor.go | 184 +++++++++++ middleware/extractor_test.go | 587 +++++++++++++++++++++++++++++++++++ middleware/jwt.go | 190 ++++-------- middleware/jwt_test.go | 310 +++++++++++++----- middleware/key_auth.go | 173 +++++------ middleware/key_auth_test.go | 123 +++++++- middleware/middleware.go | 4 +- 10 files changed, 1564 insertions(+), 397 deletions(-) create mode 100644 middleware/extractor.go create mode 100644 middleware/extractor_test.go diff --git a/echo.go b/echo.go index 56255c6c..2e63cc6b 100644 --- a/echo.go +++ b/echo.go @@ -111,10 +111,10 @@ type ( } // MiddlewareFunc defines a function to process middleware. - MiddlewareFunc func(HandlerFunc) HandlerFunc + MiddlewareFunc func(next HandlerFunc) HandlerFunc // HandlerFunc defines a function to serve HTTP requests. - HandlerFunc func(Context) error + HandlerFunc func(c Context) error // HTTPErrorHandler is a centralized HTTP error handler. HTTPErrorHandler func(error, Context) diff --git a/middleware/csrf.go b/middleware/csrf.go index 7804997d..61299f5c 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -2,9 +2,7 @@ package middleware import ( "crypto/subtle" - "errors" "net/http" - "strings" "time" "github.com/labstack/echo/v4" @@ -21,13 +19,15 @@ type ( TokenLength uint8 `yaml:"token_length"` // Optional. Default value 32. - // TokenLookup is a string in the form of ":" that is used + // TokenLookup is a string in the form of ":" or ":,:" that is used // to extract token from the request. // Optional. Default value "header:X-CSRF-Token". // Possible values: - // - "header:" - // - "form:" + // - "header:" or "header::" // - "query:" + // - "form:" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" TokenLookup string `yaml:"token_lookup"` // Context key to store generated CSRF token into context. @@ -62,12 +62,11 @@ type ( // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` } - - // csrfTokenExtractor defines a function that takes `echo.Context` and returns - // either a token or an error. - csrfTokenExtractor func(echo.Context) (string, error) ) +// ErrCSRFInvalid is returned when CSRF check fails +var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ @@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - // Initialize - parts := strings.Split(config.TokenLookup, ":") - extractor := csrfTokenFromHeader(parts[1]) - switch parts[0] { - case "form": - extractor = csrfTokenFromForm(parts[1]) - case "query": - extractor = csrfTokenFromQuery(parts[1]) + extractors, err := createExtractors(config.TokenLookup, "") + if err != nil { + panic(err) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - req := c.Request() - k, err := c.Cookie(config.CookieName) token := "" - - // Generate token - if err != nil { - token = random.String(config.TokenLength) + if k, err := c.Cookie(config.CookieName); err != nil { + token = random.String(config.TokenLength) // Generate token } else { - // Reuse token - token = k.Value + token = k.Value // Reuse token } - switch req.Method { + switch c.Request().Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 - clientToken, err := extractor(c) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + var lastExtractorErr error + var lastTokenErr error + outer: + for _, extractor := range extractors { + clientTokens, err := extractor(c) + if err != nil { + lastExtractorErr = err + continue + } + + for _, clientToken := range clientTokens { + if validateCSRFToken(token, clientToken) { + lastTokenErr = nil + lastExtractorErr = nil + break outer + } + lastTokenErr = ErrCSRFInvalid + } } - if !validateCSRFToken(token, clientToken) { - return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + if lastTokenErr != nil { + return lastTokenErr + } else if lastExtractorErr != nil { + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") + } else if lastExtractorErr == errFormExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") + } else { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) + } + return lastExtractorErr } } @@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } } -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided request header. -func csrfTokenFromHeader(header string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - return c.Request().Header.Get(header), nil - } -} - -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided form parameter. -func csrfTokenFromForm(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.FormValue(param) - if token == "" { - return "", errors.New("missing csrf token in the form parameter") - } - return token, nil - } -} - -// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the -// provided query parameter. -func csrfTokenFromQuery(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", errors.New("missing csrf token in the query string") - } - return token, nil - } -} - func validateCSRFToken(token, clientToken string) bool { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index af1d2639..9aff82a9 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -1,7 +1,6 @@ package middleware import ( - "fmt" "net/http" "net/http/httptest" "net/url" @@ -13,14 +12,205 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCSRF_tokenExtractors(t *testing.T) { + var testCases = []struct { + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + }{ + { + name: "ok, multiple token lookups sources, succeeds on last one", + whenTokenLookup: "header:X-CSRF-Token,form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid_token"}, + }, + givenFormTokens: map[string][]string{ + "csrf": {"token"}, + }, + }, + { + name: "ok, token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"token"}, + }, + }, + { + name: "ok, token from POST form, second token passes", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"invalid", "token"}, + }, + }, + { + name: "nok, invalid token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{ + "csrf": {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from POST form", + whenTokenLookup: "form:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenFormTokens: map[string][]string{}, + expectError: "code=400, message=missing csrf token in the form parameter", + }, + { + name: "ok, token from POST header", + whenTokenLookup: "", // will use defaults + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"token"}, + }, + }, + { + name: "ok, token from POST header, second token passes", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid", "token"}, + }, + }, + { + name: "nok, invalid token from POST header", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{ + echo.HeaderXCSRFToken: {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from POST header", + whenTokenLookup: "header:" + echo.HeaderXCSRFToken, + givenCSRFCookie: "token", + givenMethod: http.MethodPost, + givenHeaderTokens: map[string][]string{}, + expectError: "code=400, message=missing csrf token in request header", + }, + { + name: "ok, token from PUT query param", + whenTokenLookup: "query:csrf-param", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf-param": {"token"}, + }, + }, + { + name: "ok, token from PUT query form, second token passes", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf": {"invalid", "token"}, + }, + }, + { + name: "nok, invalid token from PUT query form", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{ + "csrf": {"invalid_token"}, + }, + expectError: "code=403, message=invalid csrf token", + }, + { + name: "nok, missing token from PUT query form", + whenTokenLookup: "query:csrf", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectError: "code=400, message=missing csrf token in the query string", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + q := make(url.Values) + for queryParam, values := range tc.givenQueryTokens { + for _, v := range values { + q.Add(queryParam, v) + } + } + + f := make(url.Values) + for formKey, values := range tc.givenFormTokens { + for _, v := range values { + f.Add(formKey, v) + } + } + + var req *http.Request + switch tc.givenMethod { + case http.MethodGet: + req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) + case http.MethodPost, http.MethodPut: + req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + } + + for header, values := range tc.givenHeaderTokens { + for _, v := range values { + req.Header.Add(header, v) + } + } + + if tc.givenCSRFCookie != "" { + req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie) + } + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + TokenLookup: tc.whenTokenLookup, + CookieName: tc.whenCookieName, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err := h(c) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestCSRF(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ - TokenLength: 16, - }) + csrf := CSRF() h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -43,7 +233,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(16) + token := random.String(32) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { @@ -51,38 +241,6 @@ func TestCSRF(t *testing.T) { } } -func TestCSRFTokenFromForm(t *testing.T) { - f := make(url.Values) - f.Set("csrf", "token") - e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - c := e.NewContext(req, nil) - token, err := csrfTokenFromForm("csrf")(c) - if assert.NoError(t, err) { - assert.Equal(t, "token", token) - } - _, err = csrfTokenFromForm("invalid")(c) - assert.Error(t, err) -} - -func TestCSRFTokenFromQuery(t *testing.T) { - q := make(url.Values) - q.Set("csrf", "token") - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - req.URL.RawQuery = q.Encode() - c := e.NewContext(req, nil) - token, err := csrfTokenFromQuery("csrf")(c) - if assert.NoError(t, err) { - assert.Equal(t, "token", token) - } - _, err = csrfTokenFromQuery("invalid")(c) - assert.Error(t, err) - csrfTokenFromQuery("csrf") -} - func TestCSRFSetSameSiteMode(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -135,7 +293,6 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) { r := h(c) assert.NoError(t, r) - fmt.Println(rec.Header()["Set-Cookie"]) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) } @@ -158,3 +315,46 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) } + +func TestCSRFConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + whenSkip bool + expectCookies int + }{ + { + name: "do skip", + whenSkip: true, + expectCookies: 0, + }, + { + name: "do not skip", + whenSkip: false, + expectCookies: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRFWithConfig(CSRFConfig{ + Skipper: func(c echo.Context) bool { + return tc.whenSkip + }, + }) + + h := csrf(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + r := h(c) + assert.NoError(t, r) + cookie := rec.Header()["Set-Cookie"] + assert.Len(t, cookie, tc.expectCookies) + }) + } +} diff --git a/middleware/extractor.go b/middleware/extractor.go new file mode 100644 index 00000000..a57ed4e1 --- /dev/null +++ b/middleware/extractor.go @@ -0,0 +1,184 @@ +package middleware + +import ( + "errors" + "fmt" + "github.com/labstack/echo/v4" + "net/textproto" + "strings" +) + +const ( + // extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion + // attack vector + extractorLimit = 20 +) + +var errHeaderExtractorValueMissing = errors.New("missing value in request header") +var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") +var errQueryExtractorValueMissing = errors.New("missing value in the query string") +var errParamExtractorValueMissing = errors.New("missing value in path params") +var errCookieExtractorValueMissing = errors.New("missing value in cookies") +var errFormExtractorValueMissing = errors.New("missing value in the form") + +// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. +type ValuesExtractor func(c echo.Context) ([]string, error) + +func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { + if lookups == "" { + return nil, nil + } + sources := strings.Split(lookups, ",") + var extractors = make([]ValuesExtractor, 0) + for _, source := range sources { + parts := strings.Split(source, ":") + if len(parts) < 2 { + return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source) + } + + switch parts[0] { + case "query": + extractors = append(extractors, valuesFromQuery(parts[1])) + case "param": + extractors = append(extractors, valuesFromParam(parts[1])) + case "cookie": + extractors = append(extractors, valuesFromCookie(parts[1])) + case "form": + extractors = append(extractors, valuesFromForm(parts[1])) + case "header": + prefix := "" + if len(parts) > 2 { + prefix = parts[2] + } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { + // backwards compatibility for JWT and KeyAuth: + // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc + // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that + // behaviour for default values and Authorization header. + prefix = authScheme + if !strings.HasSuffix(prefix, " ") { + prefix += " " + } + } + extractors = append(extractors, valuesFromHeader(parts[1], prefix)) + } + } + return extractors, nil +} + +// valuesFromHeader returns a functions that extracts values from the request header. +// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static +// prefix like `Authorization: ` where part that we want to remove is ` ` +// note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove +// is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. +// If prefix is left empty the whole value is returned. +func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { + prefixLen := len(valuePrefix) + // standard library parses http.Request header keys in canonical form but we may provide something else so fix this + header = textproto.CanonicalMIMEHeaderKey(header) + return func(c echo.Context) ([]string, error) { + values := c.Request().Header.Values(header) + if len(values) == 0 { + return nil, errHeaderExtractorValueMissing + } + + result := make([]string, 0) + for i, value := range values { + if prefixLen == 0 { + result = append(result, value) + if i >= extractorLimit-1 { + break + } + continue + } + if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + result = append(result, value[prefixLen:]) + if i >= extractorLimit-1 { + break + } + } + } + + if len(result) == 0 { + if prefixLen > 0 { + return nil, errHeaderExtractorValueInvalid + } + return nil, errHeaderExtractorValueMissing + } + return result, nil + } +} + +// valuesFromQuery returns a function that extracts values from the query string. +func valuesFromQuery(param string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + result := c.QueryParams()[param] + if len(result) == 0 { + return nil, errQueryExtractorValueMissing + } else if len(result) > extractorLimit-1 { + result = result[:extractorLimit] + } + return result, nil + } +} + +// valuesFromParam returns a function that extracts values from the url param string. +func valuesFromParam(param string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + result := make([]string, 0) + paramVales := c.ParamValues() + for i, p := range c.ParamNames() { + if param == p { + result = append(result, paramVales[i]) + if i >= extractorLimit-1 { + break + } + } + } + if len(result) == 0 { + return nil, errParamExtractorValueMissing + } + return result, nil + } +} + +// valuesFromCookie returns a function that extracts values from the named cookie. +func valuesFromCookie(name string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + cookies := c.Cookies() + if len(cookies) == 0 { + return nil, errCookieExtractorValueMissing + } + + result := make([]string, 0) + for i, cookie := range cookies { + if name == cookie.Name { + result = append(result, cookie.Value) + if i >= extractorLimit-1 { + break + } + } + } + if len(result) == 0 { + return nil, errCookieExtractorValueMissing + } + return result, nil + } +} + +// valuesFromForm returns a function that extracts values from the form field. +func valuesFromForm(name string) ValuesExtractor { + return func(c echo.Context) ([]string, error) { + if parseErr := c.Request().ParseForm(); parseErr != nil { + return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr) + } + values := c.Request().Form[name] + if len(values) == 0 { + return nil, errFormExtractorValueMissing + } + if len(values) > extractorLimit-1 { + values = values[:extractorLimit] + } + result := append([]string{}, values...) + return result, nil + } +} diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go new file mode 100644 index 00000000..ae4b30a8 --- /dev/null +++ b/middleware/extractor_test.go @@ -0,0 +1,587 @@ +package middleware + +import ( + "fmt" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type pathParam struct { + name string + value string +} + +func setPathParams(c echo.Context, params []pathParam) { + names := make([]string, 0, len(params)) + values := make([]string, 0, len(params)) + for _, pp := range params { + names = append(names, pp.name) + values = append(values, pp.value) + } + c.SetParamNames(names...) + c.SetParamValues(values...) +} + +func TestCreateExtractors(t *testing.T) { + var testCases = []struct { + name string + givenRequest func() *http.Request + givenPathParams []pathParam + whenLoopups string + expectValues []string + expectCreateError string + expectError string + }{ + { + name: "ok, header", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "Bearer token") + return req + }, + whenLoopups: "header:Authorization:Bearer ", + expectValues: []string{"token"}, + }, + { + name: "ok, form", + givenRequest: func() *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + return req + }, + whenLoopups: "form:name", + expectValues: []string{"Jon Snow"}, + }, + { + name: "ok, cookie", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderCookie, "_csrf=token") + return req + }, + whenLoopups: "cookie:_csrf", + expectValues: []string{"token"}, + }, + { + name: "ok, param", + givenPathParams: []pathParam{ + {name: "id", value: "123"}, + }, + whenLoopups: "param:id", + expectValues: []string{"123"}, + }, + { + name: "ok, query", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) + return req + }, + whenLoopups: "query:id", + expectValues: []string{"999"}, + }, + { + name: "nok, invalid lookup", + whenLoopups: "query", + expectCreateError: "extractor source for lookup could not be split into needed parts: query", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + req = tc.givenRequest() + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.givenPathParams != nil { + setPathParams(c, tc.givenPathParams) + } + + extractors, err := createExtractors(tc.whenLoopups, "") + if tc.expectCreateError != "" { + assert.EqualError(t, err, tc.expectCreateError) + return + } + assert.NoError(t, err) + + for _, e := range extractors { + values, eErr := e(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, eErr, tc.expectError) + return + } + assert.NoError(t, eErr) + } + }) + } +} + +func TestValuesFromHeader(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + whenValuePrefix string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, single value, case insensitive", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, + }, + { + name: "ok, empty prefix", + givenRequest: exampleRequest, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="}, + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "Bearer ", + expectError: errHeaderExtractorValueInvalid.Error(), + }, + { + name: "nok, no matching due different prefix", + givenRequest: func(req *http.Request) { + req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==") + req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0") + }, + whenName: echo.HeaderWWWAuthenticate, + whenValuePrefix: "", + expectError: errHeaderExtractorValueMissing.Error(), + }, + { + name: "nok, no headers", + givenRequest: nil, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectError: errHeaderExtractorValueMissing.Error(), + }, + { + name: "ok, prefix, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i <= 25; i++ { + req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i)) + } + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "basic ", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i <= 25; i++ { + req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i)) + } + }, + whenName: echo.HeaderAuthorization, + whenValuePrefix: "", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromQuery(t *testing.T) { + var testCases = []struct { + name string + givenQueryPart string + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenQueryPart: "?id=123&name=test", + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenQueryPart: "?id=123&id=456&name=test", + whenName: "id", + expectValues: []string{"123", "456"}, + }, + { + name: "nok, missing value", + givenQueryPart: "?id=123&name=test", + whenName: "nope", + expectError: errQueryExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenQueryPart: "?name=test" + + "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + + "&id=21&id=22&id=23&id=24&id=25", + whenName: "id", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromQuery(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromParam(t *testing.T) { + examplePathParams := []pathParam{ + {name: "id", value: "123"}, + {name: "gid", value: "456"}, + {name: "gid", value: "789"}, + } + examplePathParams20 := make([]pathParam, 0) + for i := 1; i < 25; i++ { + examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) + } + + var testCases = []struct { + name string + givenPathParams []pathParam + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenPathParams: examplePathParams, + whenName: "id", + expectValues: []string{"123"}, + }, + { + name: "ok, multiple value", + givenPathParams: examplePathParams, + whenName: "gid", + expectValues: []string{"456", "789"}, + }, + { + name: "nok, no values", + givenPathParams: nil, + whenName: "nope", + expectValues: nil, + expectError: errParamExtractorValueMissing.Error(), + }, + { + name: "nok, no matching value", + givenPathParams: examplePathParams, + whenName: "nope", + expectValues: nil, + expectError: errParamExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenPathParams: examplePathParams20, + whenName: "id", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tc.givenPathParams != nil { + setPathParams(c, tc.givenPathParams) + } + + extractor := valuesFromParam(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromCookie(t *testing.T) { + exampleRequest := func(req *http.Request) { + req.Header.Set(echo.HeaderCookie, "_csrf=token") + } + + var testCases = []struct { + name string + givenRequest func(req *http.Request) + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, single value", + givenRequest: exampleRequest, + whenName: "_csrf", + expectValues: []string{"token"}, + }, + { + name: "ok, multiple value", + givenRequest: func(req *http.Request) { + req.Header.Add(echo.HeaderCookie, "_csrf=token") + req.Header.Add(echo.HeaderCookie, "_csrf=token2") + }, + whenName: "_csrf", + expectValues: []string{"token", "token2"}, + }, + { + name: "nok, no matching cookie", + givenRequest: exampleRequest, + whenName: "xxx", + expectValues: nil, + expectError: errCookieExtractorValueMissing.Error(), + }, + { + name: "nok, no cookies at all", + givenRequest: nil, + whenName: "xxx", + expectValues: nil, + expectError: errCookieExtractorValueMissing.Error(), + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: func(req *http.Request) { + for i := 1; i < 25; i++ { + req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) + } + }, + whenName: "_csrf", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenRequest != nil { + tc.givenRequest(req) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromCookie(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValuesFromForm(t *testing.T) { + examplePostFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) + req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) + + return req + } + exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request { + f := make(url.Values) + f.Set("name", "Jon Snow") + f.Set("emails[]", "jon@labstack.com") + if mod != nil { + mod(&f) + } + + req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil) + return req + } + + var testCases = []struct { + name string + givenRequest *http.Request + whenName string + expectValues []string + expectError string + }{ + { + name: "ok, POST form, single value", + givenRequest: examplePostFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, POST form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "ok, GET form, single value", + givenRequest: exampleGetFormRequest(nil), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com"}, + }, + { + name: "ok, GET form, multiple value", + givenRequest: examplePostFormRequest(func(v *url.Values) { + v.Add("emails[]", "snow@labstack.com") + }), + whenName: "emails[]", + expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, + }, + { + name: "nok, POST form, value missing", + givenRequest: examplePostFormRequest(nil), + whenName: "nope", + expectError: errFormExtractorValueMissing.Error(), + }, + { + name: "nok, POST form, form parsing error", + givenRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodPost, "/", nil) + req.Body = nil + return req + }(), + whenName: "name", + expectError: "valuesFromForm parse form failed: missing form body", + }, + { + name: "ok, cut values over extractorLimit", + givenRequest: examplePostFormRequest(func(v *url.Values) { + for i := 1; i < 25; i++ { + v.Add("id[]", fmt.Sprintf("%v", i)) + } + }), + whenName: "id[]", + expectValues: []string{ + "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := tc.givenRequest + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + extractor := valuesFromForm(tc.whenName) + + values, err := extractor(c) + assert.Equal(t, tc.expectValues, values) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/middleware/jwt.go b/middleware/jwt.go index 43605e37..bec5167e 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -1,3 +1,4 @@ +//go:build go1.15 // +build go1.15 package middleware @@ -5,12 +6,10 @@ package middleware import ( "errors" "fmt" - "net/http" - "reflect" - "strings" - "github.com/golang-jwt/jwt" "github.com/labstack/echo/v4" + "net/http" + "reflect" ) type ( @@ -22,7 +21,8 @@ type ( // BeforeFunc defines a function which is executed just before the middleware. BeforeFunc BeforeFunc - // SuccessHandler defines a function which is executed for a valid token. + // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next + // middleware or handler. SuccessHandler JWTSuccessHandler // ErrorHandler defines a function which is executed for an invalid token. @@ -32,6 +32,13 @@ type ( // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. ErrorHandlerWithContext JWTErrorHandlerWithContext + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) token value then. + ContinueOnIgnoredError bool + // Signing key to validate token. // This is one of the three options to provide a token validation key. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. @@ -61,20 +68,25 @@ type ( // to extract token from the request. // Optional. Default value "header:Authorization". // Possible values: - // - "header:" + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of JWT tokens `Authorization: Bearer ` prefix we cut is `Bearer `. + // If prefix is left empty the whole value is returned. // - "query:" // - "param:" // - "cookie:" // - "form:" - // Multiply sources example: - // - "header: Authorization,cookie: myowncookie" + // Multiple sources example: + // - "header:Authorization,cookie:myowncookie" TokenLookup string // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. // This is one of the two options to provide a token extractor. // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. // You can also provide both if you want. - TokenLookupFuncs []TokenLookupFunc + TokenLookupFuncs []ValuesExtractor // AuthScheme to be used in the Authorization header. // Optional. Default value "Bearer". @@ -100,16 +112,13 @@ type ( } // JWTSuccessHandler defines a function which is executed for a valid token. - JWTSuccessHandler func(echo.Context) + JWTSuccessHandler func(c echo.Context) // JWTErrorHandler defines a function which is executed for an invalid token. - JWTErrorHandler func(error) error + JWTErrorHandler func(err error) error // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. - JWTErrorHandlerWithContext func(error, echo.Context) error - - // TokenLookupFunc defines a function for extracting JWT token from the given context. - TokenLookupFunc func(echo.Context) (string, error) + JWTErrorHandlerWithContext func(err error, c echo.Context) error ) // Algorithms @@ -183,25 +192,12 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { config.ParseTokenFunc = config.defaultParseToken } - // Initialize - // Split sources - sources := strings.Split(config.TokenLookup, ",") - var extractors = config.TokenLookupFuncs - for _, source := range sources { - parts := strings.Split(source, ":") - - switch parts[0] { - case "query": - extractors = append(extractors, jwtFromQuery(parts[1])) - case "param": - extractors = append(extractors, jwtFromParam(parts[1])) - case "cookie": - extractors = append(extractors, jwtFromCookie(parts[1])) - case "form": - extractors = append(extractors, jwtFromForm(parts[1])) - case "header": - extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme)) - } + extractors, err := createExtractors(config.TokenLookup, config.AuthScheme) + if err != nil { + panic(err) + } + if len(config.TokenLookupFuncs) > 0 { + extractors = append(config.TokenLookupFuncs, extractors...) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -213,48 +209,54 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.BeforeFunc != nil { config.BeforeFunc(c) } - var auth string - var err error + + var lastExtractorErr error + var lastTokenErr error for _, extractor := range extractors { - // Extract token from extractor, if it's not fail break the loop and - // set auth - auth, err = extractor(c) - if err == nil { - break + auths, err := extractor(c) + if err != nil { + lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth) + continue + } + for _, auth := range auths { + token, err := config.ParseTokenFunc(auth, c) + if err != nil { + lastTokenErr = err + continue + } + // Store user information from token into context. + c.Set(config.ContextKey, token) + if config.SuccessHandler != nil { + config.SuccessHandler(c) + } + return next(c) } } - // If none of extractor has a token, handle error - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err) - } - - if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) - } - return err - } - - token, err := config.ParseTokenFunc(auth, c) - if err == nil { - // Store user information from token into context. - c.Set(config.ContextKey, token) - if config.SuccessHandler != nil { - config.SuccessHandler(c) - } - return next(c) + // we are here only when we did not successfully extract or parse any of the tokens + err := lastTokenErr + if err == nil { // prioritize token errors over extracting errors + err = lastExtractorErr } if config.ErrorHandler != nil { return config.ErrorHandler(err) } if config.ErrorHandlerWithContext != nil { - return config.ErrorHandlerWithContext(err, c) + tmpErr := config.ErrorHandlerWithContext(err, c) + if config.ContinueOnIgnoredError && tmpErr == nil { + return next(c) + } + return tmpErr } - return &echo.HTTPError{ - Code: ErrJWTInvalid.Code, - Message: ErrJWTInvalid.Message, - Internal: err, + + // backwards compatible errors codes + if lastTokenErr != nil { + return &echo.HTTPError{ + Code: ErrJWTInvalid.Code, + Message: ErrJWTInvalid.Message, + Internal: err, + } } + return err // this is lastExtractorErr value } } } @@ -296,59 +298,3 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) { return config.SigningKey, nil } - -// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header. -func jwtFromHeader(header string, authScheme string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - l := len(authScheme) - if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) { - return auth[l+1:], nil - } - return "", ErrJWTMissing - } -} - -// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string. -func jwtFromQuery(param string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string. -func jwtFromParam(param string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - token := c.Param(param) - if token == "" { - return "", ErrJWTMissing - } - return token, nil - } -} - -// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie. -func jwtFromCookie(name string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - cookie, err := c.Cookie(name) - if err != nil { - return "", ErrJWTMissing - } - return cookie.Value, nil - } -} - -// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field. -func jwtFromForm(name string) TokenLookupFunc { - return func(c echo.Context) (string, error) { - field := c.FormValue(name) - if field == "" { - return "", ErrJWTMissing - } - return field, nil - } -} diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 18454d0a..eee9df96 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -1,3 +1,4 @@ +//go:build go1.15 // +build go1.15 package middleware @@ -28,6 +29,26 @@ type jwtCustomClaims struct { jwtCustomInfo } +func TestJWT(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) + }) + + e.Use(JWT([]byte("secret"))) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ") + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) +} + func TestJWTRace(t *testing.T) { e := echo.New() handler := func(c echo.Context) error { @@ -64,8 +85,7 @@ func TestJWTRace(t *testing.T) { assert.Equal(t, claims.Admin, true) } -func TestJWT(t *testing.T) { - e := echo.New() +func TestJWTConfig(t *testing.T) { handler := func(c echo.Context) error { return c.String(http.StatusOK, "test") } @@ -74,7 +94,8 @@ func TestJWT(t *testing.T) { invalidKey := []byte("invalid-key") validAuth := DefaultJWTConfig.AuthScheme + " " + token - for _, tc := range []struct { + testCases := []struct { + name string expPanic bool expErrCode int // 0 for Success config JWTConfig @@ -82,166 +103,166 @@ func TestJWT(t *testing.T) { hdrAuth string hdrCookie string // test.Request doesn't provide SetCookie(); use name=val formValues map[string]string - info string }{ { + name: "No signing key provided", expPanic: true, - info: "No signing key provided", }, { + name: "Unexpected signing method", expErrCode: http.StatusBadRequest, config: JWTConfig{ SigningKey: validKey, SigningMethod: "RS256", }, - info: "Unexpected signing method", }, { + name: "Invalid key", expErrCode: http.StatusUnauthorized, hdrAuth: validAuth, config: JWTConfig{SigningKey: invalidKey}, - info: "Invalid key", }, { + name: "Valid JWT", hdrAuth: validAuth, config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT", }, { + name: "Valid JWT with custom AuthScheme", hdrAuth: "Token" + " " + token, config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, - info: "Valid JWT with custom AuthScheme", }, { + name: "Valid JWT with custom claims", hdrAuth: validAuth, config: JWTConfig{ Claims: &jwtCustomClaims{}, SigningKey: []byte("secret"), }, - info: "Valid JWT with custom claims", }, { + name: "Invalid Authorization header", hdrAuth: "invalid-auth", expErrCode: http.StatusBadRequest, config: JWTConfig{SigningKey: validKey}, - info: "Invalid Authorization header", }, { + name: "Empty header auth field", config: JWTConfig{SigningKey: validKey}, expErrCode: http.StatusBadRequest, - info: "Empty header auth field", }, { + name: "Valid query method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=" + token, - info: "Valid query method", }, { + name: "Invalid query param name", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwtxyz=" + token, expErrCode: http.StatusBadRequest, - info: "Invalid query param name", }, { + name: "Invalid query param value", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b&jwt=invalid-token", expErrCode: http.StatusUnauthorized, - info: "Invalid query param value", }, { + name: "Empty query", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt", }, reqURL: "/?a=b", expErrCode: http.StatusBadRequest, - info: "Empty query", }, { + name: "Valid param method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "param:jwt", }, reqURL: "/" + token, - info: "Valid param method", }, { + name: "Valid cookie method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "cookie:jwt", }, hdrCookie: "jwt=" + token, - info: "Valid cookie method", }, { + name: "Multiple jwt lookuop", config: JWTConfig{ SigningKey: validKey, TokenLookup: "query:jwt,cookie:jwt", }, hdrCookie: "jwt=" + token, - info: "Multiple jwt lookuop", }, { + name: "Invalid token with cookie method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "cookie:jwt", }, expErrCode: http.StatusUnauthorized, hdrCookie: "jwt=invalid", - info: "Invalid token with cookie method", }, { + name: "Empty cookie", config: JWTConfig{ SigningKey: validKey, TokenLookup: "cookie:jwt", }, expErrCode: http.StatusBadRequest, - info: "Empty cookie", }, { + name: "Valid form method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "form:jwt", }, formValues: map[string]string{"jwt": token}, - info: "Valid form method", }, { + name: "Invalid token with form method", config: JWTConfig{ SigningKey: validKey, TokenLookup: "form:jwt", }, expErrCode: http.StatusUnauthorized, formValues: map[string]string{"jwt": "invalid"}, - info: "Invalid token with form method", }, { + name: "Empty form field", config: JWTConfig{ SigningKey: validKey, TokenLookup: "form:jwt", }, expErrCode: http.StatusBadRequest, - info: "Empty form field", }, { + name: "Valid JWT with a valid key using a user-defined KeyFunc", hdrAuth: validAuth, config: JWTConfig{ KeyFunc: func(*jwt.Token) (interface{}, error) { return validKey, nil }, }, - info: "Valid JWT with a valid key using a user-defined KeyFunc", }, { + name: "Valid JWT with an invalid key using a user-defined KeyFunc", hdrAuth: validAuth, config: JWTConfig{ KeyFunc: func(*jwt.Token) (interface{}, error) { @@ -249,9 +270,9 @@ func TestJWT(t *testing.T) { }, }, expErrCode: http.StatusUnauthorized, - info: "Valid JWT with an invalid key using a user-defined KeyFunc", }, { + name: "Token verification does not pass using a user-defined KeyFunc", hdrAuth: validAuth, config: JWTConfig{ KeyFunc: func(*jwt.Token) (interface{}, error) { @@ -259,67 +280,70 @@ func TestJWT(t *testing.T) { }, }, expErrCode: http.StatusUnauthorized, - info: "Token verification does not pass using a user-defined KeyFunc", }, { + name: "Valid JWT with lower case AuthScheme", hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, config: JWTConfig{SigningKey: validKey}, - info: "Valid JWT with lower case AuthScheme", }, - } { - if tc.reqURL == "" { - tc.reqURL = "/" - } - - var req *http.Request - if len(tc.formValues) > 0 { - form := url.Values{} - for k, v := range tc.formValues { - form.Set(k, v) + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + if tc.reqURL == "" { + tc.reqURL = "/" } - req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) - req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") - req.ParseForm() - } else { - req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) - } - res := httptest.NewRecorder() - req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) - req.Header.Set(echo.HeaderCookie, tc.hdrCookie) - c := e.NewContext(req, res) - if tc.reqURL == "/"+token { - c.SetParamNames("jwt") - c.SetParamValues(token) - } + var req *http.Request + if len(tc.formValues) > 0 { + form := url.Values{} + for k, v := range tc.formValues { + form.Set(k, v) + } + req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode())) + req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded") + req.ParseForm() + } else { + req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil) + } + res := httptest.NewRecorder() + req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth) + req.Header.Set(echo.HeaderCookie, tc.hdrCookie) + c := e.NewContext(req, res) - if tc.expPanic { - assert.Panics(t, func() { - JWTWithConfig(tc.config) - }, tc.info) - continue - } + if tc.reqURL == "/"+token { + c.SetParamNames("jwt") + c.SetParamValues(token) + } + + if tc.expPanic { + assert.Panics(t, func() { + JWTWithConfig(tc.config) + }, tc.name) + return + } + + if tc.expErrCode != 0 { + h := JWTWithConfig(tc.config)(handler) + he := h(c).(*echo.HTTPError) + assert.Equal(t, tc.expErrCode, he.Code, tc.name) + return + } - if tc.expErrCode != 0 { h := JWTWithConfig(tc.config)(handler) - he := h(c).(*echo.HTTPError) - assert.Equal(t, tc.expErrCode, he.Code, tc.info) - continue - } - - h := JWTWithConfig(tc.config)(handler) - if assert.NoError(t, h(c), tc.info) { - user := c.Get("user").(*jwt.Token) - switch claims := user.Claims.(type) { - case jwt.MapClaims: - assert.Equal(t, claims["name"], "John Doe", tc.info) - case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe", tc.info) - assert.Equal(t, claims.Admin, true, tc.info) - default: - panic("unexpected type of claims") + if assert.NoError(t, h(c), tc.name) { + user := c.Get("user").(*jwt.Token) + switch claims := user.Claims.(type) { + case jwt.MapClaims: + assert.Equal(t, claims["name"], "John Doe", tc.name) + case *jwtCustomClaims: + assert.Equal(t, claims.Name, "John Doe", tc.name) + assert.Equal(t, claims.Admin, true, tc.name) + default: + panic("unexpected type of claims") + } } - } + }) } } @@ -608,13 +632,14 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) { e := echo.New() e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "test") + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) }) e.Use(JWTWithConfig(JWTConfig{ - TokenLookupFuncs: []TokenLookupFunc{ - func(c echo.Context) (string, error) { - return c.Request().Header.Get("X-API-Key"), nil + TokenLookupFuncs: []ValuesExtractor{ + func(c echo.Context) ([]string, error) { + return []string{c.Request().Header.Get("X-API-Key")}, nil }, }, SigningKey: []byte("secret"), @@ -626,4 +651,129 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) { e.ServeHTTP(res, req) assert.Equal(t, http.StatusOK, res.Code) + assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String()) +} + +func TestJWTConfig_SuccessHandler(t *testing.T) { + var testCases = []struct { + name string + givenToken string + expectCalled bool + expectStatus int + }{ + { + name: "ok, success handler is called", + givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", + expectCalled: true, + expectStatus: http.StatusOK, + }, + { + name: "nok, success handler is not called", + givenToken: "x.x.x", + expectCalled: false, + expectStatus: http.StatusUnauthorized, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + token := c.Get("user").(*jwt.Token) + return c.JSON(http.StatusOK, token.Claims) + }) + + wasCalled := false + e.Use(JWTWithConfig(JWTConfig{ + SuccessHandler: func(c echo.Context) { + wasCalled = true + }, + SigningKey: []byte("secret"), + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectCalled, wasCalled) + assert.Equal(t, tc.expectStatus, res.Code) + }) + } +} + +func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) { + var testCases = []struct { + name string + whenContinueOnIgnoredError bool + givenToken string + expectStatus int + expectBody string + }{ + { + name: "no error handler is called", + whenContinueOnIgnoredError: true, + givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ", + expectStatus: http.StatusTeapot, + expectBody: "", + }, + { + name: "ContinueOnIgnoredError is false and error handler is called for missing token", + whenContinueOnIgnoredError: false, + givenToken: "", + // empty response with 200. This emulates previous behaviour when error handler swallowed the error + expectStatus: http.StatusOK, + expectBody: "", + }, + { + name: "error handler is called for missing token", + whenContinueOnIgnoredError: true, + givenToken: "", + expectStatus: http.StatusTeapot, + expectBody: "public-token", + }, + { + name: "error handler is called for invalid token", + whenContinueOnIgnoredError: true, + givenToken: "x.x.x", + expectStatus: http.StatusUnauthorized, + expectBody: "{\"message\":\"Unauthorized\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + testValue, _ := c.Get("test").(string) + return c.String(http.StatusTeapot, testValue) + }) + + e.Use(JWTWithConfig(JWTConfig{ + ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, + SigningKey: []byte("secret"), + ErrorHandlerWithContext: func(err error, c echo.Context) error { + if err == ErrJWTMissing { + c.Set("test", "public-token") + return nil + } + return echo.ErrUnauthorized + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenToken != "" { + req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken) + } + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectStatus, res.Code) + assert.Equal(t, tc.expectBody, res.Body.String()) + }) + } } diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 54f3b47f..e8a6b085 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -2,11 +2,8 @@ package middleware import ( "errors" - "fmt" - "net/http" - "strings" - "github.com/labstack/echo/v4" + "net/http" ) type ( @@ -15,15 +12,21 @@ type ( // Skipper defines a function to skip middleware. Skipper Skipper - // KeyLookup is a string in the form of ":" that is used + // KeyLookup is a string in the form of ":" or ":,:" that is used // to extract key from the request. // Optional. Default value "header:Authorization". // Possible values: - // - "header:" + // - "header:" or "header::" + // `` is argument value to cut/trim prefix of the extracted value. This is useful if header + // value has static prefix like `Authorization: ` where part that we + // want to cut is ` ` note the space at the end. + // In case of basic authentication `Authorization: Basic ` prefix we want to remove is `Basic `. // - "query:" // - "form:" // - "cookie:" - KeyLookup string `yaml:"key_lookup"` + // Multiple sources example: + // - "header:Authorization,header:X-Api-Key" + KeyLookup string // AuthScheme to be used in the Authorization header. // Optional. Default value "Bearer". @@ -36,15 +39,20 @@ type ( // ErrorHandler defines a function which is executed for an invalid key. // It may be used to define a custom error. ErrorHandler KeyAuthErrorHandler + + // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to + // ignore the error (by returning `nil`). + // This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality. + // In that case you can use ErrorHandler to set a default public key auth value in the request context + // and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then. + ContinueOnIgnoredError bool } // KeyAuthValidator defines a function to validate KeyAuth credentials. - KeyAuthValidator func(string, echo.Context) (bool, error) - - keyExtractor func(echo.Context) (string, error) + KeyAuthValidator func(auth string, c echo.Context) (bool, error) // KeyAuthErrorHandler defines a function which is executed for an invalid key. - KeyAuthErrorHandler func(error, echo.Context) error + KeyAuthErrorHandler func(err error, c echo.Context) error ) var ( @@ -56,6 +64,21 @@ var ( } ) +// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups +type ErrKeyAuthMissing struct { + Err error +} + +// Error returns errors text +func (e *ErrKeyAuthMissing) Error() string { + return e.Err.Error() +} + +// Unwrap unwraps error +func (e *ErrKeyAuthMissing) Unwrap() error { + return e.Err +} + // KeyAuth returns an KeyAuth middleware. // // For valid key it calls the next handler. @@ -85,16 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { panic("echo: key-auth middleware requires a validator function") } - // Initialize - parts := strings.Split(config.KeyLookup, ":") - extractor := keyFromHeader(parts[1], config.AuthScheme) - switch parts[0] { - case "query": - extractor = keyFromQuery(parts[1]) - case "form": - extractor = keyFromForm(parts[1]) - case "cookie": - extractor = keyFromCookie(parts[1]) + extractors, err := createExtractors(config.KeyLookup, config.AuthScheme) + if err != nil { + panic(err) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -103,79 +119,62 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { return next(c) } - // Extract and verify key - key, err := extractor(c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err, c) + var lastExtractorErr error + var lastValidatorErr error + for _, extractor := range extractors { + keys, err := extractor(c) + if err != nil { + lastExtractorErr = err + continue + } + for _, key := range keys { + valid, err := config.Validator(key, c) + if err != nil { + lastValidatorErr = err + continue + } + if valid { + return next(c) + } + lastValidatorErr = errors.New("invalid key") } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - valid, err := config.Validator(key, c) - if err != nil { - if config.ErrorHandler != nil { - return config.ErrorHandler(err, c) + + // we are here only when we did not successfully extract and validate any of keys + err := lastValidatorErr + if err == nil { // prioritize validator errors over extracting errors + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + err = errors.New("missing key in the query string") + } else if lastExtractorErr == errCookieExtractorValueMissing { + err = errors.New("missing key in cookies") + } else if lastExtractorErr == errFormExtractorValueMissing { + err = errors.New("missing key in the form") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + err = errors.New("missing key in request header") + } else if lastExtractorErr == errHeaderExtractorValueInvalid { + err = errors.New("invalid key in the request header") + } else { + err = lastExtractorErr } + err = &ErrKeyAuthMissing{Err: err} + } + + if config.ErrorHandler != nil { + tmpErr := config.ErrorHandler(err, c) + if config.ContinueOnIgnoredError && tmpErr == nil { + return next(c) + } + return tmpErr + } + if lastValidatorErr != nil { // prioritize validator errors over extracting errors return &echo.HTTPError{ Code: http.StatusUnauthorized, - Message: "invalid key", - Internal: err, + Message: "Unauthorized", + Internal: lastValidatorErr, } - } else if valid { - return next(c) } - return echo.ErrUnauthorized + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } } } - -// keyFromHeader returns a `keyExtractor` that extracts key from the request header. -func keyFromHeader(header string, authScheme string) keyExtractor { - return func(c echo.Context) (string, error) { - auth := c.Request().Header.Get(header) - if auth == "" { - return "", errors.New("missing key in request header") - } - if header == echo.HeaderAuthorization { - l := len(authScheme) - if len(auth) > l+1 && auth[:l] == authScheme { - return auth[l+1:], nil - } - return "", errors.New("invalid key in the request header") - } - return auth, nil - } -} - -// keyFromQuery returns a `keyExtractor` that extracts key from the query string. -func keyFromQuery(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.QueryParam(param) - if key == "" { - return "", errors.New("missing key in the query string") - } - return key, nil - } -} - -// keyFromForm returns a `keyExtractor` that extracts key from the form. -func keyFromForm(param string) keyExtractor { - return func(c echo.Context) (string, error) { - key := c.FormValue(param) - if key == "" { - return "", errors.New("missing key in the form") - } - return key, nil - } -} - -// keyFromCookie returns a `keyExtractor` that extracts key from the form. -func keyFromCookie(cookieName string) keyExtractor { - return func(c echo.Context) (string, error) { - key, err := c.Cookie(cookieName) - if err != nil { - return "", fmt.Errorf("missing key in cookies: %w", err) - } - return key.Value, nil - } -} diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 0cc513ab..ff8968c3 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized", + expectError: "code=401, message=Unauthorized, internal=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -92,6 +92,17 @@ func TestKeyAuthWithConfig(t *testing.T) { expectHandlerCalled: false, expectError: "code=400, message=missing key in request header", }, + { + name: "ok, custom key lookup from multiple places, query and header", + givenRequest: func(req *http.Request) { + req.URL.RawQuery = "key=invalid-key" + req.Header.Set("API-Key", "valid-key") + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key,header:API-Key" + }, + expectHandlerCalled: true, + }, { name: "ok, custom key lookup, header", givenRequest: func(req *http.Request) { @@ -179,7 +190,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies: http: named cookie not present", + expectError: "code=400, message=missing key in cookies", }, { name: "nok, custom errorHandler, error from extractor", @@ -216,7 +227,7 @@ func TestKeyAuthWithConfig(t *testing.T) { }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, - expectError: "code=401, message=invalid key, internal=some user defined error", + expectError: "code=401, message=Unauthorized, internal=some user defined error", }, } @@ -257,3 +268,109 @@ func TestKeyAuthWithConfig(t *testing.T) { }) } } + +func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { + assert.PanicsWithError( + t, + "extractor source for lookup could not be split into needed parts: a", + func() { + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + KeyLookup: "a", + })(handler) + }, + ) +} + +func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { + assert.PanicsWithValue( + t, + "echo: key-auth middleware requires a validator function", + func() { + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + KeyAuthWithConfig(KeyAuthConfig{ + Validator: nil, + })(handler) + }, + ) +} + +func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { + var testCases = []struct { + name string + whenContinueOnIgnoredError bool + givenKey string + expectStatus int + expectBody string + }{ + { + name: "no error handler is called", + whenContinueOnIgnoredError: true, + givenKey: "valid-key", + expectStatus: http.StatusTeapot, + expectBody: "", + }, + { + name: "ContinueOnIgnoredError is false and error handler is called for missing token", + whenContinueOnIgnoredError: false, + givenKey: "", + // empty response with 200. This emulates previous behaviour when error handler swallowed the error + expectStatus: http.StatusOK, + expectBody: "", + }, + { + name: "error handler is called for missing token", + whenContinueOnIgnoredError: true, + givenKey: "", + expectStatus: http.StatusTeapot, + expectBody: "public-auth", + }, + { + name: "error handler is called for invalid token", + whenContinueOnIgnoredError: true, + givenKey: "x.x.x", + expectStatus: http.StatusUnauthorized, + expectBody: "{\"message\":\"Unauthorized\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.GET("/", func(c echo.Context) error { + testValue, _ := c.Get("test").(string) + return c.String(http.StatusTeapot, testValue) + }) + + e.Use(KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(err error, c echo.Context) error { + if _, ok := err.(*ErrKeyAuthMissing); ok { + c.Set("test", "public-auth") + return nil + } + return echo.ErrUnauthorized + }, + KeyLookup: "header:X-API-Key", + ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tc.givenKey != "" { + req.Header.Set("X-API-Key", tc.givenKey) + } + res := httptest.NewRecorder() + + e.ServeHTTP(res, req) + + assert.Equal(t, tc.expectStatus, res.Code) + assert.Equal(t, tc.expectBody, res.Body.String()) + }) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go index a7ad73a5..f250ca49 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -12,10 +12,10 @@ import ( type ( // Skipper defines a function to skip middleware. Returning true skips processing // the middleware. - Skipper func(echo.Context) bool + Skipper func(c echo.Context) bool // BeforeFunc defines a function which is executed just before the middleware. - BeforeFunc func(echo.Context) + BeforeFunc func(c echo.Context) ) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {