mirror of
https://github.com/labstack/echo.git
synced 2025-01-12 01:22:21 +02:00
JWT, KeyAuth, CSRF multivalue extractors (#2060)
* CSRF, JWT, KeyAuth middleware support for multivalue value extractors * Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
This commit is contained in:
parent
9e9924d763
commit
4a1ccdfdc5
4
echo.go
4
echo.go
@ -111,10 +111,10 @@ type (
|
||||
}
|
||||
|
||||
// MiddlewareFunc defines a function to process middleware.
|
||||
MiddlewareFunc func(HandlerFunc) HandlerFunc
|
||||
MiddlewareFunc func(next HandlerFunc) HandlerFunc
|
||||
|
||||
// HandlerFunc defines a function to serve HTTP requests.
|
||||
HandlerFunc func(Context) error
|
||||
HandlerFunc func(c Context) error
|
||||
|
||||
// HTTPErrorHandler is a centralized HTTP error handler.
|
||||
HTTPErrorHandler func(error, Context)
|
||||
|
@ -2,9 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
@ -21,13 +19,15 @@ type (
|
||||
TokenLength uint8 `yaml:"token_length"`
|
||||
// Optional. Default value 32.
|
||||
|
||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
||||
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||
// to extract token from the request.
|
||||
// Optional. Default value "header:X-CSRF-Token".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "form:<name>"
|
||||
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||
// - "query:<name>"
|
||||
// - "form:<name>"
|
||||
// Multiple sources example:
|
||||
// - "header:X-CSRF-Token,query:csrf"
|
||||
TokenLookup string `yaml:"token_lookup"`
|
||||
|
||||
// Context key to store generated CSRF token into context.
|
||||
@ -62,12 +62,11 @@ type (
|
||||
// Optional. Default value SameSiteDefaultMode.
|
||||
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
|
||||
}
|
||||
|
||||
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
||||
// either a token or an error.
|
||||
csrfTokenExtractor func(echo.Context) (string, error)
|
||||
)
|
||||
|
||||
// ErrCSRFInvalid is returned when CSRF check fails
|
||||
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
||||
|
||||
var (
|
||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||
DefaultCSRFConfig = CSRFConfig{
|
||||
@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
config.CookieSecure = true
|
||||
}
|
||||
|
||||
// Initialize
|
||||
parts := strings.Split(config.TokenLookup, ":")
|
||||
extractor := csrfTokenFromHeader(parts[1])
|
||||
switch parts[0] {
|
||||
case "form":
|
||||
extractor = csrfTokenFromForm(parts[1])
|
||||
case "query":
|
||||
extractor = csrfTokenFromQuery(parts[1])
|
||||
extractors, err := createExtractors(config.TokenLookup, "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
req := c.Request()
|
||||
k, err := c.Cookie(config.CookieName)
|
||||
token := ""
|
||||
|
||||
// Generate token
|
||||
if err != nil {
|
||||
token = random.String(config.TokenLength)
|
||||
if k, err := c.Cookie(config.CookieName); err != nil {
|
||||
token = random.String(config.TokenLength) // Generate token
|
||||
} else {
|
||||
// Reuse token
|
||||
token = k.Value
|
||||
token = k.Value // Reuse token
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
switch c.Request().Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
default:
|
||||
// Validate token only for requests which are not defined as 'safe' by RFC7231
|
||||
clientToken, err := extractor(c)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
var lastExtractorErr error
|
||||
var lastTokenErr error
|
||||
outer:
|
||||
for _, extractor := range extractors {
|
||||
clientTokens, err := extractor(c)
|
||||
if err != nil {
|
||||
lastExtractorErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
for _, clientToken := range clientTokens {
|
||||
if validateCSRFToken(token, clientToken) {
|
||||
lastTokenErr = nil
|
||||
lastExtractorErr = nil
|
||||
break outer
|
||||
}
|
||||
lastTokenErr = ErrCSRFInvalid
|
||||
}
|
||||
}
|
||||
if !validateCSRFToken(token, clientToken) {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
||||
if lastTokenErr != nil {
|
||||
return lastTokenErr
|
||||
} else if lastExtractorErr != nil {
|
||||
// ugly part to preserve backwards compatible errors. someone could rely on them
|
||||
if lastExtractorErr == errQueryExtractorValueMissing {
|
||||
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
|
||||
} else if lastExtractorErr == errFormExtractorValueMissing {
|
||||
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
|
||||
} else if lastExtractorErr == errHeaderExtractorValueMissing {
|
||||
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
|
||||
} else {
|
||||
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
|
||||
}
|
||||
return lastExtractorErr
|
||||
}
|
||||
}
|
||||
|
||||
@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||||
// provided request header.
|
||||
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
return c.Request().Header.Get(header), nil
|
||||
}
|
||||
}
|
||||
|
||||
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
||||
// provided form parameter.
|
||||
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.FormValue(param)
|
||||
if token == "" {
|
||||
return "", errors.New("missing csrf token in the form parameter")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
|
||||
// provided query parameter.
|
||||
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.QueryParam(param)
|
||||
if token == "" {
|
||||
return "", errors.New("missing csrf token in the query string")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
func validateCSRFToken(token, clientToken string) bool {
|
||||
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -13,14 +12,205 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCSRF_tokenExtractors(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenTokenLookup string
|
||||
whenCookieName string
|
||||
givenCSRFCookie string
|
||||
givenMethod string
|
||||
givenQueryTokens map[string][]string
|
||||
givenFormTokens map[string][]string
|
||||
givenHeaderTokens map[string][]string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, multiple token lookups sources, succeeds on last one",
|
||||
whenTokenLookup: "header:X-CSRF-Token,form:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenHeaderTokens: map[string][]string{
|
||||
echo.HeaderXCSRFToken: {"invalid_token"},
|
||||
},
|
||||
givenFormTokens: map[string][]string{
|
||||
"csrf": {"token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, token from POST form",
|
||||
whenTokenLookup: "form:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenFormTokens: map[string][]string{
|
||||
"csrf": {"token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, token from POST form, second token passes",
|
||||
whenTokenLookup: "form:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenFormTokens: map[string][]string{
|
||||
"csrf": {"invalid", "token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid token from POST form",
|
||||
whenTokenLookup: "form:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenFormTokens: map[string][]string{
|
||||
"csrf": {"invalid_token"},
|
||||
},
|
||||
expectError: "code=403, message=invalid csrf token",
|
||||
},
|
||||
{
|
||||
name: "nok, missing token from POST form",
|
||||
whenTokenLookup: "form:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenFormTokens: map[string][]string{},
|
||||
expectError: "code=400, message=missing csrf token in the form parameter",
|
||||
},
|
||||
{
|
||||
name: "ok, token from POST header",
|
||||
whenTokenLookup: "", // will use defaults
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenHeaderTokens: map[string][]string{
|
||||
echo.HeaderXCSRFToken: {"token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, token from POST header, second token passes",
|
||||
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenHeaderTokens: map[string][]string{
|
||||
echo.HeaderXCSRFToken: {"invalid", "token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid token from POST header",
|
||||
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenHeaderTokens: map[string][]string{
|
||||
echo.HeaderXCSRFToken: {"invalid_token"},
|
||||
},
|
||||
expectError: "code=403, message=invalid csrf token",
|
||||
},
|
||||
{
|
||||
name: "nok, missing token from POST header",
|
||||
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPost,
|
||||
givenHeaderTokens: map[string][]string{},
|
||||
expectError: "code=400, message=missing csrf token in request header",
|
||||
},
|
||||
{
|
||||
name: "ok, token from PUT query param",
|
||||
whenTokenLookup: "query:csrf-param",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPut,
|
||||
givenQueryTokens: map[string][]string{
|
||||
"csrf-param": {"token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, token from PUT query form, second token passes",
|
||||
whenTokenLookup: "query:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPut,
|
||||
givenQueryTokens: map[string][]string{
|
||||
"csrf": {"invalid", "token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid token from PUT query form",
|
||||
whenTokenLookup: "query:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPut,
|
||||
givenQueryTokens: map[string][]string{
|
||||
"csrf": {"invalid_token"},
|
||||
},
|
||||
expectError: "code=403, message=invalid csrf token",
|
||||
},
|
||||
{
|
||||
name: "nok, missing token from PUT query form",
|
||||
whenTokenLookup: "query:csrf",
|
||||
givenCSRFCookie: "token",
|
||||
givenMethod: http.MethodPut,
|
||||
givenQueryTokens: map[string][]string{},
|
||||
expectError: "code=400, message=missing csrf token in the query string",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
q := make(url.Values)
|
||||
for queryParam, values := range tc.givenQueryTokens {
|
||||
for _, v := range values {
|
||||
q.Add(queryParam, v)
|
||||
}
|
||||
}
|
||||
|
||||
f := make(url.Values)
|
||||
for formKey, values := range tc.givenFormTokens {
|
||||
for _, v := range values {
|
||||
f.Add(formKey, v)
|
||||
}
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
switch tc.givenMethod {
|
||||
case http.MethodGet:
|
||||
req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
|
||||
case http.MethodPost, http.MethodPut:
|
||||
req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
}
|
||||
|
||||
for header, values := range tc.givenHeaderTokens {
|
||||
for _, v := range values {
|
||||
req.Header.Add(header, v)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.givenCSRFCookie != "" {
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
csrf := CSRFWithConfig(CSRFConfig{
|
||||
TokenLookup: tc.whenTokenLookup,
|
||||
CookieName: tc.whenCookieName,
|
||||
})
|
||||
|
||||
h := csrf(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
err := h(c)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
csrf := CSRFWithConfig(CSRFConfig{
|
||||
TokenLength: 16,
|
||||
})
|
||||
csrf := CSRF()
|
||||
h := csrf(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
@ -43,7 +233,7 @@ func TestCSRF(t *testing.T) {
|
||||
assert.Error(t, h(c))
|
||||
|
||||
// Valid CSRF token
|
||||
token := random.String(16)
|
||||
token := random.String(32)
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
||||
req.Header.Set(echo.HeaderXCSRFToken, token)
|
||||
if assert.NoError(t, h(c)) {
|
||||
@ -51,38 +241,6 @@ func TestCSRF(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFTokenFromForm(t *testing.T) {
|
||||
f := make(url.Values)
|
||||
f.Set("csrf", "token")
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
c := e.NewContext(req, nil)
|
||||
token, err := csrfTokenFromForm("csrf")(c)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "token", token)
|
||||
}
|
||||
_, err = csrfTokenFromForm("invalid")(c)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCSRFTokenFromQuery(t *testing.T) {
|
||||
q := make(url.Values)
|
||||
q.Set("csrf", "token")
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
c := e.NewContext(req, nil)
|
||||
token, err := csrfTokenFromQuery("csrf")(c)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "token", token)
|
||||
}
|
||||
_, err = csrfTokenFromQuery("invalid")(c)
|
||||
assert.Error(t, err)
|
||||
csrfTokenFromQuery("csrf")
|
||||
}
|
||||
|
||||
func TestCSRFSetSameSiteMode(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
@ -135,7 +293,6 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
|
||||
|
||||
r := h(c)
|
||||
assert.NoError(t, r)
|
||||
fmt.Println(rec.Header()["Set-Cookie"])
|
||||
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
|
||||
}
|
||||
|
||||
@ -158,3 +315,46 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
|
||||
assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
|
||||
assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
|
||||
}
|
||||
|
||||
func TestCSRFConfig_skipper(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenSkip bool
|
||||
expectCookies int
|
||||
}{
|
||||
{
|
||||
name: "do skip",
|
||||
whenSkip: true,
|
||||
expectCookies: 0,
|
||||
},
|
||||
{
|
||||
name: "do not skip",
|
||||
whenSkip: false,
|
||||
expectCookies: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
csrf := CSRFWithConfig(CSRFConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return tc.whenSkip
|
||||
},
|
||||
})
|
||||
|
||||
h := csrf(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
r := h(c)
|
||||
assert.NoError(t, r)
|
||||
cookie := rec.Header()["Set-Cookie"]
|
||||
assert.Len(t, cookie, tc.expectCookies)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
184
middleware/extractor.go
Normal file
184
middleware/extractor.go
Normal file
@ -0,0 +1,184 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
|
||||
// attack vector
|
||||
extractorLimit = 20
|
||||
)
|
||||
|
||||
var errHeaderExtractorValueMissing = errors.New("missing value in request header")
|
||||
var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
|
||||
var errQueryExtractorValueMissing = errors.New("missing value in the query string")
|
||||
var errParamExtractorValueMissing = errors.New("missing value in path params")
|
||||
var errCookieExtractorValueMissing = errors.New("missing value in cookies")
|
||||
var errFormExtractorValueMissing = errors.New("missing value in the form")
|
||||
|
||||
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
|
||||
type ValuesExtractor func(c echo.Context) ([]string, error)
|
||||
|
||||
func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
|
||||
if lookups == "" {
|
||||
return nil, nil
|
||||
}
|
||||
sources := strings.Split(lookups, ",")
|
||||
var extractors = make([]ValuesExtractor, 0)
|
||||
for _, source := range sources {
|
||||
parts := strings.Split(source, ":")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "query":
|
||||
extractors = append(extractors, valuesFromQuery(parts[1]))
|
||||
case "param":
|
||||
extractors = append(extractors, valuesFromParam(parts[1]))
|
||||
case "cookie":
|
||||
extractors = append(extractors, valuesFromCookie(parts[1]))
|
||||
case "form":
|
||||
extractors = append(extractors, valuesFromForm(parts[1]))
|
||||
case "header":
|
||||
prefix := ""
|
||||
if len(parts) > 2 {
|
||||
prefix = parts[2]
|
||||
} else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
|
||||
// backwards compatibility for JWT and KeyAuth:
|
||||
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
|
||||
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
|
||||
// behaviour for default values and Authorization header.
|
||||
prefix = authScheme
|
||||
if !strings.HasSuffix(prefix, " ") {
|
||||
prefix += " "
|
||||
}
|
||||
}
|
||||
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
|
||||
}
|
||||
}
|
||||
return extractors, nil
|
||||
}
|
||||
|
||||
// valuesFromHeader returns a functions that extracts values from the request header.
|
||||
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
|
||||
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
|
||||
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
|
||||
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
|
||||
// If prefix is left empty the whole value is returned.
|
||||
func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
|
||||
prefixLen := len(valuePrefix)
|
||||
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
|
||||
header = textproto.CanonicalMIMEHeaderKey(header)
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
values := c.Request().Header.Values(header)
|
||||
if len(values) == 0 {
|
||||
return nil, errHeaderExtractorValueMissing
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
for i, value := range values {
|
||||
if prefixLen == 0 {
|
||||
result = append(result, value)
|
||||
if i >= extractorLimit-1 {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
|
||||
result = append(result, value[prefixLen:])
|
||||
if i >= extractorLimit-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if prefixLen > 0 {
|
||||
return nil, errHeaderExtractorValueInvalid
|
||||
}
|
||||
return nil, errHeaderExtractorValueMissing
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromQuery returns a function that extracts values from the query string.
|
||||
func valuesFromQuery(param string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
result := c.QueryParams()[param]
|
||||
if len(result) == 0 {
|
||||
return nil, errQueryExtractorValueMissing
|
||||
} else if len(result) > extractorLimit-1 {
|
||||
result = result[:extractorLimit]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromParam returns a function that extracts values from the url param string.
|
||||
func valuesFromParam(param string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
result := make([]string, 0)
|
||||
paramVales := c.ParamValues()
|
||||
for i, p := range c.ParamNames() {
|
||||
if param == p {
|
||||
result = append(result, paramVales[i])
|
||||
if i >= extractorLimit-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errParamExtractorValueMissing
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromCookie returns a function that extracts values from the named cookie.
|
||||
func valuesFromCookie(name string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
cookies := c.Cookies()
|
||||
if len(cookies) == 0 {
|
||||
return nil, errCookieExtractorValueMissing
|
||||
}
|
||||
|
||||
result := make([]string, 0)
|
||||
for i, cookie := range cookies {
|
||||
if name == cookie.Name {
|
||||
result = append(result, cookie.Value)
|
||||
if i >= extractorLimit-1 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, errCookieExtractorValueMissing
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// valuesFromForm returns a function that extracts values from the form field.
|
||||
func valuesFromForm(name string) ValuesExtractor {
|
||||
return func(c echo.Context) ([]string, error) {
|
||||
if parseErr := c.Request().ParseForm(); parseErr != nil {
|
||||
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
|
||||
}
|
||||
values := c.Request().Form[name]
|
||||
if len(values) == 0 {
|
||||
return nil, errFormExtractorValueMissing
|
||||
}
|
||||
if len(values) > extractorLimit-1 {
|
||||
values = values[:extractorLimit]
|
||||
}
|
||||
result := append([]string{}, values...)
|
||||
return result, nil
|
||||
}
|
||||
}
|
587
middleware/extractor_test.go
Normal file
587
middleware/extractor_test.go
Normal file
@ -0,0 +1,587 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type pathParam struct {
|
||||
name string
|
||||
value string
|
||||
}
|
||||
|
||||
func setPathParams(c echo.Context, params []pathParam) {
|
||||
names := make([]string, 0, len(params))
|
||||
values := make([]string, 0, len(params))
|
||||
for _, pp := range params {
|
||||
names = append(names, pp.name)
|
||||
values = append(values, pp.value)
|
||||
}
|
||||
c.SetParamNames(names...)
|
||||
c.SetParamValues(values...)
|
||||
}
|
||||
|
||||
func TestCreateExtractors(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest func() *http.Request
|
||||
givenPathParams []pathParam
|
||||
whenLoopups string
|
||||
expectValues []string
|
||||
expectCreateError string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, header",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
|
||||
return req
|
||||
},
|
||||
whenLoopups: "header:Authorization:Bearer ",
|
||||
expectValues: []string{"token"},
|
||||
},
|
||||
{
|
||||
name: "ok, form",
|
||||
givenRequest: func() *http.Request {
|
||||
f := make(url.Values)
|
||||
f.Set("name", "Jon Snow")
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
return req
|
||||
},
|
||||
whenLoopups: "form:name",
|
||||
expectValues: []string{"Jon Snow"},
|
||||
},
|
||||
{
|
||||
name: "ok, cookie",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf=token")
|
||||
return req
|
||||
},
|
||||
whenLoopups: "cookie:_csrf",
|
||||
expectValues: []string{"token"},
|
||||
},
|
||||
{
|
||||
name: "ok, param",
|
||||
givenPathParams: []pathParam{
|
||||
{name: "id", value: "123"},
|
||||
},
|
||||
whenLoopups: "param:id",
|
||||
expectValues: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "ok, query",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
|
||||
return req
|
||||
},
|
||||
whenLoopups: "query:id",
|
||||
expectValues: []string{"999"},
|
||||
},
|
||||
{
|
||||
name: "nok, invalid lookup",
|
||||
whenLoopups: "query",
|
||||
expectCreateError: "extractor source for lookup could not be split into needed parts: query",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequest != nil {
|
||||
req = tc.givenRequest()
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
if tc.givenPathParams != nil {
|
||||
setPathParams(c, tc.givenPathParams)
|
||||
}
|
||||
|
||||
extractors, err := createExtractors(tc.whenLoopups, "")
|
||||
if tc.expectCreateError != "" {
|
||||
assert.EqualError(t, err, tc.expectCreateError)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, e := range extractors {
|
||||
values, eErr := e(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, eErr, tc.expectError)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, eErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromHeader(t *testing.T) {
|
||||
exampleRequest := func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest func(req *http.Request)
|
||||
whenName string
|
||||
whenValuePrefix string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
|
||||
},
|
||||
{
|
||||
name: "ok, single value, case insensitive",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "Basic ",
|
||||
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||
},
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
|
||||
},
|
||||
{
|
||||
name: "ok, empty prefix",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "",
|
||||
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
|
||||
},
|
||||
{
|
||||
name: "nok, no matching due different prefix",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||
},
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "Bearer ",
|
||||
expectError: errHeaderExtractorValueInvalid.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no matching due different prefix",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||
},
|
||||
whenName: echo.HeaderWWWAuthenticate,
|
||||
whenValuePrefix: "",
|
||||
expectError: errHeaderExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no headers",
|
||||
givenRequest: nil,
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectError: errHeaderExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "ok, prefix, cut values over extractorLimit",
|
||||
givenRequest: func(req *http.Request) {
|
||||
for i := 1; i <= 25; i++ {
|
||||
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
|
||||
}
|
||||
},
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "basic ",
|
||||
expectValues: []string{
|
||||
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, cut values over extractorLimit",
|
||||
givenRequest: func(req *http.Request) {
|
||||
for i := 1; i <= 25; i++ {
|
||||
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
|
||||
}
|
||||
},
|
||||
whenName: echo.HeaderAuthorization,
|
||||
whenValuePrefix: "",
|
||||
expectValues: []string{
|
||||
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequest != nil {
|
||||
tc.givenRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
|
||||
|
||||
values, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromQuery(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenQueryPart string
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenQueryPart: "?id=123&name=test",
|
||||
whenName: "id",
|
||||
expectValues: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenQueryPart: "?id=123&id=456&name=test",
|
||||
whenName: "id",
|
||||
expectValues: []string{"123", "456"},
|
||||
},
|
||||
{
|
||||
name: "nok, missing value",
|
||||
givenQueryPart: "?id=123&name=test",
|
||||
whenName: "nope",
|
||||
expectError: errQueryExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "ok, cut values over extractorLimit",
|
||||
givenQueryPart: "?name=test" +
|
||||
"&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
|
||||
"&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
|
||||
"&id=21&id=22&id=23&id=24&id=25",
|
||||
whenName: "id",
|
||||
expectValues: []string{
|
||||
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromQuery(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromParam(t *testing.T) {
|
||||
examplePathParams := []pathParam{
|
||||
{name: "id", value: "123"},
|
||||
{name: "gid", value: "456"},
|
||||
{name: "gid", value: "789"},
|
||||
}
|
||||
examplePathParams20 := make([]pathParam, 0)
|
||||
for i := 1; i < 25; i++ {
|
||||
examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)})
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenPathParams []pathParam
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenPathParams: examplePathParams,
|
||||
whenName: "id",
|
||||
expectValues: []string{"123"},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenPathParams: examplePathParams,
|
||||
whenName: "gid",
|
||||
expectValues: []string{"456", "789"},
|
||||
},
|
||||
{
|
||||
name: "nok, no values",
|
||||
givenPathParams: nil,
|
||||
whenName: "nope",
|
||||
expectValues: nil,
|
||||
expectError: errParamExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no matching value",
|
||||
givenPathParams: examplePathParams,
|
||||
whenName: "nope",
|
||||
expectValues: nil,
|
||||
expectError: errParamExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "ok, cut values over extractorLimit",
|
||||
givenPathParams: examplePathParams20,
|
||||
whenName: "id",
|
||||
expectValues: []string{
|
||||
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
if tc.givenPathParams != nil {
|
||||
setPathParams(c, tc.givenPathParams)
|
||||
}
|
||||
|
||||
extractor := valuesFromParam(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromCookie(t *testing.T) {
|
||||
exampleRequest := func(req *http.Request) {
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf=token")
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest func(req *http.Request)
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, single value",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: "_csrf",
|
||||
expectValues: []string{"token"},
|
||||
},
|
||||
{
|
||||
name: "ok, multiple value",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.Header.Add(echo.HeaderCookie, "_csrf=token")
|
||||
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
|
||||
},
|
||||
whenName: "_csrf",
|
||||
expectValues: []string{"token", "token2"},
|
||||
},
|
||||
{
|
||||
name: "nok, no matching cookie",
|
||||
givenRequest: exampleRequest,
|
||||
whenName: "xxx",
|
||||
expectValues: nil,
|
||||
expectError: errCookieExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, no cookies at all",
|
||||
givenRequest: nil,
|
||||
whenName: "xxx",
|
||||
expectValues: nil,
|
||||
expectError: errCookieExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "ok, cut values over extractorLimit",
|
||||
givenRequest: func(req *http.Request) {
|
||||
for i := 1; i < 25; i++ {
|
||||
req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
|
||||
}
|
||||
},
|
||||
whenName: "_csrf",
|
||||
expectValues: []string{
|
||||
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenRequest != nil {
|
||||
tc.givenRequest(req)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromCookie(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesFromForm(t *testing.T) {
|
||||
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
|
||||
f := make(url.Values)
|
||||
f.Set("name", "Jon Snow")
|
||||
f.Set("emails[]", "jon@labstack.com")
|
||||
if mod != nil {
|
||||
mod(&f)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
|
||||
return req
|
||||
}
|
||||
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
|
||||
f := make(url.Values)
|
||||
f.Set("name", "Jon Snow")
|
||||
f.Set("emails[]", "jon@labstack.com")
|
||||
if mod != nil {
|
||||
mod(&f)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
|
||||
return req
|
||||
}
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenRequest *http.Request
|
||||
whenName string
|
||||
expectValues []string
|
||||
expectError string
|
||||
}{
|
||||
{
|
||||
name: "ok, POST form, single value",
|
||||
givenRequest: examplePostFormRequest(nil),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, POST form, multiple value",
|
||||
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||
v.Add("emails[]", "snow@labstack.com")
|
||||
}),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, GET form, single value",
|
||||
givenRequest: exampleGetFormRequest(nil),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "ok, GET form, multiple value",
|
||||
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||
v.Add("emails[]", "snow@labstack.com")
|
||||
}),
|
||||
whenName: "emails[]",
|
||||
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||
},
|
||||
{
|
||||
name: "nok, POST form, value missing",
|
||||
givenRequest: examplePostFormRequest(nil),
|
||||
whenName: "nope",
|
||||
expectError: errFormExtractorValueMissing.Error(),
|
||||
},
|
||||
{
|
||||
name: "nok, POST form, form parsing error",
|
||||
givenRequest: func() *http.Request {
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
req.Body = nil
|
||||
return req
|
||||
}(),
|
||||
whenName: "name",
|
||||
expectError: "valuesFromForm parse form failed: missing form body",
|
||||
},
|
||||
{
|
||||
name: "ok, cut values over extractorLimit",
|
||||
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||
for i := 1; i < 25; i++ {
|
||||
v.Add("id[]", fmt.Sprintf("%v", i))
|
||||
}
|
||||
}),
|
||||
whenName: "id[]",
|
||||
expectValues: []string{
|
||||
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
req := tc.givenRequest
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
extractor := valuesFromForm(tc.whenName)
|
||||
|
||||
values, err := extractor(c)
|
||||
assert.Equal(t, tc.expectValues, values)
|
||||
if tc.expectError != "" {
|
||||
assert.EqualError(t, err, tc.expectError)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
//go:build go1.15
|
||||
// +build go1.15
|
||||
|
||||
package middleware
|
||||
@ -5,12 +6,10 @@ package middleware
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -22,7 +21,8 @@ type (
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
BeforeFunc BeforeFunc
|
||||
|
||||
// SuccessHandler defines a function which is executed for a valid token.
|
||||
// SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
|
||||
// middleware or handler.
|
||||
SuccessHandler JWTSuccessHandler
|
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid token.
|
||||
@ -32,6 +32,13 @@ type (
|
||||
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
|
||||
ErrorHandlerWithContext JWTErrorHandlerWithContext
|
||||
|
||||
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to
|
||||
// ignore the error (by returning `nil`).
|
||||
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
|
||||
// In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context
|
||||
// and continue. Some logic down the remaining execution chain needs to check that (public) token value then.
|
||||
ContinueOnIgnoredError bool
|
||||
|
||||
// Signing key to validate token.
|
||||
// This is one of the three options to provide a token validation key.
|
||||
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||
@ -61,20 +68,25 @@ type (
|
||||
// to extract token from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
|
||||
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
|
||||
// want to cut is `<auth-scheme> ` note the space at the end.
|
||||
// In case of JWT tokens `Authorization: Bearer <token>` prefix we cut is `Bearer `.
|
||||
// If prefix is left empty the whole value is returned.
|
||||
// - "query:<name>"
|
||||
// - "param:<name>"
|
||||
// - "cookie:<name>"
|
||||
// - "form:<name>"
|
||||
// Multiply sources example:
|
||||
// - "header: Authorization,cookie: myowncookie"
|
||||
// Multiple sources example:
|
||||
// - "header:Authorization,cookie:myowncookie"
|
||||
TokenLookup string
|
||||
|
||||
// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
|
||||
// This is one of the two options to provide a token extractor.
|
||||
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
|
||||
// You can also provide both if you want.
|
||||
TokenLookupFuncs []TokenLookupFunc
|
||||
TokenLookupFuncs []ValuesExtractor
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
@ -100,16 +112,13 @@ type (
|
||||
}
|
||||
|
||||
// JWTSuccessHandler defines a function which is executed for a valid token.
|
||||
JWTSuccessHandler func(echo.Context)
|
||||
JWTSuccessHandler func(c echo.Context)
|
||||
|
||||
// JWTErrorHandler defines a function which is executed for an invalid token.
|
||||
JWTErrorHandler func(error) error
|
||||
JWTErrorHandler func(err error) error
|
||||
|
||||
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
|
||||
JWTErrorHandlerWithContext func(error, echo.Context) error
|
||||
|
||||
// TokenLookupFunc defines a function for extracting JWT token from the given context.
|
||||
TokenLookupFunc func(echo.Context) (string, error)
|
||||
JWTErrorHandlerWithContext func(err error, c echo.Context) error
|
||||
)
|
||||
|
||||
// Algorithms
|
||||
@ -183,25 +192,12 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
config.ParseTokenFunc = config.defaultParseToken
|
||||
}
|
||||
|
||||
// Initialize
|
||||
// Split sources
|
||||
sources := strings.Split(config.TokenLookup, ",")
|
||||
var extractors = config.TokenLookupFuncs
|
||||
for _, source := range sources {
|
||||
parts := strings.Split(source, ":")
|
||||
|
||||
switch parts[0] {
|
||||
case "query":
|
||||
extractors = append(extractors, jwtFromQuery(parts[1]))
|
||||
case "param":
|
||||
extractors = append(extractors, jwtFromParam(parts[1]))
|
||||
case "cookie":
|
||||
extractors = append(extractors, jwtFromCookie(parts[1]))
|
||||
case "form":
|
||||
extractors = append(extractors, jwtFromForm(parts[1]))
|
||||
case "header":
|
||||
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
|
||||
}
|
||||
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(config.TokenLookupFuncs) > 0 {
|
||||
extractors = append(config.TokenLookupFuncs, extractors...)
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -213,48 +209,54 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
if config.BeforeFunc != nil {
|
||||
config.BeforeFunc(c)
|
||||
}
|
||||
var auth string
|
||||
var err error
|
||||
|
||||
var lastExtractorErr error
|
||||
var lastTokenErr error
|
||||
for _, extractor := range extractors {
|
||||
// Extract token from extractor, if it's not fail break the loop and
|
||||
// set auth
|
||||
auth, err = extractor(c)
|
||||
if err == nil {
|
||||
break
|
||||
auths, err := extractor(c)
|
||||
if err != nil {
|
||||
lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth)
|
||||
continue
|
||||
}
|
||||
for _, auth := range auths {
|
||||
token, err := config.ParseTokenFunc(auth, c)
|
||||
if err != nil {
|
||||
lastTokenErr = err
|
||||
continue
|
||||
}
|
||||
// Store user information from token into context.
|
||||
c.Set(config.ContextKey, token)
|
||||
if config.SuccessHandler != nil {
|
||||
config.SuccessHandler(c)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
// If none of extractor has a token, handle error
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err)
|
||||
}
|
||||
|
||||
if config.ErrorHandlerWithContext != nil {
|
||||
return config.ErrorHandlerWithContext(err, c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := config.ParseTokenFunc(auth, c)
|
||||
if err == nil {
|
||||
// Store user information from token into context.
|
||||
c.Set(config.ContextKey, token)
|
||||
if config.SuccessHandler != nil {
|
||||
config.SuccessHandler(c)
|
||||
}
|
||||
return next(c)
|
||||
// we are here only when we did not successfully extract or parse any of the tokens
|
||||
err := lastTokenErr
|
||||
if err == nil { // prioritize token errors over extracting errors
|
||||
err = lastExtractorErr
|
||||
}
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err)
|
||||
}
|
||||
if config.ErrorHandlerWithContext != nil {
|
||||
return config.ErrorHandlerWithContext(err, c)
|
||||
tmpErr := config.ErrorHandlerWithContext(err, c)
|
||||
if config.ContinueOnIgnoredError && tmpErr == nil {
|
||||
return next(c)
|
||||
}
|
||||
return tmpErr
|
||||
}
|
||||
return &echo.HTTPError{
|
||||
Code: ErrJWTInvalid.Code,
|
||||
Message: ErrJWTInvalid.Message,
|
||||
Internal: err,
|
||||
|
||||
// backwards compatible errors codes
|
||||
if lastTokenErr != nil {
|
||||
return &echo.HTTPError{
|
||||
Code: ErrJWTInvalid.Code,
|
||||
Message: ErrJWTInvalid.Message,
|
||||
Internal: err,
|
||||
}
|
||||
}
|
||||
return err // this is lastExtractorErr value
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -296,59 +298,3 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
|
||||
|
||||
return config.SigningKey, nil
|
||||
}
|
||||
|
||||
// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
|
||||
func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
|
||||
return func(c echo.Context) (string, error) {
|
||||
auth := c.Request().Header.Get(header)
|
||||
l := len(authScheme)
|
||||
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
|
||||
func jwtFromQuery(param string) TokenLookupFunc {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.QueryParam(param)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
|
||||
func jwtFromParam(param string) TokenLookupFunc {
|
||||
return func(c echo.Context) (string, error) {
|
||||
token := c.Param(param)
|
||||
if token == "" {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
|
||||
func jwtFromCookie(name string) TokenLookupFunc {
|
||||
return func(c echo.Context) (string, error) {
|
||||
cookie, err := c.Cookie(name)
|
||||
if err != nil {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return cookie.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
|
||||
func jwtFromForm(name string) TokenLookupFunc {
|
||||
return func(c echo.Context) (string, error) {
|
||||
field := c.FormValue(name)
|
||||
if field == "" {
|
||||
return "", ErrJWTMissing
|
||||
}
|
||||
return field, nil
|
||||
}
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
//go:build go1.15
|
||||
// +build go1.15
|
||||
|
||||
package middleware
|
||||
@ -28,6 +29,26 @@ type jwtCustomClaims struct {
|
||||
jwtCustomInfo
|
||||
}
|
||||
|
||||
func TestJWT(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
token := c.Get("user").(*jwt.Token)
|
||||
return c.JSON(http.StatusOK, token.Claims)
|
||||
})
|
||||
|
||||
e.Use(JWT([]byte("secret")))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
|
||||
}
|
||||
|
||||
func TestJWTRace(t *testing.T) {
|
||||
e := echo.New()
|
||||
handler := func(c echo.Context) error {
|
||||
@ -64,8 +85,7 @@ func TestJWTRace(t *testing.T) {
|
||||
assert.Equal(t, claims.Admin, true)
|
||||
}
|
||||
|
||||
func TestJWT(t *testing.T) {
|
||||
e := echo.New()
|
||||
func TestJWTConfig(t *testing.T) {
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
@ -74,7 +94,8 @@ func TestJWT(t *testing.T) {
|
||||
invalidKey := []byte("invalid-key")
|
||||
validAuth := DefaultJWTConfig.AuthScheme + " " + token
|
||||
|
||||
for _, tc := range []struct {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expPanic bool
|
||||
expErrCode int // 0 for Success
|
||||
config JWTConfig
|
||||
@ -82,166 +103,166 @@ func TestJWT(t *testing.T) {
|
||||
hdrAuth string
|
||||
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
|
||||
formValues map[string]string
|
||||
info string
|
||||
}{
|
||||
{
|
||||
name: "No signing key provided",
|
||||
expPanic: true,
|
||||
info: "No signing key provided",
|
||||
},
|
||||
{
|
||||
name: "Unexpected signing method",
|
||||
expErrCode: http.StatusBadRequest,
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
SigningMethod: "RS256",
|
||||
},
|
||||
info: "Unexpected signing method",
|
||||
},
|
||||
{
|
||||
name: "Invalid key",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{SigningKey: invalidKey},
|
||||
info: "Invalid key",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT",
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
info: "Valid JWT",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with custom AuthScheme",
|
||||
hdrAuth: "Token" + " " + token,
|
||||
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
|
||||
info: "Valid JWT with custom AuthScheme",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with custom claims",
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
Claims: &jwtCustomClaims{},
|
||||
SigningKey: []byte("secret"),
|
||||
},
|
||||
info: "Valid JWT with custom claims",
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization header",
|
||||
hdrAuth: "invalid-auth",
|
||||
expErrCode: http.StatusBadRequest,
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
info: "Invalid Authorization header",
|
||||
},
|
||||
{
|
||||
name: "Empty header auth field",
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty header auth field",
|
||||
},
|
||||
{
|
||||
name: "Valid query method",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b&jwt=" + token,
|
||||
info: "Valid query method",
|
||||
},
|
||||
{
|
||||
name: "Invalid query param name",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b&jwtxyz=" + token,
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Invalid query param name",
|
||||
},
|
||||
{
|
||||
name: "Invalid query param value",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b&jwt=invalid-token",
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
info: "Invalid query param value",
|
||||
},
|
||||
{
|
||||
name: "Empty query",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt",
|
||||
},
|
||||
reqURL: "/?a=b",
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty query",
|
||||
},
|
||||
{
|
||||
name: "Valid param method",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "param:jwt",
|
||||
},
|
||||
reqURL: "/" + token,
|
||||
info: "Valid param method",
|
||||
},
|
||||
{
|
||||
name: "Valid cookie method",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "cookie:jwt",
|
||||
},
|
||||
hdrCookie: "jwt=" + token,
|
||||
info: "Valid cookie method",
|
||||
},
|
||||
{
|
||||
name: "Multiple jwt lookuop",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "query:jwt,cookie:jwt",
|
||||
},
|
||||
hdrCookie: "jwt=" + token,
|
||||
info: "Multiple jwt lookuop",
|
||||
},
|
||||
{
|
||||
name: "Invalid token with cookie method",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "cookie:jwt",
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
hdrCookie: "jwt=invalid",
|
||||
info: "Invalid token with cookie method",
|
||||
},
|
||||
{
|
||||
name: "Empty cookie",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "cookie:jwt",
|
||||
},
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty cookie",
|
||||
},
|
||||
{
|
||||
name: "Valid form method",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
formValues: map[string]string{"jwt": token},
|
||||
info: "Valid form method",
|
||||
},
|
||||
{
|
||||
name: "Invalid token with form method",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
formValues: map[string]string{"jwt": "invalid"},
|
||||
info: "Invalid token with form method",
|
||||
},
|
||||
{
|
||||
name: "Empty form field",
|
||||
config: JWTConfig{
|
||||
SigningKey: validKey,
|
||||
TokenLookup: "form:jwt",
|
||||
},
|
||||
expErrCode: http.StatusBadRequest,
|
||||
info: "Empty form field",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with a valid key using a user-defined KeyFunc",
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||
return validKey, nil
|
||||
},
|
||||
},
|
||||
info: "Valid JWT with a valid key using a user-defined KeyFunc",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with an invalid key using a user-defined KeyFunc",
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||
@ -249,9 +270,9 @@ func TestJWT(t *testing.T) {
|
||||
},
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
|
||||
},
|
||||
{
|
||||
name: "Token verification does not pass using a user-defined KeyFunc",
|
||||
hdrAuth: validAuth,
|
||||
config: JWTConfig{
|
||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||
@ -259,67 +280,70 @@ func TestJWT(t *testing.T) {
|
||||
},
|
||||
},
|
||||
expErrCode: http.StatusUnauthorized,
|
||||
info: "Token verification does not pass using a user-defined KeyFunc",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with lower case AuthScheme",
|
||||
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
|
||||
config: JWTConfig{SigningKey: validKey},
|
||||
info: "Valid JWT with lower case AuthScheme",
|
||||
},
|
||||
} {
|
||||
if tc.reqURL == "" {
|
||||
tc.reqURL = "/"
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
if len(tc.formValues) > 0 {
|
||||
form := url.Values{}
|
||||
for k, v := range tc.formValues {
|
||||
form.Set(k, v)
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
if tc.reqURL == "" {
|
||||
tc.reqURL = "/"
|
||||
}
|
||||
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
|
||||
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
|
||||
req.ParseForm()
|
||||
} else {
|
||||
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
|
||||
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
if tc.reqURL == "/"+token {
|
||||
c.SetParamNames("jwt")
|
||||
c.SetParamValues(token)
|
||||
}
|
||||
var req *http.Request
|
||||
if len(tc.formValues) > 0 {
|
||||
form := url.Values{}
|
||||
for k, v := range tc.formValues {
|
||||
form.Set(k, v)
|
||||
}
|
||||
req = httptest.NewRequest(http.MethodPost, tc.reqURL, strings.NewReader(form.Encode()))
|
||||
req.Header.Set(echo.HeaderContentType, "application/x-www-form-urlencoded")
|
||||
req.ParseForm()
|
||||
} else {
|
||||
req = httptest.NewRequest(http.MethodGet, tc.reqURL, nil)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
|
||||
req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
if tc.expPanic {
|
||||
assert.Panics(t, func() {
|
||||
JWTWithConfig(tc.config)
|
||||
}, tc.info)
|
||||
continue
|
||||
}
|
||||
if tc.reqURL == "/"+token {
|
||||
c.SetParamNames("jwt")
|
||||
c.SetParamValues(token)
|
||||
}
|
||||
|
||||
if tc.expPanic {
|
||||
assert.Panics(t, func() {
|
||||
JWTWithConfig(tc.config)
|
||||
}, tc.name)
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expErrCode != 0 {
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, tc.expErrCode, he.Code, tc.name)
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expErrCode != 0 {
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
|
||||
continue
|
||||
}
|
||||
|
||||
h := JWTWithConfig(tc.config)(handler)
|
||||
if assert.NoError(t, h(c), tc.info) {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
switch claims := user.Claims.(type) {
|
||||
case jwt.MapClaims:
|
||||
assert.Equal(t, claims["name"], "John Doe", tc.info)
|
||||
case *jwtCustomClaims:
|
||||
assert.Equal(t, claims.Name, "John Doe", tc.info)
|
||||
assert.Equal(t, claims.Admin, true, tc.info)
|
||||
default:
|
||||
panic("unexpected type of claims")
|
||||
if assert.NoError(t, h(c), tc.name) {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
switch claims := user.Claims.(type) {
|
||||
case jwt.MapClaims:
|
||||
assert.Equal(t, claims["name"], "John Doe", tc.name)
|
||||
case *jwtCustomClaims:
|
||||
assert.Equal(t, claims.Name, "John Doe", tc.name)
|
||||
assert.Equal(t, claims.Admin, true, tc.name)
|
||||
default:
|
||||
panic("unexpected type of claims")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -608,13 +632,14 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
token := c.Get("user").(*jwt.Token)
|
||||
return c.JSON(http.StatusOK, token.Claims)
|
||||
})
|
||||
|
||||
e.Use(JWTWithConfig(JWTConfig{
|
||||
TokenLookupFuncs: []TokenLookupFunc{
|
||||
func(c echo.Context) (string, error) {
|
||||
return c.Request().Header.Get("X-API-Key"), nil
|
||||
TokenLookupFuncs: []ValuesExtractor{
|
||||
func(c echo.Context) ([]string, error) {
|
||||
return []string{c.Request().Header.Get("X-API-Key")}, nil
|
||||
},
|
||||
},
|
||||
SigningKey: []byte("secret"),
|
||||
@ -626,4 +651,129 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
|
||||
}
|
||||
|
||||
func TestJWTConfig_SuccessHandler(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
givenToken string
|
||||
expectCalled bool
|
||||
expectStatus int
|
||||
}{
|
||||
{
|
||||
name: "ok, success handler is called",
|
||||
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
|
||||
expectCalled: true,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "nok, success handler is not called",
|
||||
givenToken: "x.x.x",
|
||||
expectCalled: false,
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
token := c.Get("user").(*jwt.Token)
|
||||
return c.JSON(http.StatusOK, token.Claims)
|
||||
})
|
||||
|
||||
wasCalled := false
|
||||
e.Use(JWTWithConfig(JWTConfig{
|
||||
SuccessHandler: func(c echo.Context) {
|
||||
wasCalled = true
|
||||
},
|
||||
SigningKey: []byte("secret"),
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tc.expectCalled, wasCalled)
|
||||
assert.Equal(t, tc.expectStatus, res.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenContinueOnIgnoredError bool
|
||||
givenToken string
|
||||
expectStatus int
|
||||
expectBody string
|
||||
}{
|
||||
{
|
||||
name: "no error handler is called",
|
||||
whenContinueOnIgnoredError: true,
|
||||
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
|
||||
expectStatus: http.StatusTeapot,
|
||||
expectBody: "",
|
||||
},
|
||||
{
|
||||
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
|
||||
whenContinueOnIgnoredError: false,
|
||||
givenToken: "",
|
||||
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
|
||||
expectStatus: http.StatusOK,
|
||||
expectBody: "",
|
||||
},
|
||||
{
|
||||
name: "error handler is called for missing token",
|
||||
whenContinueOnIgnoredError: true,
|
||||
givenToken: "",
|
||||
expectStatus: http.StatusTeapot,
|
||||
expectBody: "public-token",
|
||||
},
|
||||
{
|
||||
name: "error handler is called for invalid token",
|
||||
whenContinueOnIgnoredError: true,
|
||||
givenToken: "x.x.x",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
expectBody: "{\"message\":\"Unauthorized\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
testValue, _ := c.Get("test").(string)
|
||||
return c.String(http.StatusTeapot, testValue)
|
||||
})
|
||||
|
||||
e.Use(JWTWithConfig(JWTConfig{
|
||||
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandlerWithContext: func(err error, c echo.Context) error {
|
||||
if err == ErrJWTMissing {
|
||||
c.Set("test", "public-token")
|
||||
return nil
|
||||
}
|
||||
return echo.ErrUnauthorized
|
||||
},
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenToken != "" {
|
||||
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tc.expectStatus, res.Code)
|
||||
assert.Equal(t, tc.expectBody, res.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -2,11 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -15,15 +12,21 @@ type (
|
||||
// Skipper defines a function to skip middleware.
|
||||
Skipper Skipper
|
||||
|
||||
// KeyLookup is a string in the form of "<source>:<name>" that is used
|
||||
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||
// to extract key from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>"
|
||||
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
|
||||
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
|
||||
// want to cut is `<auth-scheme> ` note the space at the end.
|
||||
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
|
||||
// - "query:<name>"
|
||||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
KeyLookup string `yaml:"key_lookup"`
|
||||
// Multiple sources example:
|
||||
// - "header:Authorization,header:X-Api-Key"
|
||||
KeyLookup string
|
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
@ -36,15 +39,20 @@ type (
|
||||
// ErrorHandler defines a function which is executed for an invalid key.
|
||||
// It may be used to define a custom error.
|
||||
ErrorHandler KeyAuthErrorHandler
|
||||
|
||||
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
|
||||
// ignore the error (by returning `nil`).
|
||||
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
|
||||
// In that case you can use ErrorHandler to set a default public key auth value in the request context
|
||||
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
|
||||
ContinueOnIgnoredError bool
|
||||
}
|
||||
|
||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||
KeyAuthValidator func(string, echo.Context) (bool, error)
|
||||
|
||||
keyExtractor func(echo.Context) (string, error)
|
||||
KeyAuthValidator func(auth string, c echo.Context) (bool, error)
|
||||
|
||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||
KeyAuthErrorHandler func(error, echo.Context) error
|
||||
KeyAuthErrorHandler func(err error, c echo.Context) error
|
||||
)
|
||||
|
||||
var (
|
||||
@ -56,6 +64,21 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
|
||||
type ErrKeyAuthMissing struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error returns errors text
|
||||
func (e *ErrKeyAuthMissing) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// Unwrap unwraps error
|
||||
func (e *ErrKeyAuthMissing) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// KeyAuth returns an KeyAuth middleware.
|
||||
//
|
||||
// For valid key it calls the next handler.
|
||||
@ -85,16 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
||||
panic("echo: key-auth middleware requires a validator function")
|
||||
}
|
||||
|
||||
// Initialize
|
||||
parts := strings.Split(config.KeyLookup, ":")
|
||||
extractor := keyFromHeader(parts[1], config.AuthScheme)
|
||||
switch parts[0] {
|
||||
case "query":
|
||||
extractor = keyFromQuery(parts[1])
|
||||
case "form":
|
||||
extractor = keyFromForm(parts[1])
|
||||
case "cookie":
|
||||
extractor = keyFromCookie(parts[1])
|
||||
extractors, err := createExtractors(config.KeyLookup, config.AuthScheme)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@ -103,79 +119,62 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Extract and verify key
|
||||
key, err := extractor(c)
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err, c)
|
||||
var lastExtractorErr error
|
||||
var lastValidatorErr error
|
||||
for _, extractor := range extractors {
|
||||
keys, err := extractor(c)
|
||||
if err != nil {
|
||||
lastExtractorErr = err
|
||||
continue
|
||||
}
|
||||
for _, key := range keys {
|
||||
valid, err := config.Validator(key, c)
|
||||
if err != nil {
|
||||
lastValidatorErr = err
|
||||
continue
|
||||
}
|
||||
if valid {
|
||||
return next(c)
|
||||
}
|
||||
lastValidatorErr = errors.New("invalid key")
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
valid, err := config.Validator(key, c)
|
||||
if err != nil {
|
||||
if config.ErrorHandler != nil {
|
||||
return config.ErrorHandler(err, c)
|
||||
|
||||
// we are here only when we did not successfully extract and validate any of keys
|
||||
err := lastValidatorErr
|
||||
if err == nil { // prioritize validator errors over extracting errors
|
||||
// ugly part to preserve backwards compatible errors. someone could rely on them
|
||||
if lastExtractorErr == errQueryExtractorValueMissing {
|
||||
err = errors.New("missing key in the query string")
|
||||
} else if lastExtractorErr == errCookieExtractorValueMissing {
|
||||
err = errors.New("missing key in cookies")
|
||||
} else if lastExtractorErr == errFormExtractorValueMissing {
|
||||
err = errors.New("missing key in the form")
|
||||
} else if lastExtractorErr == errHeaderExtractorValueMissing {
|
||||
err = errors.New("missing key in request header")
|
||||
} else if lastExtractorErr == errHeaderExtractorValueInvalid {
|
||||
err = errors.New("invalid key in the request header")
|
||||
} else {
|
||||
err = lastExtractorErr
|
||||
}
|
||||
err = &ErrKeyAuthMissing{Err: err}
|
||||
}
|
||||
|
||||
if config.ErrorHandler != nil {
|
||||
tmpErr := config.ErrorHandler(err, c)
|
||||
if config.ContinueOnIgnoredError && tmpErr == nil {
|
||||
return next(c)
|
||||
}
|
||||
return tmpErr
|
||||
}
|
||||
if lastValidatorErr != nil { // prioritize validator errors over extracting errors
|
||||
return &echo.HTTPError{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: "invalid key",
|
||||
Internal: err,
|
||||
Message: "Unauthorized",
|
||||
Internal: lastValidatorErr,
|
||||
}
|
||||
} else if valid {
|
||||
return next(c)
|
||||
}
|
||||
return echo.ErrUnauthorized
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
|
||||
func keyFromHeader(header string, authScheme string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
auth := c.Request().Header.Get(header)
|
||||
if auth == "" {
|
||||
return "", errors.New("missing key in request header")
|
||||
}
|
||||
if header == echo.HeaderAuthorization {
|
||||
l := len(authScheme)
|
||||
if len(auth) > l+1 && auth[:l] == authScheme {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", errors.New("invalid key in the request header")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
|
||||
func keyFromQuery(param string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key := c.QueryParam(param)
|
||||
if key == "" {
|
||||
return "", errors.New("missing key in the query string")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromForm returns a `keyExtractor` that extracts key from the form.
|
||||
func keyFromForm(param string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key := c.FormValue(param)
|
||||
if key == "" {
|
||||
return "", errors.New("missing key in the form")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
|
||||
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
|
||||
func keyFromCookie(cookieName string) keyExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
key, err := c.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("missing key in cookies: %w", err)
|
||||
}
|
||||
return key.Value, nil
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=Unauthorized",
|
||||
expectError: "code=401, message=Unauthorized, internal=invalid key",
|
||||
},
|
||||
{
|
||||
name: "nok, defaults, invalid scheme in header",
|
||||
@ -92,6 +92,17 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in request header",
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup from multiple places, query and header",
|
||||
givenRequest: func(req *http.Request) {
|
||||
req.URL.RawQuery = "key=invalid-key"
|
||||
req.Header.Set("API-Key", "valid-key")
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {
|
||||
conf.KeyLookup = "query:key,header:API-Key"
|
||||
},
|
||||
expectHandlerCalled: true,
|
||||
},
|
||||
{
|
||||
name: "ok, custom key lookup, header",
|
||||
givenRequest: func(req *http.Request) {
|
||||
@ -179,7 +190,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
conf.KeyLookup = "cookie:key"
|
||||
},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=400, message=missing key in cookies: http: named cookie not present",
|
||||
expectError: "code=400, message=missing key in cookies",
|
||||
},
|
||||
{
|
||||
name: "nok, custom errorHandler, error from extractor",
|
||||
@ -216,7 +227,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
},
|
||||
whenConfig: func(conf *KeyAuthConfig) {},
|
||||
expectHandlerCalled: false,
|
||||
expectError: "code=401, message=invalid key, internal=some user defined error",
|
||||
expectError: "code=401, message=Unauthorized, internal=some user defined error",
|
||||
},
|
||||
}
|
||||
|
||||
@ -257,3 +268,109 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
|
||||
assert.PanicsWithError(
|
||||
t,
|
||||
"extractor source for lookup could not be split into needed parts: a",
|
||||
func() {
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
KeyAuthWithConfig(KeyAuthConfig{
|
||||
Validator: testKeyValidator,
|
||||
KeyLookup: "a",
|
||||
})(handler)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
|
||||
assert.PanicsWithValue(
|
||||
t,
|
||||
"echo: key-auth middleware requires a validator function",
|
||||
func() {
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
KeyAuthWithConfig(KeyAuthConfig{
|
||||
Validator: nil,
|
||||
})(handler)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenContinueOnIgnoredError bool
|
||||
givenKey string
|
||||
expectStatus int
|
||||
expectBody string
|
||||
}{
|
||||
{
|
||||
name: "no error handler is called",
|
||||
whenContinueOnIgnoredError: true,
|
||||
givenKey: "valid-key",
|
||||
expectStatus: http.StatusTeapot,
|
||||
expectBody: "",
|
||||
},
|
||||
{
|
||||
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
|
||||
whenContinueOnIgnoredError: false,
|
||||
givenKey: "",
|
||||
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
|
||||
expectStatus: http.StatusOK,
|
||||
expectBody: "",
|
||||
},
|
||||
{
|
||||
name: "error handler is called for missing token",
|
||||
whenContinueOnIgnoredError: true,
|
||||
givenKey: "",
|
||||
expectStatus: http.StatusTeapot,
|
||||
expectBody: "public-auth",
|
||||
},
|
||||
{
|
||||
name: "error handler is called for invalid token",
|
||||
whenContinueOnIgnoredError: true,
|
||||
givenKey: "x.x.x",
|
||||
expectStatus: http.StatusUnauthorized,
|
||||
expectBody: "{\"message\":\"Unauthorized\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
testValue, _ := c.Get("test").(string)
|
||||
return c.String(http.StatusTeapot, testValue)
|
||||
})
|
||||
|
||||
e.Use(KeyAuthWithConfig(KeyAuthConfig{
|
||||
Validator: testKeyValidator,
|
||||
ErrorHandler: func(err error, c echo.Context) error {
|
||||
if _, ok := err.(*ErrKeyAuthMissing); ok {
|
||||
c.Set("test", "public-auth")
|
||||
return nil
|
||||
}
|
||||
return echo.ErrUnauthorized
|
||||
},
|
||||
KeyLookup: "header:X-API-Key",
|
||||
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.givenKey != "" {
|
||||
req.Header.Set("X-API-Key", tc.givenKey)
|
||||
}
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tc.expectStatus, res.Code)
|
||||
assert.Equal(t, tc.expectBody, res.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -12,10 +12,10 @@ import (
|
||||
type (
|
||||
// Skipper defines a function to skip middleware. Returning true skips processing
|
||||
// the middleware.
|
||||
Skipper func(echo.Context) bool
|
||||
Skipper func(c echo.Context) bool
|
||||
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
BeforeFunc func(echo.Context)
|
||||
BeforeFunc func(c echo.Context)
|
||||
)
|
||||
|
||||
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
||||
|
Loading…
Reference in New Issue
Block a user