mirror of
https://github.com/labstack/echo.git
synced 2025-01-26 03:20:08 +02:00
JWT, KeyAuth, CSRF multivalue extractors (#2060)
* CSRF, JWT, KeyAuth middleware support for multivalue value extractors * Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
This commit is contained in:
parent
9e9924d763
commit
4a1ccdfdc5
4
echo.go
4
echo.go
@ -111,10 +111,10 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MiddlewareFunc defines a function to process middleware.
|
// MiddlewareFunc defines a function to process middleware.
|
||||||
MiddlewareFunc func(HandlerFunc) HandlerFunc
|
MiddlewareFunc func(next HandlerFunc) HandlerFunc
|
||||||
|
|
||||||
// HandlerFunc defines a function to serve HTTP requests.
|
// HandlerFunc defines a function to serve HTTP requests.
|
||||||
HandlerFunc func(Context) error
|
HandlerFunc func(c Context) error
|
||||||
|
|
||||||
// HTTPErrorHandler is a centralized HTTP error handler.
|
// HTTPErrorHandler is a centralized HTTP error handler.
|
||||||
HTTPErrorHandler func(error, Context)
|
HTTPErrorHandler func(error, Context)
|
||||||
|
@ -2,9 +2,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
@ -21,13 +19,15 @@ type (
|
|||||||
TokenLength uint8 `yaml:"token_length"`
|
TokenLength uint8 `yaml:"token_length"`
|
||||||
// Optional. Default value 32.
|
// Optional. Default value 32.
|
||||||
|
|
||||||
// TokenLookup is a string in the form of "<source>:<key>" that is used
|
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||||
// to extract token from the request.
|
// to extract token from the request.
|
||||||
// Optional. Default value "header:X-CSRF-Token".
|
// Optional. Default value "header:X-CSRF-Token".
|
||||||
// Possible values:
|
// Possible values:
|
||||||
// - "header:<name>"
|
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||||
// - "form:<name>"
|
|
||||||
// - "query:<name>"
|
// - "query:<name>"
|
||||||
|
// - "form:<name>"
|
||||||
|
// Multiple sources example:
|
||||||
|
// - "header:X-CSRF-Token,query:csrf"
|
||||||
TokenLookup string `yaml:"token_lookup"`
|
TokenLookup string `yaml:"token_lookup"`
|
||||||
|
|
||||||
// Context key to store generated CSRF token into context.
|
// Context key to store generated CSRF token into context.
|
||||||
@ -62,12 +62,11 @@ type (
|
|||||||
// Optional. Default value SameSiteDefaultMode.
|
// Optional. Default value SameSiteDefaultMode.
|
||||||
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
|
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
|
|
||||||
// either a token or an error.
|
|
||||||
csrfTokenExtractor func(echo.Context) (string, error)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrCSRFInvalid is returned when CSRF check fails
|
||||||
|
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// DefaultCSRFConfig is the default CSRF middleware config.
|
// DefaultCSRFConfig is the default CSRF middleware config.
|
||||||
DefaultCSRFConfig = CSRFConfig{
|
DefaultCSRFConfig = CSRFConfig{
|
||||||
@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
config.CookieSecure = true
|
config.CookieSecure = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize
|
extractors, err := createExtractors(config.TokenLookup, "")
|
||||||
parts := strings.Split(config.TokenLookup, ":")
|
if err != nil {
|
||||||
extractor := csrfTokenFromHeader(parts[1])
|
panic(err)
|
||||||
switch parts[0] {
|
|
||||||
case "form":
|
|
||||||
extractor = csrfTokenFromForm(parts[1])
|
|
||||||
case "query":
|
|
||||||
extractor = csrfTokenFromQuery(parts[1])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
req := c.Request()
|
|
||||||
k, err := c.Cookie(config.CookieName)
|
|
||||||
token := ""
|
token := ""
|
||||||
|
if k, err := c.Cookie(config.CookieName); err != nil {
|
||||||
// Generate token
|
token = random.String(config.TokenLength) // Generate token
|
||||||
if err != nil {
|
|
||||||
token = random.String(config.TokenLength)
|
|
||||||
} else {
|
} else {
|
||||||
// Reuse token
|
token = k.Value // Reuse token
|
||||||
token = k.Value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch req.Method {
|
switch c.Request().Method {
|
||||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||||
default:
|
default:
|
||||||
// Validate token only for requests which are not defined as 'safe' by RFC7231
|
// Validate token only for requests which are not defined as 'safe' by RFC7231
|
||||||
clientToken, err := extractor(c)
|
var lastExtractorErr error
|
||||||
|
var lastTokenErr error
|
||||||
|
outer:
|
||||||
|
for _, extractor := range extractors {
|
||||||
|
clientTokens, err := extractor(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
lastExtractorErr = err
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if !validateCSRFToken(token, clientToken) {
|
|
||||||
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
|
for _, clientToken := range clientTokens {
|
||||||
|
if validateCSRFToken(token, clientToken) {
|
||||||
|
lastTokenErr = nil
|
||||||
|
lastExtractorErr = nil
|
||||||
|
break outer
|
||||||
|
}
|
||||||
|
lastTokenErr = ErrCSRFInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastTokenErr != nil {
|
||||||
|
return lastTokenErr
|
||||||
|
} else if lastExtractorErr != nil {
|
||||||
|
// ugly part to preserve backwards compatible errors. someone could rely on them
|
||||||
|
if lastExtractorErr == errQueryExtractorValueMissing {
|
||||||
|
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
|
||||||
|
} else if lastExtractorErr == errFormExtractorValueMissing {
|
||||||
|
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
|
||||||
|
} else if lastExtractorErr == errHeaderExtractorValueMissing {
|
||||||
|
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
|
||||||
|
} else {
|
||||||
|
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
|
||||||
|
}
|
||||||
|
return lastExtractorErr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
|
||||||
// provided request header.
|
|
||||||
func csrfTokenFromHeader(header string) csrfTokenExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
return c.Request().Header.Get(header), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
|
|
||||||
// provided form parameter.
|
|
||||||
func csrfTokenFromForm(param string) csrfTokenExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
token := c.FormValue(param)
|
|
||||||
if token == "" {
|
|
||||||
return "", errors.New("missing csrf token in the form parameter")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
|
|
||||||
// provided query parameter.
|
|
||||||
func csrfTokenFromQuery(param string) csrfTokenExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
token := c.QueryParam(param)
|
|
||||||
if token == "" {
|
|
||||||
return "", errors.New("missing csrf token in the query string")
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateCSRFToken(token, clientToken string) bool {
|
func validateCSRFToken(token, clientToken string) bool {
|
||||||
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
|
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -13,14 +12,205 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestCSRF_tokenExtractors(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
whenTokenLookup string
|
||||||
|
whenCookieName string
|
||||||
|
givenCSRFCookie string
|
||||||
|
givenMethod string
|
||||||
|
givenQueryTokens map[string][]string
|
||||||
|
givenFormTokens map[string][]string
|
||||||
|
givenHeaderTokens map[string][]string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, multiple token lookups sources, succeeds on last one",
|
||||||
|
whenTokenLookup: "header:X-CSRF-Token,form:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenHeaderTokens: map[string][]string{
|
||||||
|
echo.HeaderXCSRFToken: {"invalid_token"},
|
||||||
|
},
|
||||||
|
givenFormTokens: map[string][]string{
|
||||||
|
"csrf": {"token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, token from POST form",
|
||||||
|
whenTokenLookup: "form:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenFormTokens: map[string][]string{
|
||||||
|
"csrf": {"token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, token from POST form, second token passes",
|
||||||
|
whenTokenLookup: "form:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenFormTokens: map[string][]string{
|
||||||
|
"csrf": {"invalid", "token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, invalid token from POST form",
|
||||||
|
whenTokenLookup: "form:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenFormTokens: map[string][]string{
|
||||||
|
"csrf": {"invalid_token"},
|
||||||
|
},
|
||||||
|
expectError: "code=403, message=invalid csrf token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, missing token from POST form",
|
||||||
|
whenTokenLookup: "form:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenFormTokens: map[string][]string{},
|
||||||
|
expectError: "code=400, message=missing csrf token in the form parameter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, token from POST header",
|
||||||
|
whenTokenLookup: "", // will use defaults
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenHeaderTokens: map[string][]string{
|
||||||
|
echo.HeaderXCSRFToken: {"token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, token from POST header, second token passes",
|
||||||
|
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenHeaderTokens: map[string][]string{
|
||||||
|
echo.HeaderXCSRFToken: {"invalid", "token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, invalid token from POST header",
|
||||||
|
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenHeaderTokens: map[string][]string{
|
||||||
|
echo.HeaderXCSRFToken: {"invalid_token"},
|
||||||
|
},
|
||||||
|
expectError: "code=403, message=invalid csrf token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, missing token from POST header",
|
||||||
|
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPost,
|
||||||
|
givenHeaderTokens: map[string][]string{},
|
||||||
|
expectError: "code=400, message=missing csrf token in request header",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, token from PUT query param",
|
||||||
|
whenTokenLookup: "query:csrf-param",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPut,
|
||||||
|
givenQueryTokens: map[string][]string{
|
||||||
|
"csrf-param": {"token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, token from PUT query form, second token passes",
|
||||||
|
whenTokenLookup: "query:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPut,
|
||||||
|
givenQueryTokens: map[string][]string{
|
||||||
|
"csrf": {"invalid", "token"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, invalid token from PUT query form",
|
||||||
|
whenTokenLookup: "query:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPut,
|
||||||
|
givenQueryTokens: map[string][]string{
|
||||||
|
"csrf": {"invalid_token"},
|
||||||
|
},
|
||||||
|
expectError: "code=403, message=invalid csrf token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, missing token from PUT query form",
|
||||||
|
whenTokenLookup: "query:csrf",
|
||||||
|
givenCSRFCookie: "token",
|
||||||
|
givenMethod: http.MethodPut,
|
||||||
|
givenQueryTokens: map[string][]string{},
|
||||||
|
expectError: "code=400, message=missing csrf token in the query string",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
q := make(url.Values)
|
||||||
|
for queryParam, values := range tc.givenQueryTokens {
|
||||||
|
for _, v := range values {
|
||||||
|
q.Add(queryParam, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f := make(url.Values)
|
||||||
|
for formKey, values := range tc.givenFormTokens {
|
||||||
|
for _, v := range values {
|
||||||
|
f.Add(formKey, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var req *http.Request
|
||||||
|
switch tc.givenMethod {
|
||||||
|
case http.MethodGet:
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
|
||||||
|
case http.MethodPost, http.MethodPut:
|
||||||
|
req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
|
||||||
|
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||||
|
}
|
||||||
|
|
||||||
|
for header, values := range tc.givenHeaderTokens {
|
||||||
|
for _, v := range values {
|
||||||
|
req.Header.Add(header, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.givenCSRFCookie != "" {
|
||||||
|
req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
csrf := CSRFWithConfig(CSRFConfig{
|
||||||
|
TokenLookup: tc.whenTokenLookup,
|
||||||
|
CookieName: tc.whenCookieName,
|
||||||
|
})
|
||||||
|
|
||||||
|
h := csrf(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
})
|
||||||
|
|
||||||
|
err := h(c)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectError)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCSRF(t *testing.T) {
|
func TestCSRF(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
csrf := CSRFWithConfig(CSRFConfig{
|
csrf := CSRF()
|
||||||
TokenLength: 16,
|
|
||||||
})
|
|
||||||
h := csrf(func(c echo.Context) error {
|
h := csrf(func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, "test")
|
return c.String(http.StatusOK, "test")
|
||||||
})
|
})
|
||||||
@ -43,7 +233,7 @@ func TestCSRF(t *testing.T) {
|
|||||||
assert.Error(t, h(c))
|
assert.Error(t, h(c))
|
||||||
|
|
||||||
// Valid CSRF token
|
// Valid CSRF token
|
||||||
token := random.String(16)
|
token := random.String(32)
|
||||||
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
||||||
req.Header.Set(echo.HeaderXCSRFToken, token)
|
req.Header.Set(echo.HeaderXCSRFToken, token)
|
||||||
if assert.NoError(t, h(c)) {
|
if assert.NoError(t, h(c)) {
|
||||||
@ -51,38 +241,6 @@ func TestCSRF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFTokenFromForm(t *testing.T) {
|
|
||||||
f := make(url.Values)
|
|
||||||
f.Set("csrf", "token")
|
|
||||||
e := echo.New()
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
|
||||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
|
||||||
c := e.NewContext(req, nil)
|
|
||||||
token, err := csrfTokenFromForm("csrf")(c)
|
|
||||||
if assert.NoError(t, err) {
|
|
||||||
assert.Equal(t, "token", token)
|
|
||||||
}
|
|
||||||
_, err = csrfTokenFromForm("invalid")(c)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCSRFTokenFromQuery(t *testing.T) {
|
|
||||||
q := make(url.Values)
|
|
||||||
q.Set("csrf", "token")
|
|
||||||
e := echo.New()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
|
||||||
req.URL.RawQuery = q.Encode()
|
|
||||||
c := e.NewContext(req, nil)
|
|
||||||
token, err := csrfTokenFromQuery("csrf")(c)
|
|
||||||
if assert.NoError(t, err) {
|
|
||||||
assert.Equal(t, "token", token)
|
|
||||||
}
|
|
||||||
_, err = csrfTokenFromQuery("invalid")(c)
|
|
||||||
assert.Error(t, err)
|
|
||||||
csrfTokenFromQuery("csrf")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCSRFSetSameSiteMode(t *testing.T) {
|
func TestCSRFSetSameSiteMode(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
@ -135,7 +293,6 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
|
|||||||
|
|
||||||
r := h(c)
|
r := h(c)
|
||||||
assert.NoError(t, r)
|
assert.NoError(t, r)
|
||||||
fmt.Println(rec.Header()["Set-Cookie"])
|
|
||||||
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
|
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,3 +315,46 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
|
|||||||
assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
|
assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
|
||||||
assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
|
assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCSRFConfig_skipper(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
whenSkip bool
|
||||||
|
expectCookies int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "do skip",
|
||||||
|
whenSkip: true,
|
||||||
|
expectCookies: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "do not skip",
|
||||||
|
whenSkip: false,
|
||||||
|
expectCookies: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
csrf := CSRFWithConfig(CSRFConfig{
|
||||||
|
Skipper: func(c echo.Context) bool {
|
||||||
|
return tc.whenSkip
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
h := csrf(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
})
|
||||||
|
|
||||||
|
r := h(c)
|
||||||
|
assert.NoError(t, r)
|
||||||
|
cookie := rec.Header()["Set-Cookie"]
|
||||||
|
assert.Len(t, cookie, tc.expectCookies)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
184
middleware/extractor.go
Normal file
184
middleware/extractor.go
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"net/textproto"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
|
||||||
|
// attack vector
|
||||||
|
extractorLimit = 20
|
||||||
|
)
|
||||||
|
|
||||||
|
var errHeaderExtractorValueMissing = errors.New("missing value in request header")
|
||||||
|
var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
|
||||||
|
var errQueryExtractorValueMissing = errors.New("missing value in the query string")
|
||||||
|
var errParamExtractorValueMissing = errors.New("missing value in path params")
|
||||||
|
var errCookieExtractorValueMissing = errors.New("missing value in cookies")
|
||||||
|
var errFormExtractorValueMissing = errors.New("missing value in the form")
|
||||||
|
|
||||||
|
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
|
||||||
|
type ValuesExtractor func(c echo.Context) ([]string, error)
|
||||||
|
|
||||||
|
func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
|
||||||
|
if lookups == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
sources := strings.Split(lookups, ",")
|
||||||
|
var extractors = make([]ValuesExtractor, 0)
|
||||||
|
for _, source := range sources {
|
||||||
|
parts := strings.Split(source, ":")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parts[0] {
|
||||||
|
case "query":
|
||||||
|
extractors = append(extractors, valuesFromQuery(parts[1]))
|
||||||
|
case "param":
|
||||||
|
extractors = append(extractors, valuesFromParam(parts[1]))
|
||||||
|
case "cookie":
|
||||||
|
extractors = append(extractors, valuesFromCookie(parts[1]))
|
||||||
|
case "form":
|
||||||
|
extractors = append(extractors, valuesFromForm(parts[1]))
|
||||||
|
case "header":
|
||||||
|
prefix := ""
|
||||||
|
if len(parts) > 2 {
|
||||||
|
prefix = parts[2]
|
||||||
|
} else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
|
||||||
|
// backwards compatibility for JWT and KeyAuth:
|
||||||
|
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
|
||||||
|
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
|
||||||
|
// behaviour for default values and Authorization header.
|
||||||
|
prefix = authScheme
|
||||||
|
if !strings.HasSuffix(prefix, " ") {
|
||||||
|
prefix += " "
|
||||||
|
}
|
||||||
|
}
|
||||||
|
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return extractors, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromHeader returns a functions that extracts values from the request header.
|
||||||
|
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
|
||||||
|
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
|
||||||
|
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
|
||||||
|
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
|
||||||
|
// If prefix is left empty the whole value is returned.
|
||||||
|
func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
|
||||||
|
prefixLen := len(valuePrefix)
|
||||||
|
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
|
||||||
|
header = textproto.CanonicalMIMEHeaderKey(header)
|
||||||
|
return func(c echo.Context) ([]string, error) {
|
||||||
|
values := c.Request().Header.Values(header)
|
||||||
|
if len(values) == 0 {
|
||||||
|
return nil, errHeaderExtractorValueMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]string, 0)
|
||||||
|
for i, value := range values {
|
||||||
|
if prefixLen == 0 {
|
||||||
|
result = append(result, value)
|
||||||
|
if i >= extractorLimit-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
|
||||||
|
result = append(result, value[prefixLen:])
|
||||||
|
if i >= extractorLimit-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 {
|
||||||
|
if prefixLen > 0 {
|
||||||
|
return nil, errHeaderExtractorValueInvalid
|
||||||
|
}
|
||||||
|
return nil, errHeaderExtractorValueMissing
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromQuery returns a function that extracts values from the query string.
|
||||||
|
func valuesFromQuery(param string) ValuesExtractor {
|
||||||
|
return func(c echo.Context) ([]string, error) {
|
||||||
|
result := c.QueryParams()[param]
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, errQueryExtractorValueMissing
|
||||||
|
} else if len(result) > extractorLimit-1 {
|
||||||
|
result = result[:extractorLimit]
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromParam returns a function that extracts values from the url param string.
|
||||||
|
func valuesFromParam(param string) ValuesExtractor {
|
||||||
|
return func(c echo.Context) ([]string, error) {
|
||||||
|
result := make([]string, 0)
|
||||||
|
paramVales := c.ParamValues()
|
||||||
|
for i, p := range c.ParamNames() {
|
||||||
|
if param == p {
|
||||||
|
result = append(result, paramVales[i])
|
||||||
|
if i >= extractorLimit-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, errParamExtractorValueMissing
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromCookie returns a function that extracts values from the named cookie.
|
||||||
|
func valuesFromCookie(name string) ValuesExtractor {
|
||||||
|
return func(c echo.Context) ([]string, error) {
|
||||||
|
cookies := c.Cookies()
|
||||||
|
if len(cookies) == 0 {
|
||||||
|
return nil, errCookieExtractorValueMissing
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]string, 0)
|
||||||
|
for i, cookie := range cookies {
|
||||||
|
if name == cookie.Name {
|
||||||
|
result = append(result, cookie.Value)
|
||||||
|
if i >= extractorLimit-1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, errCookieExtractorValueMissing
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesFromForm returns a function that extracts values from the form field.
|
||||||
|
func valuesFromForm(name string) ValuesExtractor {
|
||||||
|
return func(c echo.Context) ([]string, error) {
|
||||||
|
if parseErr := c.Request().ParseForm(); parseErr != nil {
|
||||||
|
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
|
||||||
|
}
|
||||||
|
values := c.Request().Form[name]
|
||||||
|
if len(values) == 0 {
|
||||||
|
return nil, errFormExtractorValueMissing
|
||||||
|
}
|
||||||
|
if len(values) > extractorLimit-1 {
|
||||||
|
values = values[:extractorLimit]
|
||||||
|
}
|
||||||
|
result := append([]string{}, values...)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
587
middleware/extractor_test.go
Normal file
587
middleware/extractor_test.go
Normal file
@ -0,0 +1,587 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pathParam struct {
|
||||||
|
name string
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
|
||||||
|
func setPathParams(c echo.Context, params []pathParam) {
|
||||||
|
names := make([]string, 0, len(params))
|
||||||
|
values := make([]string, 0, len(params))
|
||||||
|
for _, pp := range params {
|
||||||
|
names = append(names, pp.name)
|
||||||
|
values = append(values, pp.value)
|
||||||
|
}
|
||||||
|
c.SetParamNames(names...)
|
||||||
|
c.SetParamValues(values...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateExtractors(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenRequest func() *http.Request
|
||||||
|
givenPathParams []pathParam
|
||||||
|
whenLoopups string
|
||||||
|
expectValues []string
|
||||||
|
expectCreateError string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, header",
|
||||||
|
givenRequest: func() *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
whenLoopups: "header:Authorization:Bearer ",
|
||||||
|
expectValues: []string{"token"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, form",
|
||||||
|
givenRequest: func() *http.Request {
|
||||||
|
f := make(url.Values)
|
||||||
|
f.Set("name", "Jon Snow")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||||
|
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
whenLoopups: "form:name",
|
||||||
|
expectValues: []string{"Jon Snow"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, cookie",
|
||||||
|
givenRequest: func() *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderCookie, "_csrf=token")
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
whenLoopups: "cookie:_csrf",
|
||||||
|
expectValues: []string{"token"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, param",
|
||||||
|
givenPathParams: []pathParam{
|
||||||
|
{name: "id", value: "123"},
|
||||||
|
},
|
||||||
|
whenLoopups: "param:id",
|
||||||
|
expectValues: []string{"123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, query",
|
||||||
|
givenRequest: func() *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
|
||||||
|
return req
|
||||||
|
},
|
||||||
|
whenLoopups: "query:id",
|
||||||
|
expectValues: []string{"999"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, invalid lookup",
|
||||||
|
whenLoopups: "query",
|
||||||
|
expectCreateError: "extractor source for lookup could not be split into needed parts: query",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if tc.givenRequest != nil {
|
||||||
|
req = tc.givenRequest()
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
if tc.givenPathParams != nil {
|
||||||
|
setPathParams(c, tc.givenPathParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
extractors, err := createExtractors(tc.whenLoopups, "")
|
||||||
|
if tc.expectCreateError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectCreateError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for _, e := range extractors {
|
||||||
|
values, eErr := e(c)
|
||||||
|
assert.Equal(t, tc.expectValues, values)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, eErr, tc.expectError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, eErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValuesFromHeader(t *testing.T) {
|
||||||
|
exampleRequest := func(req *http.Request) {
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenRequest func(req *http.Request)
|
||||||
|
whenName string
|
||||||
|
whenValuePrefix string
|
||||||
|
expectValues []string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, single value",
|
||||||
|
givenRequest: exampleRequest,
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "basic ",
|
||||||
|
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, single value, case insensitive",
|
||||||
|
givenRequest: exampleRequest,
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "Basic ",
|
||||||
|
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multiple value",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||||
|
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||||
|
},
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "basic ",
|
||||||
|
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, empty prefix",
|
||||||
|
givenRequest: exampleRequest,
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "",
|
||||||
|
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no matching due different prefix",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||||
|
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||||
|
},
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "Bearer ",
|
||||||
|
expectError: errHeaderExtractorValueInvalid.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no matching due different prefix",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
|
||||||
|
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
|
||||||
|
},
|
||||||
|
whenName: echo.HeaderWWWAuthenticate,
|
||||||
|
whenValuePrefix: "",
|
||||||
|
expectError: errHeaderExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no headers",
|
||||||
|
givenRequest: nil,
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "basic ",
|
||||||
|
expectError: errHeaderExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, prefix, cut values over extractorLimit",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
for i := 1; i <= 25; i++ {
|
||||||
|
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "basic ",
|
||||||
|
expectValues: []string{
|
||||||
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||||
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, cut values over extractorLimit",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
for i := 1; i <= 25; i++ {
|
||||||
|
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
whenName: echo.HeaderAuthorization,
|
||||||
|
whenValuePrefix: "",
|
||||||
|
expectValues: []string{
|
||||||
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||||
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if tc.givenRequest != nil {
|
||||||
|
tc.givenRequest(req)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
|
||||||
|
|
||||||
|
values, err := extractor(c)
|
||||||
|
assert.Equal(t, tc.expectValues, values)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectError)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValuesFromQuery(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenQueryPart string
|
||||||
|
whenName string
|
||||||
|
expectValues []string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, single value",
|
||||||
|
givenQueryPart: "?id=123&name=test",
|
||||||
|
whenName: "id",
|
||||||
|
expectValues: []string{"123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multiple value",
|
||||||
|
givenQueryPart: "?id=123&id=456&name=test",
|
||||||
|
whenName: "id",
|
||||||
|
expectValues: []string{"123", "456"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, missing value",
|
||||||
|
givenQueryPart: "?id=123&name=test",
|
||||||
|
whenName: "nope",
|
||||||
|
expectError: errQueryExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, cut values over extractorLimit",
|
||||||
|
givenQueryPart: "?name=test" +
|
||||||
|
"&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
|
||||||
|
"&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
|
||||||
|
"&id=21&id=22&id=23&id=24&id=25",
|
||||||
|
whenName: "id",
|
||||||
|
expectValues: []string{
|
||||||
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||||
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
extractor := valuesFromQuery(tc.whenName)
|
||||||
|
|
||||||
|
values, err := extractor(c)
|
||||||
|
assert.Equal(t, tc.expectValues, values)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectError)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValuesFromParam(t *testing.T) {
|
||||||
|
examplePathParams := []pathParam{
|
||||||
|
{name: "id", value: "123"},
|
||||||
|
{name: "gid", value: "456"},
|
||||||
|
{name: "gid", value: "789"},
|
||||||
|
}
|
||||||
|
examplePathParams20 := make([]pathParam, 0)
|
||||||
|
for i := 1; i < 25; i++ {
|
||||||
|
examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)})
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenPathParams []pathParam
|
||||||
|
whenName string
|
||||||
|
expectValues []string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, single value",
|
||||||
|
givenPathParams: examplePathParams,
|
||||||
|
whenName: "id",
|
||||||
|
expectValues: []string{"123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multiple value",
|
||||||
|
givenPathParams: examplePathParams,
|
||||||
|
whenName: "gid",
|
||||||
|
expectValues: []string{"456", "789"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no values",
|
||||||
|
givenPathParams: nil,
|
||||||
|
whenName: "nope",
|
||||||
|
expectValues: nil,
|
||||||
|
expectError: errParamExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no matching value",
|
||||||
|
givenPathParams: examplePathParams,
|
||||||
|
whenName: "nope",
|
||||||
|
expectValues: nil,
|
||||||
|
expectError: errParamExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, cut values over extractorLimit",
|
||||||
|
givenPathParams: examplePathParams20,
|
||||||
|
whenName: "id",
|
||||||
|
expectValues: []string{
|
||||||
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||||
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
if tc.givenPathParams != nil {
|
||||||
|
setPathParams(c, tc.givenPathParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor := valuesFromParam(tc.whenName)
|
||||||
|
|
||||||
|
values, err := extractor(c)
|
||||||
|
assert.Equal(t, tc.expectValues, values)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectError)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValuesFromCookie(t *testing.T) {
|
||||||
|
exampleRequest := func(req *http.Request) {
|
||||||
|
req.Header.Set(echo.HeaderCookie, "_csrf=token")
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenRequest func(req *http.Request)
|
||||||
|
whenName string
|
||||||
|
expectValues []string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, single value",
|
||||||
|
givenRequest: exampleRequest,
|
||||||
|
whenName: "_csrf",
|
||||||
|
expectValues: []string{"token"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, multiple value",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
req.Header.Add(echo.HeaderCookie, "_csrf=token")
|
||||||
|
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
|
||||||
|
},
|
||||||
|
whenName: "_csrf",
|
||||||
|
expectValues: []string{"token", "token2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no matching cookie",
|
||||||
|
givenRequest: exampleRequest,
|
||||||
|
whenName: "xxx",
|
||||||
|
expectValues: nil,
|
||||||
|
expectError: errCookieExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, no cookies at all",
|
||||||
|
givenRequest: nil,
|
||||||
|
whenName: "xxx",
|
||||||
|
expectValues: nil,
|
||||||
|
expectError: errCookieExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, cut values over extractorLimit",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
for i := 1; i < 25; i++ {
|
||||||
|
req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
whenName: "_csrf",
|
||||||
|
expectValues: []string{
|
||||||
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||||
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if tc.givenRequest != nil {
|
||||||
|
tc.givenRequest(req)
|
||||||
|
}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
extractor := valuesFromCookie(tc.whenName)
|
||||||
|
|
||||||
|
values, err := extractor(c)
|
||||||
|
assert.Equal(t, tc.expectValues, values)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectError)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValuesFromForm(t *testing.T) {
|
||||||
|
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
|
||||||
|
f := make(url.Values)
|
||||||
|
f.Set("name", "Jon Snow")
|
||||||
|
f.Set("emails[]", "jon@labstack.com")
|
||||||
|
if mod != nil {
|
||||||
|
mod(&f)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
|
||||||
|
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
|
||||||
|
f := make(url.Values)
|
||||||
|
f.Set("name", "Jon Snow")
|
||||||
|
f.Set("emails[]", "jon@labstack.com")
|
||||||
|
if mod != nil {
|
||||||
|
mod(&f)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenRequest *http.Request
|
||||||
|
whenName string
|
||||||
|
expectValues []string
|
||||||
|
expectError string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, POST form, single value",
|
||||||
|
givenRequest: examplePostFormRequest(nil),
|
||||||
|
whenName: "emails[]",
|
||||||
|
expectValues: []string{"jon@labstack.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, POST form, multiple value",
|
||||||
|
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||||
|
v.Add("emails[]", "snow@labstack.com")
|
||||||
|
}),
|
||||||
|
whenName: "emails[]",
|
||||||
|
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, GET form, single value",
|
||||||
|
givenRequest: exampleGetFormRequest(nil),
|
||||||
|
whenName: "emails[]",
|
||||||
|
expectValues: []string{"jon@labstack.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, GET form, multiple value",
|
||||||
|
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||||
|
v.Add("emails[]", "snow@labstack.com")
|
||||||
|
}),
|
||||||
|
whenName: "emails[]",
|
||||||
|
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, POST form, value missing",
|
||||||
|
givenRequest: examplePostFormRequest(nil),
|
||||||
|
whenName: "nope",
|
||||||
|
expectError: errFormExtractorValueMissing.Error(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, POST form, form parsing error",
|
||||||
|
givenRequest: func() *http.Request {
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
req.Body = nil
|
||||||
|
return req
|
||||||
|
}(),
|
||||||
|
whenName: "name",
|
||||||
|
expectError: "valuesFromForm parse form failed: missing form body",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, cut values over extractorLimit",
|
||||||
|
givenRequest: examplePostFormRequest(func(v *url.Values) {
|
||||||
|
for i := 1; i < 25; i++ {
|
||||||
|
v.Add("id[]", fmt.Sprintf("%v", i))
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
whenName: "id[]",
|
||||||
|
expectValues: []string{
|
||||||
|
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
|
||||||
|
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
req := tc.givenRequest
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
extractor := valuesFromForm(tc.whenName)
|
||||||
|
|
||||||
|
values, err := extractor(c)
|
||||||
|
assert.Equal(t, tc.expectValues, values)
|
||||||
|
if tc.expectError != "" {
|
||||||
|
assert.EqualError(t, err, tc.expectError)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
//go:build go1.15
|
||||||
// +build go1.15
|
// +build go1.15
|
||||||
|
|
||||||
package middleware
|
package middleware
|
||||||
@ -5,12 +6,10 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -22,7 +21,8 @@ type (
|
|||||||
// BeforeFunc defines a function which is executed just before the middleware.
|
// BeforeFunc defines a function which is executed just before the middleware.
|
||||||
BeforeFunc BeforeFunc
|
BeforeFunc BeforeFunc
|
||||||
|
|
||||||
// SuccessHandler defines a function which is executed for a valid token.
|
// SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
|
||||||
|
// middleware or handler.
|
||||||
SuccessHandler JWTSuccessHandler
|
SuccessHandler JWTSuccessHandler
|
||||||
|
|
||||||
// ErrorHandler defines a function which is executed for an invalid token.
|
// ErrorHandler defines a function which is executed for an invalid token.
|
||||||
@ -32,6 +32,13 @@ type (
|
|||||||
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
|
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
|
||||||
ErrorHandlerWithContext JWTErrorHandlerWithContext
|
ErrorHandlerWithContext JWTErrorHandlerWithContext
|
||||||
|
|
||||||
|
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to
|
||||||
|
// ignore the error (by returning `nil`).
|
||||||
|
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
|
||||||
|
// In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context
|
||||||
|
// and continue. Some logic down the remaining execution chain needs to check that (public) token value then.
|
||||||
|
ContinueOnIgnoredError bool
|
||||||
|
|
||||||
// Signing key to validate token.
|
// Signing key to validate token.
|
||||||
// This is one of the three options to provide a token validation key.
|
// This is one of the three options to provide a token validation key.
|
||||||
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||||
@ -61,12 +68,17 @@ type (
|
|||||||
// to extract token from the request.
|
// to extract token from the request.
|
||||||
// Optional. Default value "header:Authorization".
|
// Optional. Default value "header:Authorization".
|
||||||
// Possible values:
|
// Possible values:
|
||||||
// - "header:<name>"
|
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||||
|
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
|
||||||
|
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
|
||||||
|
// want to cut is `<auth-scheme> ` note the space at the end.
|
||||||
|
// In case of JWT tokens `Authorization: Bearer <token>` prefix we cut is `Bearer `.
|
||||||
|
// If prefix is left empty the whole value is returned.
|
||||||
// - "query:<name>"
|
// - "query:<name>"
|
||||||
// - "param:<name>"
|
// - "param:<name>"
|
||||||
// - "cookie:<name>"
|
// - "cookie:<name>"
|
||||||
// - "form:<name>"
|
// - "form:<name>"
|
||||||
// Multiply sources example:
|
// Multiple sources example:
|
||||||
// - "header:Authorization,cookie:myowncookie"
|
// - "header:Authorization,cookie:myowncookie"
|
||||||
TokenLookup string
|
TokenLookup string
|
||||||
|
|
||||||
@ -74,7 +86,7 @@ type (
|
|||||||
// This is one of the two options to provide a token extractor.
|
// This is one of the two options to provide a token extractor.
|
||||||
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
|
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
|
||||||
// You can also provide both if you want.
|
// You can also provide both if you want.
|
||||||
TokenLookupFuncs []TokenLookupFunc
|
TokenLookupFuncs []ValuesExtractor
|
||||||
|
|
||||||
// AuthScheme to be used in the Authorization header.
|
// AuthScheme to be used in the Authorization header.
|
||||||
// Optional. Default value "Bearer".
|
// Optional. Default value "Bearer".
|
||||||
@ -100,16 +112,13 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// JWTSuccessHandler defines a function which is executed for a valid token.
|
// JWTSuccessHandler defines a function which is executed for a valid token.
|
||||||
JWTSuccessHandler func(echo.Context)
|
JWTSuccessHandler func(c echo.Context)
|
||||||
|
|
||||||
// JWTErrorHandler defines a function which is executed for an invalid token.
|
// JWTErrorHandler defines a function which is executed for an invalid token.
|
||||||
JWTErrorHandler func(error) error
|
JWTErrorHandler func(err error) error
|
||||||
|
|
||||||
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
|
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
|
||||||
JWTErrorHandlerWithContext func(error, echo.Context) error
|
JWTErrorHandlerWithContext func(err error, c echo.Context) error
|
||||||
|
|
||||||
// TokenLookupFunc defines a function for extracting JWT token from the given context.
|
|
||||||
TokenLookupFunc func(echo.Context) (string, error)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Algorithms
|
// Algorithms
|
||||||
@ -183,25 +192,12 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
config.ParseTokenFunc = config.defaultParseToken
|
config.ParseTokenFunc = config.defaultParseToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize
|
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
|
||||||
// Split sources
|
if err != nil {
|
||||||
sources := strings.Split(config.TokenLookup, ",")
|
panic(err)
|
||||||
var extractors = config.TokenLookupFuncs
|
|
||||||
for _, source := range sources {
|
|
||||||
parts := strings.Split(source, ":")
|
|
||||||
|
|
||||||
switch parts[0] {
|
|
||||||
case "query":
|
|
||||||
extractors = append(extractors, jwtFromQuery(parts[1]))
|
|
||||||
case "param":
|
|
||||||
extractors = append(extractors, jwtFromParam(parts[1]))
|
|
||||||
case "cookie":
|
|
||||||
extractors = append(extractors, jwtFromCookie(parts[1]))
|
|
||||||
case "form":
|
|
||||||
extractors = append(extractors, jwtFromForm(parts[1]))
|
|
||||||
case "header":
|
|
||||||
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
|
|
||||||
}
|
}
|
||||||
|
if len(config.TokenLookupFuncs) > 0 {
|
||||||
|
extractors = append(config.TokenLookupFuncs, extractors...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
@ -213,30 +209,21 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
if config.BeforeFunc != nil {
|
if config.BeforeFunc != nil {
|
||||||
config.BeforeFunc(c)
|
config.BeforeFunc(c)
|
||||||
}
|
}
|
||||||
var auth string
|
|
||||||
var err error
|
var lastExtractorErr error
|
||||||
|
var lastTokenErr error
|
||||||
for _, extractor := range extractors {
|
for _, extractor := range extractors {
|
||||||
// Extract token from extractor, if it's not fail break the loop and
|
auths, err := extractor(c)
|
||||||
// set auth
|
|
||||||
auth, err = extractor(c)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If none of extractor has a token, handle error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if config.ErrorHandler != nil {
|
lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth)
|
||||||
return config.ErrorHandler(err)
|
continue
|
||||||
}
|
}
|
||||||
|
for _, auth := range auths {
|
||||||
if config.ErrorHandlerWithContext != nil {
|
|
||||||
return config.ErrorHandlerWithContext(err, c)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := config.ParseTokenFunc(auth, c)
|
token, err := config.ParseTokenFunc(auth, c)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
lastTokenErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Store user information from token into context.
|
// Store user information from token into context.
|
||||||
c.Set(config.ContextKey, token)
|
c.Set(config.ContextKey, token)
|
||||||
if config.SuccessHandler != nil {
|
if config.SuccessHandler != nil {
|
||||||
@ -244,18 +231,33 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
// we are here only when we did not successfully extract or parse any of the tokens
|
||||||
|
err := lastTokenErr
|
||||||
|
if err == nil { // prioritize token errors over extracting errors
|
||||||
|
err = lastExtractorErr
|
||||||
|
}
|
||||||
if config.ErrorHandler != nil {
|
if config.ErrorHandler != nil {
|
||||||
return config.ErrorHandler(err)
|
return config.ErrorHandler(err)
|
||||||
}
|
}
|
||||||
if config.ErrorHandlerWithContext != nil {
|
if config.ErrorHandlerWithContext != nil {
|
||||||
return config.ErrorHandlerWithContext(err, c)
|
tmpErr := config.ErrorHandlerWithContext(err, c)
|
||||||
|
if config.ContinueOnIgnoredError && tmpErr == nil {
|
||||||
|
return next(c)
|
||||||
}
|
}
|
||||||
|
return tmpErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// backwards compatible errors codes
|
||||||
|
if lastTokenErr != nil {
|
||||||
return &echo.HTTPError{
|
return &echo.HTTPError{
|
||||||
Code: ErrJWTInvalid.Code,
|
Code: ErrJWTInvalid.Code,
|
||||||
Message: ErrJWTInvalid.Message,
|
Message: ErrJWTInvalid.Message,
|
||||||
Internal: err,
|
Internal: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return err // this is lastExtractorErr value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,59 +298,3 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
|
|||||||
|
|
||||||
return config.SigningKey, nil
|
return config.SigningKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
|
|
||||||
func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
auth := c.Request().Header.Get(header)
|
|
||||||
l := len(authScheme)
|
|
||||||
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
|
|
||||||
return auth[l+1:], nil
|
|
||||||
}
|
|
||||||
return "", ErrJWTMissing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
|
|
||||||
func jwtFromQuery(param string) TokenLookupFunc {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
token := c.QueryParam(param)
|
|
||||||
if token == "" {
|
|
||||||
return "", ErrJWTMissing
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
|
|
||||||
func jwtFromParam(param string) TokenLookupFunc {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
token := c.Param(param)
|
|
||||||
if token == "" {
|
|
||||||
return "", ErrJWTMissing
|
|
||||||
}
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
|
|
||||||
func jwtFromCookie(name string) TokenLookupFunc {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
cookie, err := c.Cookie(name)
|
|
||||||
if err != nil {
|
|
||||||
return "", ErrJWTMissing
|
|
||||||
}
|
|
||||||
return cookie.Value, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
|
|
||||||
func jwtFromForm(name string) TokenLookupFunc {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
field := c.FormValue(name)
|
|
||||||
if field == "" {
|
|
||||||
return "", ErrJWTMissing
|
|
||||||
}
|
|
||||||
return field, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
//go:build go1.15
|
||||||
// +build go1.15
|
// +build go1.15
|
||||||
|
|
||||||
package middleware
|
package middleware
|
||||||
@ -28,6 +29,26 @@ type jwtCustomClaims struct {
|
|||||||
jwtCustomInfo
|
jwtCustomInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJWT(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
token := c.Get("user").(*jwt.Token)
|
||||||
|
return c.JSON(http.StatusOK, token.Claims)
|
||||||
|
})
|
||||||
|
|
||||||
|
e.Use(JWT([]byte("secret")))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, res.Code)
|
||||||
|
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestJWTRace(t *testing.T) {
|
func TestJWTRace(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
handler := func(c echo.Context) error {
|
handler := func(c echo.Context) error {
|
||||||
@ -64,8 +85,7 @@ func TestJWTRace(t *testing.T) {
|
|||||||
assert.Equal(t, claims.Admin, true)
|
assert.Equal(t, claims.Admin, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJWT(t *testing.T) {
|
func TestJWTConfig(t *testing.T) {
|
||||||
e := echo.New()
|
|
||||||
handler := func(c echo.Context) error {
|
handler := func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, "test")
|
return c.String(http.StatusOK, "test")
|
||||||
}
|
}
|
||||||
@ -74,7 +94,8 @@ func TestJWT(t *testing.T) {
|
|||||||
invalidKey := []byte("invalid-key")
|
invalidKey := []byte("invalid-key")
|
||||||
validAuth := DefaultJWTConfig.AuthScheme + " " + token
|
validAuth := DefaultJWTConfig.AuthScheme + " " + token
|
||||||
|
|
||||||
for _, tc := range []struct {
|
testCases := []struct {
|
||||||
|
name string
|
||||||
expPanic bool
|
expPanic bool
|
||||||
expErrCode int // 0 for Success
|
expErrCode int // 0 for Success
|
||||||
config JWTConfig
|
config JWTConfig
|
||||||
@ -82,166 +103,166 @@ func TestJWT(t *testing.T) {
|
|||||||
hdrAuth string
|
hdrAuth string
|
||||||
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
|
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
|
||||||
formValues map[string]string
|
formValues map[string]string
|
||||||
info string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
|
name: "No signing key provided",
|
||||||
expPanic: true,
|
expPanic: true,
|
||||||
info: "No signing key provided",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Unexpected signing method",
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
SigningMethod: "RS256",
|
SigningMethod: "RS256",
|
||||||
},
|
},
|
||||||
info: "Unexpected signing method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Invalid key",
|
||||||
expErrCode: http.StatusUnauthorized,
|
expErrCode: http.StatusUnauthorized,
|
||||||
hdrAuth: validAuth,
|
hdrAuth: validAuth,
|
||||||
config: JWTConfig{SigningKey: invalidKey},
|
config: JWTConfig{SigningKey: invalidKey},
|
||||||
info: "Invalid key",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid JWT",
|
||||||
hdrAuth: validAuth,
|
hdrAuth: validAuth,
|
||||||
config: JWTConfig{SigningKey: validKey},
|
config: JWTConfig{SigningKey: validKey},
|
||||||
info: "Valid JWT",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid JWT with custom AuthScheme",
|
||||||
hdrAuth: "Token" + " " + token,
|
hdrAuth: "Token" + " " + token,
|
||||||
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
|
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
|
||||||
info: "Valid JWT with custom AuthScheme",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid JWT with custom claims",
|
||||||
hdrAuth: validAuth,
|
hdrAuth: validAuth,
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
Claims: &jwtCustomClaims{},
|
Claims: &jwtCustomClaims{},
|
||||||
SigningKey: []byte("secret"),
|
SigningKey: []byte("secret"),
|
||||||
},
|
},
|
||||||
info: "Valid JWT with custom claims",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Invalid Authorization header",
|
||||||
hdrAuth: "invalid-auth",
|
hdrAuth: "invalid-auth",
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
config: JWTConfig{SigningKey: validKey},
|
config: JWTConfig{SigningKey: validKey},
|
||||||
info: "Invalid Authorization header",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Empty header auth field",
|
||||||
config: JWTConfig{SigningKey: validKey},
|
config: JWTConfig{SigningKey: validKey},
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
info: "Empty header auth field",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid query method",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "query:jwt",
|
TokenLookup: "query:jwt",
|
||||||
},
|
},
|
||||||
reqURL: "/?a=b&jwt=" + token,
|
reqURL: "/?a=b&jwt=" + token,
|
||||||
info: "Valid query method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Invalid query param name",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "query:jwt",
|
TokenLookup: "query:jwt",
|
||||||
},
|
},
|
||||||
reqURL: "/?a=b&jwtxyz=" + token,
|
reqURL: "/?a=b&jwtxyz=" + token,
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
info: "Invalid query param name",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Invalid query param value",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "query:jwt",
|
TokenLookup: "query:jwt",
|
||||||
},
|
},
|
||||||
reqURL: "/?a=b&jwt=invalid-token",
|
reqURL: "/?a=b&jwt=invalid-token",
|
||||||
expErrCode: http.StatusUnauthorized,
|
expErrCode: http.StatusUnauthorized,
|
||||||
info: "Invalid query param value",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Empty query",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "query:jwt",
|
TokenLookup: "query:jwt",
|
||||||
},
|
},
|
||||||
reqURL: "/?a=b",
|
reqURL: "/?a=b",
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
info: "Empty query",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid param method",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "param:jwt",
|
TokenLookup: "param:jwt",
|
||||||
},
|
},
|
||||||
reqURL: "/" + token,
|
reqURL: "/" + token,
|
||||||
info: "Valid param method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid cookie method",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "cookie:jwt",
|
TokenLookup: "cookie:jwt",
|
||||||
},
|
},
|
||||||
hdrCookie: "jwt=" + token,
|
hdrCookie: "jwt=" + token,
|
||||||
info: "Valid cookie method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Multiple jwt lookuop",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "query:jwt,cookie:jwt",
|
TokenLookup: "query:jwt,cookie:jwt",
|
||||||
},
|
},
|
||||||
hdrCookie: "jwt=" + token,
|
hdrCookie: "jwt=" + token,
|
||||||
info: "Multiple jwt lookuop",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Invalid token with cookie method",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "cookie:jwt",
|
TokenLookup: "cookie:jwt",
|
||||||
},
|
},
|
||||||
expErrCode: http.StatusUnauthorized,
|
expErrCode: http.StatusUnauthorized,
|
||||||
hdrCookie: "jwt=invalid",
|
hdrCookie: "jwt=invalid",
|
||||||
info: "Invalid token with cookie method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Empty cookie",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "cookie:jwt",
|
TokenLookup: "cookie:jwt",
|
||||||
},
|
},
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
info: "Empty cookie",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid form method",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "form:jwt",
|
TokenLookup: "form:jwt",
|
||||||
},
|
},
|
||||||
formValues: map[string]string{"jwt": token},
|
formValues: map[string]string{"jwt": token},
|
||||||
info: "Valid form method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Invalid token with form method",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "form:jwt",
|
TokenLookup: "form:jwt",
|
||||||
},
|
},
|
||||||
expErrCode: http.StatusUnauthorized,
|
expErrCode: http.StatusUnauthorized,
|
||||||
formValues: map[string]string{"jwt": "invalid"},
|
formValues: map[string]string{"jwt": "invalid"},
|
||||||
info: "Invalid token with form method",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Empty form field",
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
SigningKey: validKey,
|
SigningKey: validKey,
|
||||||
TokenLookup: "form:jwt",
|
TokenLookup: "form:jwt",
|
||||||
},
|
},
|
||||||
expErrCode: http.StatusBadRequest,
|
expErrCode: http.StatusBadRequest,
|
||||||
info: "Empty form field",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid JWT with a valid key using a user-defined KeyFunc",
|
||||||
hdrAuth: validAuth,
|
hdrAuth: validAuth,
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||||
return validKey, nil
|
return validKey, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
info: "Valid JWT with a valid key using a user-defined KeyFunc",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid JWT with an invalid key using a user-defined KeyFunc",
|
||||||
hdrAuth: validAuth,
|
hdrAuth: validAuth,
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||||
@ -249,9 +270,9 @@ func TestJWT(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
expErrCode: http.StatusUnauthorized,
|
expErrCode: http.StatusUnauthorized,
|
||||||
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Token verification does not pass using a user-defined KeyFunc",
|
||||||
hdrAuth: validAuth,
|
hdrAuth: validAuth,
|
||||||
config: JWTConfig{
|
config: JWTConfig{
|
||||||
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
KeyFunc: func(*jwt.Token) (interface{}, error) {
|
||||||
@ -259,14 +280,16 @@ func TestJWT(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
expErrCode: http.StatusUnauthorized,
|
expErrCode: http.StatusUnauthorized,
|
||||||
info: "Token verification does not pass using a user-defined KeyFunc",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
name: "Valid JWT with lower case AuthScheme",
|
||||||
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
|
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
|
||||||
config: JWTConfig{SigningKey: validKey},
|
config: JWTConfig{SigningKey: validKey},
|
||||||
info: "Valid JWT with lower case AuthScheme",
|
|
||||||
},
|
},
|
||||||
} {
|
}
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
if tc.reqURL == "" {
|
if tc.reqURL == "" {
|
||||||
tc.reqURL = "/"
|
tc.reqURL = "/"
|
||||||
}
|
}
|
||||||
@ -296,30 +319,31 @@ func TestJWT(t *testing.T) {
|
|||||||
if tc.expPanic {
|
if tc.expPanic {
|
||||||
assert.Panics(t, func() {
|
assert.Panics(t, func() {
|
||||||
JWTWithConfig(tc.config)
|
JWTWithConfig(tc.config)
|
||||||
}, tc.info)
|
}, tc.name)
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.expErrCode != 0 {
|
if tc.expErrCode != 0 {
|
||||||
h := JWTWithConfig(tc.config)(handler)
|
h := JWTWithConfig(tc.config)(handler)
|
||||||
he := h(c).(*echo.HTTPError)
|
he := h(c).(*echo.HTTPError)
|
||||||
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
|
assert.Equal(t, tc.expErrCode, he.Code, tc.name)
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h := JWTWithConfig(tc.config)(handler)
|
h := JWTWithConfig(tc.config)(handler)
|
||||||
if assert.NoError(t, h(c), tc.info) {
|
if assert.NoError(t, h(c), tc.name) {
|
||||||
user := c.Get("user").(*jwt.Token)
|
user := c.Get("user").(*jwt.Token)
|
||||||
switch claims := user.Claims.(type) {
|
switch claims := user.Claims.(type) {
|
||||||
case jwt.MapClaims:
|
case jwt.MapClaims:
|
||||||
assert.Equal(t, claims["name"], "John Doe", tc.info)
|
assert.Equal(t, claims["name"], "John Doe", tc.name)
|
||||||
case *jwtCustomClaims:
|
case *jwtCustomClaims:
|
||||||
assert.Equal(t, claims.Name, "John Doe", tc.info)
|
assert.Equal(t, claims.Name, "John Doe", tc.name)
|
||||||
assert.Equal(t, claims.Admin, true, tc.info)
|
assert.Equal(t, claims.Admin, true, tc.name)
|
||||||
default:
|
default:
|
||||||
panic("unexpected type of claims")
|
panic("unexpected type of claims")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -608,13 +632,14 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
|
|||||||
e := echo.New()
|
e := echo.New()
|
||||||
|
|
||||||
e.GET("/", func(c echo.Context) error {
|
e.GET("/", func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, "test")
|
token := c.Get("user").(*jwt.Token)
|
||||||
|
return c.JSON(http.StatusOK, token.Claims)
|
||||||
})
|
})
|
||||||
|
|
||||||
e.Use(JWTWithConfig(JWTConfig{
|
e.Use(JWTWithConfig(JWTConfig{
|
||||||
TokenLookupFuncs: []TokenLookupFunc{
|
TokenLookupFuncs: []ValuesExtractor{
|
||||||
func(c echo.Context) (string, error) {
|
func(c echo.Context) ([]string, error) {
|
||||||
return c.Request().Header.Get("X-API-Key"), nil
|
return []string{c.Request().Header.Get("X-API-Key")}, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SigningKey: []byte("secret"),
|
SigningKey: []byte("secret"),
|
||||||
@ -626,4 +651,129 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
|
|||||||
e.ServeHTTP(res, req)
|
e.ServeHTTP(res, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, res.Code)
|
assert.Equal(t, http.StatusOK, res.Code)
|
||||||
|
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTConfig_SuccessHandler(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
givenToken string
|
||||||
|
expectCalled bool
|
||||||
|
expectStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, success handler is called",
|
||||||
|
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
|
||||||
|
expectCalled: true,
|
||||||
|
expectStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nok, success handler is not called",
|
||||||
|
givenToken: "x.x.x",
|
||||||
|
expectCalled: false,
|
||||||
|
expectStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
token := c.Get("user").(*jwt.Token)
|
||||||
|
return c.JSON(http.StatusOK, token.Claims)
|
||||||
|
})
|
||||||
|
|
||||||
|
wasCalled := false
|
||||||
|
e.Use(JWTWithConfig(JWTConfig{
|
||||||
|
SuccessHandler: func(c echo.Context) {
|
||||||
|
wasCalled = true
|
||||||
|
},
|
||||||
|
SigningKey: []byte("secret"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectCalled, wasCalled)
|
||||||
|
assert.Equal(t, tc.expectStatus, res.Code)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
whenContinueOnIgnoredError bool
|
||||||
|
givenToken string
|
||||||
|
expectStatus int
|
||||||
|
expectBody string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no error handler is called",
|
||||||
|
whenContinueOnIgnoredError: true,
|
||||||
|
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
|
||||||
|
expectStatus: http.StatusTeapot,
|
||||||
|
expectBody: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
|
||||||
|
whenContinueOnIgnoredError: false,
|
||||||
|
givenToken: "",
|
||||||
|
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
|
||||||
|
expectStatus: http.StatusOK,
|
||||||
|
expectBody: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error handler is called for missing token",
|
||||||
|
whenContinueOnIgnoredError: true,
|
||||||
|
givenToken: "",
|
||||||
|
expectStatus: http.StatusTeapot,
|
||||||
|
expectBody: "public-token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error handler is called for invalid token",
|
||||||
|
whenContinueOnIgnoredError: true,
|
||||||
|
givenToken: "x.x.x",
|
||||||
|
expectStatus: http.StatusUnauthorized,
|
||||||
|
expectBody: "{\"message\":\"Unauthorized\"}\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
testValue, _ := c.Get("test").(string)
|
||||||
|
return c.String(http.StatusTeapot, testValue)
|
||||||
|
})
|
||||||
|
|
||||||
|
e.Use(JWTWithConfig(JWTConfig{
|
||||||
|
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
|
||||||
|
SigningKey: []byte("secret"),
|
||||||
|
ErrorHandlerWithContext: func(err error, c echo.Context) error {
|
||||||
|
if err == ErrJWTMissing {
|
||||||
|
c.Set("test", "public-token")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return echo.ErrUnauthorized
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if tc.givenToken != "" {
|
||||||
|
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
|
||||||
|
}
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectStatus, res.Code)
|
||||||
|
assert.Equal(t, tc.expectBody, res.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,8 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -15,15 +12,21 @@ type (
|
|||||||
// Skipper defines a function to skip middleware.
|
// Skipper defines a function to skip middleware.
|
||||||
Skipper Skipper
|
Skipper Skipper
|
||||||
|
|
||||||
// KeyLookup is a string in the form of "<source>:<name>" that is used
|
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||||
// to extract key from the request.
|
// to extract key from the request.
|
||||||
// Optional. Default value "header:Authorization".
|
// Optional. Default value "header:Authorization".
|
||||||
// Possible values:
|
// Possible values:
|
||||||
// - "header:<name>"
|
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||||
|
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
|
||||||
|
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
|
||||||
|
// want to cut is `<auth-scheme> ` note the space at the end.
|
||||||
|
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
|
||||||
// - "query:<name>"
|
// - "query:<name>"
|
||||||
// - "form:<name>"
|
// - "form:<name>"
|
||||||
// - "cookie:<name>"
|
// - "cookie:<name>"
|
||||||
KeyLookup string `yaml:"key_lookup"`
|
// Multiple sources example:
|
||||||
|
// - "header:Authorization,header:X-Api-Key"
|
||||||
|
KeyLookup string
|
||||||
|
|
||||||
// AuthScheme to be used in the Authorization header.
|
// AuthScheme to be used in the Authorization header.
|
||||||
// Optional. Default value "Bearer".
|
// Optional. Default value "Bearer".
|
||||||
@ -36,15 +39,20 @@ type (
|
|||||||
// ErrorHandler defines a function which is executed for an invalid key.
|
// ErrorHandler defines a function which is executed for an invalid key.
|
||||||
// It may be used to define a custom error.
|
// It may be used to define a custom error.
|
||||||
ErrorHandler KeyAuthErrorHandler
|
ErrorHandler KeyAuthErrorHandler
|
||||||
|
|
||||||
|
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
|
||||||
|
// ignore the error (by returning `nil`).
|
||||||
|
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
|
||||||
|
// In that case you can use ErrorHandler to set a default public key auth value in the request context
|
||||||
|
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
|
||||||
|
ContinueOnIgnoredError bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||||
KeyAuthValidator func(string, echo.Context) (bool, error)
|
KeyAuthValidator func(auth string, c echo.Context) (bool, error)
|
||||||
|
|
||||||
keyExtractor func(echo.Context) (string, error)
|
|
||||||
|
|
||||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||||
KeyAuthErrorHandler func(error, echo.Context) error
|
KeyAuthErrorHandler func(err error, c echo.Context) error
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -56,6 +64,21 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
|
||||||
|
type ErrKeyAuthMissing struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns errors text
|
||||||
|
func (e *ErrKeyAuthMissing) Error() string {
|
||||||
|
return e.Err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap unwraps error
|
||||||
|
func (e *ErrKeyAuthMissing) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
// KeyAuth returns an KeyAuth middleware.
|
// KeyAuth returns an KeyAuth middleware.
|
||||||
//
|
//
|
||||||
// For valid key it calls the next handler.
|
// For valid key it calls the next handler.
|
||||||
@ -85,16 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
|||||||
panic("echo: key-auth middleware requires a validator function")
|
panic("echo: key-auth middleware requires a validator function")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize
|
extractors, err := createExtractors(config.KeyLookup, config.AuthScheme)
|
||||||
parts := strings.Split(config.KeyLookup, ":")
|
if err != nil {
|
||||||
extractor := keyFromHeader(parts[1], config.AuthScheme)
|
panic(err)
|
||||||
switch parts[0] {
|
|
||||||
case "query":
|
|
||||||
extractor = keyFromQuery(parts[1])
|
|
||||||
case "form":
|
|
||||||
extractor = keyFromForm(parts[1])
|
|
||||||
case "cookie":
|
|
||||||
extractor = keyFromCookie(parts[1])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
@ -103,79 +119,62 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
|
|||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and verify key
|
var lastExtractorErr error
|
||||||
key, err := extractor(c)
|
var lastValidatorErr error
|
||||||
|
for _, extractor := range extractors {
|
||||||
|
keys, err := extractor(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
lastExtractorErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
valid, err := config.Validator(key, c)
|
||||||
|
if err != nil {
|
||||||
|
lastValidatorErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if valid {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
lastValidatorErr = errors.New("invalid key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we are here only when we did not successfully extract and validate any of keys
|
||||||
|
err := lastValidatorErr
|
||||||
|
if err == nil { // prioritize validator errors over extracting errors
|
||||||
|
// ugly part to preserve backwards compatible errors. someone could rely on them
|
||||||
|
if lastExtractorErr == errQueryExtractorValueMissing {
|
||||||
|
err = errors.New("missing key in the query string")
|
||||||
|
} else if lastExtractorErr == errCookieExtractorValueMissing {
|
||||||
|
err = errors.New("missing key in cookies")
|
||||||
|
} else if lastExtractorErr == errFormExtractorValueMissing {
|
||||||
|
err = errors.New("missing key in the form")
|
||||||
|
} else if lastExtractorErr == errHeaderExtractorValueMissing {
|
||||||
|
err = errors.New("missing key in request header")
|
||||||
|
} else if lastExtractorErr == errHeaderExtractorValueInvalid {
|
||||||
|
err = errors.New("invalid key in the request header")
|
||||||
|
} else {
|
||||||
|
err = lastExtractorErr
|
||||||
|
}
|
||||||
|
err = &ErrKeyAuthMissing{Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
if config.ErrorHandler != nil {
|
if config.ErrorHandler != nil {
|
||||||
return config.ErrorHandler(err, c)
|
tmpErr := config.ErrorHandler(err, c)
|
||||||
|
if config.ContinueOnIgnoredError && tmpErr == nil {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
return tmpErr
|
||||||
|
}
|
||||||
|
if lastValidatorErr != nil { // prioritize validator errors over extracting errors
|
||||||
|
return &echo.HTTPError{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
Message: "Unauthorized",
|
||||||
|
Internal: lastValidatorErr,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
}
|
}
|
||||||
valid, err := config.Validator(key, c)
|
|
||||||
if err != nil {
|
|
||||||
if config.ErrorHandler != nil {
|
|
||||||
return config.ErrorHandler(err, c)
|
|
||||||
}
|
|
||||||
return &echo.HTTPError{
|
|
||||||
Code: http.StatusUnauthorized,
|
|
||||||
Message: "invalid key",
|
|
||||||
Internal: err,
|
|
||||||
}
|
|
||||||
} else if valid {
|
|
||||||
return next(c)
|
|
||||||
}
|
|
||||||
return echo.ErrUnauthorized
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
|
|
||||||
func keyFromHeader(header string, authScheme string) keyExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
auth := c.Request().Header.Get(header)
|
|
||||||
if auth == "" {
|
|
||||||
return "", errors.New("missing key in request header")
|
|
||||||
}
|
|
||||||
if header == echo.HeaderAuthorization {
|
|
||||||
l := len(authScheme)
|
|
||||||
if len(auth) > l+1 && auth[:l] == authScheme {
|
|
||||||
return auth[l+1:], nil
|
|
||||||
}
|
|
||||||
return "", errors.New("invalid key in the request header")
|
|
||||||
}
|
|
||||||
return auth, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
|
|
||||||
func keyFromQuery(param string) keyExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
key := c.QueryParam(param)
|
|
||||||
if key == "" {
|
|
||||||
return "", errors.New("missing key in the query string")
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keyFromForm returns a `keyExtractor` that extracts key from the form.
|
|
||||||
func keyFromForm(param string) keyExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
key := c.FormValue(param)
|
|
||||||
if key == "" {
|
|
||||||
return "", errors.New("missing key in the form")
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
|
|
||||||
func keyFromCookie(cookieName string) keyExtractor {
|
|
||||||
return func(c echo.Context) (string, error) {
|
|
||||||
key, err := c.Cookie(cookieName)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("missing key in cookies: %w", err)
|
|
||||||
}
|
|
||||||
return key.Value, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
|||||||
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
|
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
|
||||||
},
|
},
|
||||||
expectHandlerCalled: false,
|
expectHandlerCalled: false,
|
||||||
expectError: "code=401, message=Unauthorized",
|
expectError: "code=401, message=Unauthorized, internal=invalid key",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nok, defaults, invalid scheme in header",
|
name: "nok, defaults, invalid scheme in header",
|
||||||
@ -92,6 +92,17 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
|||||||
expectHandlerCalled: false,
|
expectHandlerCalled: false,
|
||||||
expectError: "code=400, message=missing key in request header",
|
expectError: "code=400, message=missing key in request header",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "ok, custom key lookup from multiple places, query and header",
|
||||||
|
givenRequest: func(req *http.Request) {
|
||||||
|
req.URL.RawQuery = "key=invalid-key"
|
||||||
|
req.Header.Set("API-Key", "valid-key")
|
||||||
|
},
|
||||||
|
whenConfig: func(conf *KeyAuthConfig) {
|
||||||
|
conf.KeyLookup = "query:key,header:API-Key"
|
||||||
|
},
|
||||||
|
expectHandlerCalled: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "ok, custom key lookup, header",
|
name: "ok, custom key lookup, header",
|
||||||
givenRequest: func(req *http.Request) {
|
givenRequest: func(req *http.Request) {
|
||||||
@ -179,7 +190,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
|||||||
conf.KeyLookup = "cookie:key"
|
conf.KeyLookup = "cookie:key"
|
||||||
},
|
},
|
||||||
expectHandlerCalled: false,
|
expectHandlerCalled: false,
|
||||||
expectError: "code=400, message=missing key in cookies: http: named cookie not present",
|
expectError: "code=400, message=missing key in cookies",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nok, custom errorHandler, error from extractor",
|
name: "nok, custom errorHandler, error from extractor",
|
||||||
@ -216,7 +227,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
whenConfig: func(conf *KeyAuthConfig) {},
|
whenConfig: func(conf *KeyAuthConfig) {},
|
||||||
expectHandlerCalled: false,
|
expectHandlerCalled: false,
|
||||||
expectError: "code=401, message=invalid key, internal=some user defined error",
|
expectError: "code=401, message=Unauthorized, internal=some user defined error",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -257,3 +268,109 @@ func TestKeyAuthWithConfig(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
|
||||||
|
assert.PanicsWithError(
|
||||||
|
t,
|
||||||
|
"extractor source for lookup could not be split into needed parts: a",
|
||||||
|
func() {
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
KeyAuthWithConfig(KeyAuthConfig{
|
||||||
|
Validator: testKeyValidator,
|
||||||
|
KeyLookup: "a",
|
||||||
|
})(handler)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
|
||||||
|
assert.PanicsWithValue(
|
||||||
|
t,
|
||||||
|
"echo: key-auth middleware requires a validator function",
|
||||||
|
func() {
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
KeyAuthWithConfig(KeyAuthConfig{
|
||||||
|
Validator: nil,
|
||||||
|
})(handler)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
whenContinueOnIgnoredError bool
|
||||||
|
givenKey string
|
||||||
|
expectStatus int
|
||||||
|
expectBody string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no error handler is called",
|
||||||
|
whenContinueOnIgnoredError: true,
|
||||||
|
givenKey: "valid-key",
|
||||||
|
expectStatus: http.StatusTeapot,
|
||||||
|
expectBody: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
|
||||||
|
whenContinueOnIgnoredError: false,
|
||||||
|
givenKey: "",
|
||||||
|
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
|
||||||
|
expectStatus: http.StatusOK,
|
||||||
|
expectBody: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error handler is called for missing token",
|
||||||
|
whenContinueOnIgnoredError: true,
|
||||||
|
givenKey: "",
|
||||||
|
expectStatus: http.StatusTeapot,
|
||||||
|
expectBody: "public-auth",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error handler is called for invalid token",
|
||||||
|
whenContinueOnIgnoredError: true,
|
||||||
|
givenKey: "x.x.x",
|
||||||
|
expectStatus: http.StatusUnauthorized,
|
||||||
|
expectBody: "{\"message\":\"Unauthorized\"}\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
e.GET("/", func(c echo.Context) error {
|
||||||
|
testValue, _ := c.Get("test").(string)
|
||||||
|
return c.String(http.StatusTeapot, testValue)
|
||||||
|
})
|
||||||
|
|
||||||
|
e.Use(KeyAuthWithConfig(KeyAuthConfig{
|
||||||
|
Validator: testKeyValidator,
|
||||||
|
ErrorHandler: func(err error, c echo.Context) error {
|
||||||
|
if _, ok := err.(*ErrKeyAuthMissing); ok {
|
||||||
|
c.Set("test", "public-auth")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return echo.ErrUnauthorized
|
||||||
|
},
|
||||||
|
KeyLookup: "header:X-API-Key",
|
||||||
|
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if tc.givenKey != "" {
|
||||||
|
req.Header.Set("X-API-Key", tc.givenKey)
|
||||||
|
}
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
|
||||||
|
e.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.expectStatus, res.Code)
|
||||||
|
assert.Equal(t, tc.expectBody, res.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -12,10 +12,10 @@ import (
|
|||||||
type (
|
type (
|
||||||
// Skipper defines a function to skip middleware. Returning true skips processing
|
// Skipper defines a function to skip middleware. Returning true skips processing
|
||||||
// the middleware.
|
// the middleware.
|
||||||
Skipper func(echo.Context) bool
|
Skipper func(c echo.Context) bool
|
||||||
|
|
||||||
// BeforeFunc defines a function which is executed just before the middleware.
|
// BeforeFunc defines a function which is executed just before the middleware.
|
||||||
BeforeFunc func(echo.Context)
|
BeforeFunc func(c echo.Context)
|
||||||
)
|
)
|
||||||
|
|
||||||
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user