1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-26 03:20:08 +02:00

JWT, KeyAuth, CSRF multivalue extractors (#2060)

* CSRF, JWT, KeyAuth middleware support for multivalue value extractors
* Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
This commit is contained in:
Martti T 2022-01-24 22:03:45 +02:00 committed by GitHub
parent 9e9924d763
commit 4a1ccdfdc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1564 additions and 397 deletions

View File

@ -111,10 +111,10 @@ type (
} }
// MiddlewareFunc defines a function to process middleware. // MiddlewareFunc defines a function to process middleware.
MiddlewareFunc func(HandlerFunc) HandlerFunc MiddlewareFunc func(next HandlerFunc) HandlerFunc
// HandlerFunc defines a function to serve HTTP requests. // HandlerFunc defines a function to serve HTTP requests.
HandlerFunc func(Context) error HandlerFunc func(c Context) error
// HTTPErrorHandler is a centralized HTTP error handler. // HTTPErrorHandler is a centralized HTTP error handler.
HTTPErrorHandler func(error, Context) HTTPErrorHandler func(error, Context)

View File

@ -2,9 +2,7 @@ package middleware
import ( import (
"crypto/subtle" "crypto/subtle"
"errors"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -21,13 +19,15 @@ type (
TokenLength uint8 `yaml:"token_length"` TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32. // Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<key>" that is used // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request. // to extract token from the request.
// Optional. Default value "header:X-CSRF-Token". // Optional. Default value "header:X-CSRF-Token".
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>" or "header:<name>:<cut-prefix>"
// - "form:<name>"
// - "query:<name>" // - "query:<name>"
// - "form:<name>"
// Multiple sources example:
// - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"` TokenLookup string `yaml:"token_lookup"`
// Context key to store generated CSRF token into context. // Context key to store generated CSRF token into context.
@ -62,12 +62,11 @@ type (
// Optional. Default value SameSiteDefaultMode. // Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"` CookieSameSite http.SameSite `yaml:"cookie_same_site"`
} }
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
csrfTokenExtractor func(echo.Context) (string, error)
) )
// ErrCSRFInvalid is returned when CSRF check fails
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
var ( var (
// DefaultCSRFConfig is the default CSRF middleware config. // DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{ DefaultCSRFConfig = CSRFConfig{
@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieSecure = true config.CookieSecure = true
} }
// Initialize extractors, err := createExtractors(config.TokenLookup, "")
parts := strings.Split(config.TokenLookup, ":") if err != nil {
extractor := csrfTokenFromHeader(parts[1]) panic(err)
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
req := c.Request()
k, err := c.Cookie(config.CookieName)
token := "" token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
// Generate token token = random.String(config.TokenLength) // Generate token
if err != nil {
token = random.String(config.TokenLength)
} else { } else {
// Reuse token token = k.Value // Reuse token
token = k.Value
} }
switch req.Method { switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default: default:
// Validate token only for requests which are not defined as 'safe' by RFC7231 // Validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(c) var lastExtractorErr error
var lastTokenErr error
outer:
for _, extractor := range extractors {
clientTokens, err := extractor(c)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error()) lastExtractorErr = err
continue
} }
if !validateCSRFToken(token, clientToken) {
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") for _, clientToken := range clientTokens {
if validateCSRFToken(token, clientToken) {
lastTokenErr = nil
lastExtractorErr = nil
break outer
}
lastTokenErr = ErrCSRFInvalid
}
}
if lastTokenErr != nil {
return lastTokenErr
} else if lastExtractorErr != nil {
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
} else if lastExtractorErr == errFormExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
} else {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
}
return lastExtractorErr
} }
} }
@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
} }
} }
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header.
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header.Get(header), nil
}
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter.
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
}
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter.
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}
}
func validateCSRFToken(token, clientToken string) bool { func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
} }

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -13,14 +12,205 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCSRF_tokenExtractors(t *testing.T) {
var testCases = []struct {
name string
whenTokenLookup string
whenCookieName string
givenCSRFCookie string
givenMethod string
givenQueryTokens map[string][]string
givenFormTokens map[string][]string
givenHeaderTokens map[string][]string
expectError string
}{
{
name: "ok, multiple token lookups sources, succeeds on last one",
whenTokenLookup: "header:X-CSRF-Token,form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid_token"},
},
givenFormTokens: map[string][]string{
"csrf": {"token"},
},
},
{
name: "ok, token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"token"},
},
},
{
name: "ok, token from POST form, second token passes",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
},
{
name: "nok, invalid token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in the form parameter",
},
{
name: "ok, token from POST header",
whenTokenLookup: "", // will use defaults
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"token"},
},
},
{
name: "ok, token from POST header, second token passes",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid", "token"},
},
},
{
name: "nok, invalid token from POST header",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from POST header",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in request header",
},
{
name: "ok, token from PUT query param",
whenTokenLookup: "query:csrf-param",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf-param": {"token"},
},
},
{
name: "ok, token from PUT query form, second token passes",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
},
{
name: "nok, invalid token from PUT query form",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf": {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from PUT query form",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in the query string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
q := make(url.Values)
for queryParam, values := range tc.givenQueryTokens {
for _, v := range values {
q.Add(queryParam, v)
}
}
f := make(url.Values)
for formKey, values := range tc.givenFormTokens {
for _, v := range values {
f.Add(formKey, v)
}
}
var req *http.Request
switch tc.givenMethod {
case http.MethodGet:
req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
case http.MethodPost, http.MethodPut:
req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
}
for header, values := range tc.givenHeaderTokens {
for _, v := range values {
req.Header.Add(header, v)
}
}
if tc.givenCSRFCookie != "" {
req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
TokenLookup: tc.whenTokenLookup,
CookieName: tc.whenCookieName,
})
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
err := h(c)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestCSRF(t *testing.T) { func TestCSRF(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{ csrf := CSRF()
TokenLength: 16,
})
h := csrf(func(c echo.Context) error { h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
}) })
@ -43,7 +233,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c)) assert.Error(t, h(c))
// Valid CSRF token // Valid CSRF token
token := random.String(16) token := random.String(32)
req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token) req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {
@ -51,38 +241,6 @@ func TestCSRF(t *testing.T) {
} }
} }
func TestCSRFTokenFromForm(t *testing.T) {
f := make(url.Values)
f.Set("csrf", "token")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
c := e.NewContext(req, nil)
token, err := csrfTokenFromForm("csrf")(c)
if assert.NoError(t, err) {
assert.Equal(t, "token", token)
}
_, err = csrfTokenFromForm("invalid")(c)
assert.Error(t, err)
}
func TestCSRFTokenFromQuery(t *testing.T) {
q := make(url.Values)
q.Set("csrf", "token")
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
req.URL.RawQuery = q.Encode()
c := e.NewContext(req, nil)
token, err := csrfTokenFromQuery("csrf")(c)
if assert.NoError(t, err) {
assert.Equal(t, "token", token)
}
_, err = csrfTokenFromQuery("invalid")(c)
assert.Error(t, err)
csrfTokenFromQuery("csrf")
}
func TestCSRFSetSameSiteMode(t *testing.T) { func TestCSRFSetSameSiteMode(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil) req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -135,7 +293,6 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
r := h(c) r := h(c)
assert.NoError(t, r) assert.NoError(t, r)
fmt.Println(rec.Header()["Set-Cookie"])
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"]) assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
} }
@ -158,3 +315,46 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"]) assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"]) assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
} }
func TestCSRFConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
whenSkip bool
expectCookies int
}{
{
name: "do skip",
whenSkip: true,
expectCookies: 0,
},
{
name: "do not skip",
whenSkip: false,
expectCookies: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
Skipper: func(c echo.Context) bool {
return tc.whenSkip
},
})
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
cookie := rec.Header()["Set-Cookie"]
assert.Len(t, cookie, tc.expectCookies)
})
}
}

184
middleware/extractor.go Normal file
View File

@ -0,0 +1,184 @@
package middleware
import (
"errors"
"fmt"
"github.com/labstack/echo/v4"
"net/textproto"
"strings"
)
const (
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
// attack vector
extractorLimit = 20
)
var errHeaderExtractorValueMissing = errors.New("missing value in request header")
var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
var errQueryExtractorValueMissing = errors.New("missing value in the query string")
var errParamExtractorValueMissing = errors.New("missing value in path params")
var errCookieExtractorValueMissing = errors.New("missing value in cookies")
var errFormExtractorValueMissing = errors.New("missing value in the form")
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c echo.Context) ([]string, error)
func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
if lookups == "" {
return nil, nil
}
sources := strings.Split(lookups, ",")
var extractors = make([]ValuesExtractor, 0)
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
}
switch parts[0] {
case "query":
extractors = append(extractors, valuesFromQuery(parts[1]))
case "param":
extractors = append(extractors, valuesFromParam(parts[1]))
case "cookie":
extractors = append(extractors, valuesFromCookie(parts[1]))
case "form":
extractors = append(extractors, valuesFromForm(parts[1]))
case "header":
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
} else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
// backwards compatibility for JWT and KeyAuth:
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
// behaviour for default values and Authorization header.
prefix = authScheme
if !strings.HasSuffix(prefix, " ") {
prefix += " "
}
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
// If prefix is left empty the whole value is returned.
func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
return func(c echo.Context) ([]string, error) {
values := c.Request().Header.Values(header)
if len(values) == 0 {
return nil, errHeaderExtractorValueMissing
}
result := make([]string, 0)
for i, value := range values {
if prefixLen == 0 {
result = append(result, value)
if i >= extractorLimit-1 {
break
}
continue
}
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
if prefixLen > 0 {
return nil, errHeaderExtractorValueInvalid
}
return nil, errHeaderExtractorValueMissing
}
return result, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, errQueryExtractorValueMissing
} else if len(result) > extractorLimit-1 {
result = result[:extractorLimit]
}
return result, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
result := make([]string, 0)
paramVales := c.ParamValues()
for i, p := range c.ParamNames() {
if param == p {
result = append(result, paramVales[i])
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
return nil, errParamExtractorValueMissing
}
return result, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, errCookieExtractorValueMissing
}
result := make([]string, 0)
for i, cookie := range cookies {
if name == cookie.Name {
result = append(result, cookie.Value)
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
return nil, errCookieExtractorValueMissing
}
return result, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
if parseErr := c.Request().ParseForm(); parseErr != nil {
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, errFormExtractorValueMissing
}
if len(values) > extractorLimit-1 {
values = values[:extractorLimit]
}
result := append([]string{}, values...)
return result, nil
}
}

View File

@ -0,0 +1,587 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
type pathParam struct {
name string
value string
}
func setPathParams(c echo.Context, params []pathParam) {
names := make([]string, 0, len(params))
values := make([]string, 0, len(params))
for _, pp := range params {
names = append(names, pp.name)
values = append(values, pp.value)
}
c.SetParamNames(names...)
c.SetParamValues(values...)
}
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathParams []pathParam
whenLoopups string
expectValues []string
expectCreateError string
expectError string
}{
{
name: "ok, header",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
},
{
name: "ok, form",
givenRequest: func() *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
},
{
name: "ok, cookie",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
},
{
name: "ok, param",
givenPathParams: []pathParam{
{name: "id", value: "123"},
},
whenLoopups: "param:id",
expectValues: []string{"123"},
},
{
name: "ok, query",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
whenLoopups: "query:id",
expectValues: []string{"999"},
},
{
name: "nok, invalid lookup",
whenLoopups: "query",
expectCreateError: "extractor source for lookup could not be split into needed parts: query",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tc.givenPathParams != nil {
setPathParams(c, tc.givenPathParams)
}
extractors, err := createExtractors(tc.whenLoopups, "")
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
}
assert.NoError(t, err)
for _, e := range extractors {
values, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
}
assert.NoError(t, eErr)
}
})
}
}
func TestValuesFromHeader(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, single value, case insensitive",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
name: "ok, empty prefix",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Bearer ",
expectError: errHeaderExtractorValueInvalid.Error(),
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderWWWAuthenticate,
whenValuePrefix: "",
expectError: errHeaderExtractorValueMissing.Error(),
},
{
name: "nok, no headers",
givenRequest: nil,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectError: errHeaderExtractorValueMissing.Error(),
},
{
name: "ok, prefix, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i <= 25; i++ {
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
}
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
{
name: "ok, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i <= 25; i++ {
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
}
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromQuery(t *testing.T) {
var testCases = []struct {
name string
givenQueryPart string
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenQueryPart: "?id=123&name=test",
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
expectValues: []string{"123", "456"},
},
{
name: "nok, missing value",
givenQueryPart: "?id=123&name=test",
whenName: "nope",
expectError: errQueryExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenQueryPart: "?name=test" +
"&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
"&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
"&id=21&id=22&id=23&id=24&id=25",
whenName: "id",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromQuery(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromParam(t *testing.T) {
examplePathParams := []pathParam{
{name: "id", value: "123"},
{name: "gid", value: "456"},
{name: "gid", value: "789"},
}
examplePathParams20 := make([]pathParam, 0)
for i := 1; i < 25; i++ {
examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)})
}
var testCases = []struct {
name string
givenPathParams []pathParam
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenPathParams: examplePathParams,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenPathParams: examplePathParams,
whenName: "gid",
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
givenPathParams: nil,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "nok, no matching value",
givenPathParams: examplePathParams,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenPathParams: examplePathParams20,
whenName: "id",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tc.givenPathParams != nil {
setPathParams(c, tc.givenPathParams)
}
extractor := valuesFromParam(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromCookie(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: "_csrf",
expectValues: []string{"token"},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Add(echo.HeaderCookie, "_csrf=token")
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
expectValues: []string{"token", "token2"},
},
{
name: "nok, no matching cookie",
givenRequest: exampleRequest,
whenName: "xxx",
expectValues: nil,
expectError: errCookieExtractorValueMissing.Error(),
},
{
name: "nok, no cookies at all",
givenRequest: nil,
whenName: "xxx",
expectValues: nil,
expectError: errCookieExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i < 25; i++ {
req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
}
},
whenName: "_csrf",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromCookie(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromForm(t *testing.T) {
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
}
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
whenName string
expectValues []string
expectError string
}{
{
name: "ok, POST form, single value",
givenRequest: examplePostFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, POST form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, GET form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "nok, POST form, value missing",
givenRequest: examplePostFormRequest(nil),
whenName: "nope",
expectError: errFormExtractorValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
{
name: "ok, cut values over extractorLimit",
givenRequest: examplePostFormRequest(func(v *url.Values) {
for i := 1; i < 25; i++ {
v.Add("id[]", fmt.Sprintf("%v", i))
}
}),
whenName: "id[]",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := tc.givenRequest
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromForm(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -1,3 +1,4 @@
//go:build go1.15
// +build go1.15 // +build go1.15
package middleware package middleware
@ -5,12 +6,10 @@ package middleware
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http"
"reflect"
"strings"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
"reflect"
) )
type ( type (
@ -22,7 +21,8 @@ type (
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc BeforeFunc BeforeFunc BeforeFunc
// SuccessHandler defines a function which is executed for a valid token. // SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
// middleware or handler.
SuccessHandler JWTSuccessHandler SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token. // ErrorHandler defines a function which is executed for an invalid token.
@ -32,6 +32,13 @@ type (
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context. // ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext ErrorHandlerWithContext JWTErrorHandlerWithContext
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to
// ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) token value then.
ContinueOnIgnoredError bool
// Signing key to validate token. // Signing key to validate token.
// This is one of the three options to provide a token validation key. // This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey. // The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
@ -61,20 +68,25 @@ type (
// to extract token from the request. // to extract token from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization".
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end.
// In case of JWT tokens `Authorization: Bearer <token>` prefix we cut is `Bearer `.
// If prefix is left empty the whole value is returned.
// - "query:<name>" // - "query:<name>"
// - "param:<name>" // - "param:<name>"
// - "cookie:<name>" // - "cookie:<name>"
// - "form:<name>" // - "form:<name>"
// Multiply sources example: // Multiple sources example:
// - "header: Authorization,cookie: myowncookie" // - "header:Authorization,cookie:myowncookie"
TokenLookup string TokenLookup string
// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context. // TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
// This is one of the two options to provide a token extractor. // This is one of the two options to provide a token extractor.
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup. // The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
// You can also provide both if you want. // You can also provide both if you want.
TokenLookupFuncs []TokenLookupFunc TokenLookupFuncs []ValuesExtractor
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer". // Optional. Default value "Bearer".
@ -100,16 +112,13 @@ type (
} }
// JWTSuccessHandler defines a function which is executed for a valid token. // JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(echo.Context) JWTSuccessHandler func(c echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token. // JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(error) error JWTErrorHandler func(err error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error JWTErrorHandlerWithContext func(err error, c echo.Context) error
// TokenLookupFunc defines a function for extracting JWT token from the given context.
TokenLookupFunc func(echo.Context) (string, error)
) )
// Algorithms // Algorithms
@ -183,25 +192,12 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
config.ParseTokenFunc = config.defaultParseToken config.ParseTokenFunc = config.defaultParseToken
} }
// Initialize extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
// Split sources if err != nil {
sources := strings.Split(config.TokenLookup, ",") panic(err)
var extractors = config.TokenLookupFuncs
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
case "form":
extractors = append(extractors, jwtFromForm(parts[1]))
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
} }
if len(config.TokenLookupFuncs) > 0 {
extractors = append(config.TokenLookupFuncs, extractors...)
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -213,30 +209,21 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.BeforeFunc != nil { if config.BeforeFunc != nil {
config.BeforeFunc(c) config.BeforeFunc(c)
} }
var auth string
var err error var lastExtractorErr error
var lastTokenErr error
for _, extractor := range extractors { for _, extractor := range extractors {
// Extract token from extractor, if it's not fail break the loop and auths, err := extractor(c)
// set auth
auth, err = extractor(c)
if err == nil {
break
}
}
// If none of extractor has a token, handle error
if err != nil { if err != nil {
if config.ErrorHandler != nil { lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth)
return config.ErrorHandler(err) continue
} }
for _, auth := range auths {
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
}
return err
}
token, err := config.ParseTokenFunc(auth, c) token, err := config.ParseTokenFunc(auth, c)
if err == nil { if err != nil {
lastTokenErr = err
continue
}
// Store user information from token into context. // Store user information from token into context.
c.Set(config.ContextKey, token) c.Set(config.ContextKey, token)
if config.SuccessHandler != nil { if config.SuccessHandler != nil {
@ -244,18 +231,33 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
} }
return next(c) return next(c)
} }
}
// we are here only when we did not successfully extract or parse any of the tokens
err := lastTokenErr
if err == nil { // prioritize token errors over extracting errors
err = lastExtractorErr
}
if config.ErrorHandler != nil { if config.ErrorHandler != nil {
return config.ErrorHandler(err) return config.ErrorHandler(err)
} }
if config.ErrorHandlerWithContext != nil { if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c) tmpErr := config.ErrorHandlerWithContext(err, c)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
} }
return tmpErr
}
// backwards compatible errors codes
if lastTokenErr != nil {
return &echo.HTTPError{ return &echo.HTTPError{
Code: ErrJWTInvalid.Code, Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message, Message: ErrJWTInvalid.Message,
Internal: err, Internal: err,
} }
} }
return err // this is lastExtractorErr value
}
} }
} }
@ -296,59 +298,3 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
return config.SigningKey, nil return config.SigningKey, nil
} }
// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
func jwtFromQuery(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
func jwtFromParam(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
func jwtFromCookie(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie.Value, nil
}
}
// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
func jwtFromForm(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -1,3 +1,4 @@
//go:build go1.15
// +build go1.15 // +build go1.15
package middleware package middleware
@ -28,6 +29,26 @@ type jwtCustomClaims struct {
jwtCustomInfo jwtCustomInfo
} }
func TestJWT(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
})
e.Use(JWT([]byte("secret")))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
}
func TestJWTRace(t *testing.T) { func TestJWTRace(t *testing.T) {
e := echo.New() e := echo.New()
handler := func(c echo.Context) error { handler := func(c echo.Context) error {
@ -64,8 +85,7 @@ func TestJWTRace(t *testing.T) {
assert.Equal(t, claims.Admin, true) assert.Equal(t, claims.Admin, true)
} }
func TestJWT(t *testing.T) { func TestJWTConfig(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error { handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }
@ -74,7 +94,8 @@ func TestJWT(t *testing.T) {
invalidKey := []byte("invalid-key") invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token validAuth := DefaultJWTConfig.AuthScheme + " " + token
for _, tc := range []struct { testCases := []struct {
name string
expPanic bool expPanic bool
expErrCode int // 0 for Success expErrCode int // 0 for Success
config JWTConfig config JWTConfig
@ -82,166 +103,166 @@ func TestJWT(t *testing.T) {
hdrAuth string hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string formValues map[string]string
info string
}{ }{
{ {
name: "No signing key provided",
expPanic: true, expPanic: true,
info: "No signing key provided",
}, },
{ {
name: "Unexpected signing method",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
SigningMethod: "RS256", SigningMethod: "RS256",
}, },
info: "Unexpected signing method",
}, },
{ {
name: "Invalid key",
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{SigningKey: invalidKey}, config: JWTConfig{SigningKey: invalidKey},
info: "Invalid key",
}, },
{ {
name: "Valid JWT",
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{SigningKey: validKey},
info: "Valid JWT",
}, },
{ {
name: "Valid JWT with custom AuthScheme",
hdrAuth: "Token" + " " + token, hdrAuth: "Token" + " " + token,
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
info: "Valid JWT with custom AuthScheme",
}, },
{ {
name: "Valid JWT with custom claims",
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{ config: JWTConfig{
Claims: &jwtCustomClaims{}, Claims: &jwtCustomClaims{},
SigningKey: []byte("secret"), SigningKey: []byte("secret"),
}, },
info: "Valid JWT with custom claims",
}, },
{ {
name: "Invalid Authorization header",
hdrAuth: "invalid-auth", hdrAuth: "invalid-auth",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{SigningKey: validKey},
info: "Invalid Authorization header",
}, },
{ {
name: "Empty header auth field",
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{SigningKey: validKey},
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty header auth field",
}, },
{ {
name: "Valid query method",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwt=" + token, reqURL: "/?a=b&jwt=" + token,
info: "Valid query method",
}, },
{ {
name: "Invalid query param name",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwtxyz=" + token, reqURL: "/?a=b&jwtxyz=" + token,
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Invalid query param name",
}, },
{ {
name: "Invalid query param value",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b&jwt=invalid-token", reqURL: "/?a=b&jwt=invalid-token",
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Invalid query param value",
}, },
{ {
name: "Empty query",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "query:jwt", TokenLookup: "query:jwt",
}, },
reqURL: "/?a=b", reqURL: "/?a=b",
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty query",
}, },
{ {
name: "Valid param method",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "param:jwt", TokenLookup: "param:jwt",
}, },
reqURL: "/" + token, reqURL: "/" + token,
info: "Valid param method",
}, },
{ {
name: "Valid cookie method",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
hdrCookie: "jwt=" + token, hdrCookie: "jwt=" + token,
info: "Valid cookie method",
}, },
{ {
name: "Multiple jwt lookuop",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "query:jwt,cookie:jwt", TokenLookup: "query:jwt,cookie:jwt",
}, },
hdrCookie: "jwt=" + token, hdrCookie: "jwt=" + token,
info: "Multiple jwt lookuop",
}, },
{ {
name: "Invalid token with cookie method",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
hdrCookie: "jwt=invalid", hdrCookie: "jwt=invalid",
info: "Invalid token with cookie method",
}, },
{ {
name: "Empty cookie",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "cookie:jwt", TokenLookup: "cookie:jwt",
}, },
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty cookie",
}, },
{ {
name: "Valid form method",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
formValues: map[string]string{"jwt": token}, formValues: map[string]string{"jwt": token},
info: "Valid form method",
}, },
{ {
name: "Invalid token with form method",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"}, formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
}, },
{ {
name: "Empty form field",
config: JWTConfig{ config: JWTConfig{
SigningKey: validKey, SigningKey: validKey,
TokenLookup: "form:jwt", TokenLookup: "form:jwt",
}, },
expErrCode: http.StatusBadRequest, expErrCode: http.StatusBadRequest,
info: "Empty form field",
}, },
{ {
name: "Valid JWT with a valid key using a user-defined KeyFunc",
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{ config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) { KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil return validKey, nil
}, },
}, },
info: "Valid JWT with a valid key using a user-defined KeyFunc",
}, },
{ {
name: "Valid JWT with an invalid key using a user-defined KeyFunc",
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{ config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) { KeyFunc: func(*jwt.Token) (interface{}, error) {
@ -249,9 +270,9 @@ func TestJWT(t *testing.T) {
}, },
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
}, },
{ {
name: "Token verification does not pass using a user-defined KeyFunc",
hdrAuth: validAuth, hdrAuth: validAuth,
config: JWTConfig{ config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) { KeyFunc: func(*jwt.Token) (interface{}, error) {
@ -259,14 +280,16 @@ func TestJWT(t *testing.T) {
}, },
}, },
expErrCode: http.StatusUnauthorized, expErrCode: http.StatusUnauthorized,
info: "Token verification does not pass using a user-defined KeyFunc",
}, },
{ {
name: "Valid JWT with lower case AuthScheme",
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token, hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
config: JWTConfig{SigningKey: validKey}, config: JWTConfig{SigningKey: validKey},
info: "Valid JWT with lower case AuthScheme",
}, },
} { }
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
if tc.reqURL == "" { if tc.reqURL == "" {
tc.reqURL = "/" tc.reqURL = "/"
} }
@ -296,30 +319,31 @@ func TestJWT(t *testing.T) {
if tc.expPanic { if tc.expPanic {
assert.Panics(t, func() { assert.Panics(t, func() {
JWTWithConfig(tc.config) JWTWithConfig(tc.config)
}, tc.info) }, tc.name)
continue return
} }
if tc.expErrCode != 0 { if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler) h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError) he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code, tc.info) assert.Equal(t, tc.expErrCode, he.Code, tc.name)
continue return
} }
h := JWTWithConfig(tc.config)(handler) h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.info) { if assert.NoError(t, h(c), tc.name) {
user := c.Get("user").(*jwt.Token) user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) { switch claims := user.Claims.(type) {
case jwt.MapClaims: case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe", tc.info) assert.Equal(t, claims["name"], "John Doe", tc.name)
case *jwtCustomClaims: case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe", tc.info) assert.Equal(t, claims.Name, "John Doe", tc.name)
assert.Equal(t, claims.Admin, true, tc.info) assert.Equal(t, claims.Admin, true, tc.name)
default: default:
panic("unexpected type of claims") panic("unexpected type of claims")
} }
} }
})
} }
} }
@ -608,13 +632,14 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e := echo.New() e := echo.New()
e.GET("/", func(c echo.Context) error { e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "test") token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
}) })
e.Use(JWTWithConfig(JWTConfig{ e.Use(JWTWithConfig(JWTConfig{
TokenLookupFuncs: []TokenLookupFunc{ TokenLookupFuncs: []ValuesExtractor{
func(c echo.Context) (string, error) { func(c echo.Context) ([]string, error) {
return c.Request().Header.Get("X-API-Key"), nil return []string{c.Request().Header.Get("X-API-Key")}, nil
}, },
}, },
SigningKey: []byte("secret"), SigningKey: []byte("secret"),
@ -626,4 +651,129 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e.ServeHTTP(res, req) e.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code) assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
}
func TestJWTConfig_SuccessHandler(t *testing.T) {
var testCases = []struct {
name string
givenToken string
expectCalled bool
expectStatus int
}{
{
name: "ok, success handler is called",
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
expectCalled: true,
expectStatus: http.StatusOK,
},
{
name: "nok, success handler is not called",
givenToken: "x.x.x",
expectCalled: false,
expectStatus: http.StatusUnauthorized,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
})
wasCalled := false
e.Use(JWTWithConfig(JWTConfig{
SuccessHandler: func(c echo.Context) {
wasCalled = true
},
SigningKey: []byte("secret"),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectCalled, wasCalled)
assert.Equal(t, tc.expectStatus, res.Code)
})
}
}
func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
name string
whenContinueOnIgnoredError bool
givenToken string
expectStatus int
expectBody string
}{
{
name: "no error handler is called",
whenContinueOnIgnoredError: true,
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
expectStatus: http.StatusTeapot,
expectBody: "",
},
{
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
whenContinueOnIgnoredError: false,
givenToken: "",
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
expectStatus: http.StatusOK,
expectBody: "",
},
{
name: "error handler is called for missing token",
whenContinueOnIgnoredError: true,
givenToken: "",
expectStatus: http.StatusTeapot,
expectBody: "public-token",
},
{
name: "error handler is called for invalid token",
whenContinueOnIgnoredError: true,
givenToken: "x.x.x",
expectStatus: http.StatusUnauthorized,
expectBody: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
testValue, _ := c.Get("test").(string)
return c.String(http.StatusTeapot, testValue)
})
e.Use(JWTWithConfig(JWTConfig{
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, c echo.Context) error {
if err == ErrJWTMissing {
c.Set("test", "public-token")
return nil
}
return echo.ErrUnauthorized
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenToken != "" {
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectStatus, res.Code)
assert.Equal(t, tc.expectBody, res.Body.String())
})
}
} }

View File

@ -2,11 +2,8 @@ package middleware
import ( import (
"errors" "errors"
"fmt"
"net/http"
"strings"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http"
) )
type ( type (
@ -15,15 +12,21 @@ type (
// Skipper defines a function to skip middleware. // Skipper defines a function to skip middleware.
Skipper Skipper Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used // KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key from the request. // to extract key from the request.
// Optional. Default value "header:Authorization". // Optional. Default value "header:Authorization".
// Possible values: // Possible values:
// - "header:<name>" // - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end.
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
// - "query:<name>" // - "query:<name>"
// - "form:<name>" // - "form:<name>"
// - "cookie:<name>" // - "cookie:<name>"
KeyLookup string `yaml:"key_lookup"` // Multiple sources example:
// - "header:Authorization,header:X-Api-Key"
KeyLookup string
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer". // Optional. Default value "Bearer".
@ -36,15 +39,20 @@ type (
// ErrorHandler defines a function which is executed for an invalid key. // ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error. // It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler ErrorHandler KeyAuthErrorHandler
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
// ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandler to set a default public key auth value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
ContinueOnIgnoredError bool
} }
// KeyAuthValidator defines a function to validate KeyAuth credentials. // KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error) KeyAuthValidator func(auth string, c echo.Context) (bool, error)
keyExtractor func(echo.Context) (string, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key. // KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error KeyAuthErrorHandler func(err error, c echo.Context) error
) )
var ( var (
@ -56,6 +64,21 @@ var (
} }
) )
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
type ErrKeyAuthMissing struct {
Err error
}
// Error returns errors text
func (e *ErrKeyAuthMissing) Error() string {
return e.Err.Error()
}
// Unwrap unwraps error
func (e *ErrKeyAuthMissing) Unwrap() error {
return e.Err
}
// KeyAuth returns an KeyAuth middleware. // KeyAuth returns an KeyAuth middleware.
// //
// For valid key it calls the next handler. // For valid key it calls the next handler.
@ -85,16 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
panic("echo: key-auth middleware requires a validator function") panic("echo: key-auth middleware requires a validator function")
} }
// Initialize extractors, err := createExtractors(config.KeyLookup, config.AuthScheme)
parts := strings.Split(config.KeyLookup, ":") if err != nil {
extractor := keyFromHeader(parts[1], config.AuthScheme) panic(err)
switch parts[0] {
case "query":
extractor = keyFromQuery(parts[1])
case "form":
extractor = keyFromForm(parts[1])
case "cookie":
extractor = keyFromCookie(parts[1])
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -103,79 +119,62 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
// Extract and verify key var lastExtractorErr error
key, err := extractor(c) var lastValidatorErr error
for _, extractor := range extractors {
keys, err := extractor(c)
if err != nil { if err != nil {
lastExtractorErr = err
continue
}
for _, key := range keys {
valid, err := config.Validator(key, c)
if err != nil {
lastValidatorErr = err
continue
}
if valid {
return next(c)
}
lastValidatorErr = errors.New("invalid key")
}
}
// we are here only when we did not successfully extract and validate any of keys
err := lastValidatorErr
if err == nil { // prioritize validator errors over extracting errors
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
err = errors.New("missing key in the query string")
} else if lastExtractorErr == errCookieExtractorValueMissing {
err = errors.New("missing key in cookies")
} else if lastExtractorErr == errFormExtractorValueMissing {
err = errors.New("missing key in the form")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
err = errors.New("missing key in request header")
} else if lastExtractorErr == errHeaderExtractorValueInvalid {
err = errors.New("invalid key in the request header")
} else {
err = lastExtractorErr
}
err = &ErrKeyAuthMissing{Err: err}
}
if config.ErrorHandler != nil { if config.ErrorHandler != nil {
return config.ErrorHandler(err, c) tmpErr := config.ErrorHandler(err, c)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
if lastValidatorErr != nil { // prioritize validator errors over extracting errors
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "Unauthorized",
Internal: lastValidatorErr,
}
} }
return echo.NewHTTPError(http.StatusBadRequest, err.Error()) return echo.NewHTTPError(http.StatusBadRequest, err.Error())
} }
valid, err := config.Validator(key, c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Internal: err,
}
} else if valid {
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
}
return key, nil
}
}
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
}
return key, nil
}
}
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
func keyFromCookie(cookieName string) keyExtractor {
return func(c echo.Context) (string, error) {
key, err := c.Cookie(cookieName)
if err != nil {
return "", fmt.Errorf("missing key in cookies: %w", err)
}
return key.Value, nil
} }
} }

View File

@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized", expectError: "code=401, message=Unauthorized, internal=invalid key",
}, },
{ {
name: "nok, defaults, invalid scheme in header", name: "nok, defaults, invalid scheme in header",
@ -92,6 +92,17 @@ func TestKeyAuthWithConfig(t *testing.T) {
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header", expectError: "code=400, message=missing key in request header",
}, },
{
name: "ok, custom key lookup from multiple places, query and header",
givenRequest: func(req *http.Request) {
req.URL.RawQuery = "key=invalid-key"
req.Header.Set("API-Key", "valid-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key,header:API-Key"
},
expectHandlerCalled: true,
},
{ {
name: "ok, custom key lookup, header", name: "ok, custom key lookup, header",
givenRequest: func(req *http.Request) { givenRequest: func(req *http.Request) {
@ -179,7 +190,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key" conf.KeyLookup = "cookie:key"
}, },
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=400, message=missing key in cookies: http: named cookie not present", expectError: "code=400, message=missing key in cookies",
}, },
{ {
name: "nok, custom errorHandler, error from extractor", name: "nok, custom errorHandler, error from extractor",
@ -216,7 +227,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
}, },
whenConfig: func(conf *KeyAuthConfig) {}, whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false, expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error", expectError: "code=401, message=Unauthorized, internal=some user defined error",
}, },
} }
@ -257,3 +268,109 @@ func TestKeyAuthWithConfig(t *testing.T) {
}) })
} }
} }
func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
assert.PanicsWithError(
t,
"extractor source for lookup could not be split into needed parts: a",
func() {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
KeyLookup: "a",
})(handler)
},
)
}
func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
assert.PanicsWithValue(
t,
"echo: key-auth middleware requires a validator function",
func() {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
KeyAuthWithConfig(KeyAuthConfig{
Validator: nil,
})(handler)
},
)
}
func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
name string
whenContinueOnIgnoredError bool
givenKey string
expectStatus int
expectBody string
}{
{
name: "no error handler is called",
whenContinueOnIgnoredError: true,
givenKey: "valid-key",
expectStatus: http.StatusTeapot,
expectBody: "",
},
{
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
whenContinueOnIgnoredError: false,
givenKey: "",
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
expectStatus: http.StatusOK,
expectBody: "",
},
{
name: "error handler is called for missing token",
whenContinueOnIgnoredError: true,
givenKey: "",
expectStatus: http.StatusTeapot,
expectBody: "public-auth",
},
{
name: "error handler is called for invalid token",
whenContinueOnIgnoredError: true,
givenKey: "x.x.x",
expectStatus: http.StatusUnauthorized,
expectBody: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
testValue, _ := c.Get("test").(string)
return c.String(http.StatusTeapot, testValue)
})
e.Use(KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(err error, c echo.Context) error {
if _, ok := err.(*ErrKeyAuthMissing); ok {
c.Set("test", "public-auth")
return nil
}
return echo.ErrUnauthorized
},
KeyLookup: "header:X-API-Key",
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenKey != "" {
req.Header.Set("X-API-Key", tc.givenKey)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectStatus, res.Code)
assert.Equal(t, tc.expectBody, res.Body.String())
})
}
}

View File

@ -12,10 +12,10 @@ import (
type ( type (
// Skipper defines a function to skip middleware. Returning true skips processing // Skipper defines a function to skip middleware. Returning true skips processing
// the middleware. // the middleware.
Skipper func(echo.Context) bool Skipper func(c echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware. // BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(echo.Context) BeforeFunc func(c echo.Context)
) )
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {