1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-26 03:20:08 +02:00

JWT, KeyAuth, CSRF multivalue extractors (#2060)

* CSRF, JWT, KeyAuth middleware support for multivalue value extractors
* Add flag to JWT and KeyAuth middleware to allow continuing execution `next(c)` when error handler decides to swallow the error (returns nil).
This commit is contained in:
Martti T 2022-01-24 22:03:45 +02:00 committed by GitHub
parent 9e9924d763
commit 4a1ccdfdc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1564 additions and 397 deletions

View File

@ -111,10 +111,10 @@ type (
}
// MiddlewareFunc defines a function to process middleware.
MiddlewareFunc func(HandlerFunc) HandlerFunc
MiddlewareFunc func(next HandlerFunc) HandlerFunc
// HandlerFunc defines a function to serve HTTP requests.
HandlerFunc func(Context) error
HandlerFunc func(c Context) error
// HTTPErrorHandler is a centralized HTTP error handler.
HTTPErrorHandler func(error, Context)

View File

@ -2,9 +2,7 @@ package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"time"
"github.com/labstack/echo/v4"
@ -21,13 +19,15 @@ type (
TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32.
// TokenLookup is a string in the form of "<source>:<key>" that is used
// TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
// - "header:<name>"
// - "form:<name>"
// - "header:<name>" or "header:<name>:<cut-prefix>"
// - "query:<name>"
// - "form:<name>"
// Multiple sources example:
// - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"`
// Context key to store generated CSRF token into context.
@ -62,12 +62,11 @@ type (
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
}
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
// either a token or an error.
csrfTokenExtractor func(echo.Context) (string, error)
)
// ErrCSRFInvalid is returned when CSRF check fails
var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieSecure = true
}
// Initialize
parts := strings.Split(config.TokenLookup, ":")
extractor := csrfTokenFromHeader(parts[1])
switch parts[0] {
case "form":
extractor = csrfTokenFromForm(parts[1])
case "query":
extractor = csrfTokenFromQuery(parts[1])
extractors, err := createExtractors(config.TokenLookup, "")
if err != nil {
panic(err)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c)
}
req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""
// Generate token
if err != nil {
token = random.String(config.TokenLength)
if k, err := c.Cookie(config.CookieName); err != nil {
token = random.String(config.TokenLength) // Generate token
} else {
// Reuse token
token = k.Value
token = k.Value // Reuse token
}
switch req.Method {
switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
clientToken, err := extractor(c)
var lastExtractorErr error
var lastTokenErr error
outer:
for _, extractor := range extractors {
clientTokens, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
lastExtractorErr = err
continue
}
if !validateCSRFToken(token, clientToken) {
return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
for _, clientToken := range clientTokens {
if validateCSRFToken(token, clientToken) {
lastTokenErr = nil
lastExtractorErr = nil
break outer
}
lastTokenErr = ErrCSRFInvalid
}
}
if lastTokenErr != nil {
return lastTokenErr
} else if lastExtractorErr != nil {
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
} else if lastExtractorErr == errFormExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
} else {
lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
}
return lastExtractorErr
}
}
@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
}
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided request header.
func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
return c.Request().Header.Get(header), nil
}
}
// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
// provided form parameter.
func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("missing csrf token in the form parameter")
}
return token, nil
}
}
// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
// provided query parameter.
func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("missing csrf token in the query string")
}
return token, nil
}
}
func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}

View File

@ -1,7 +1,6 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
@ -13,14 +12,205 @@ import (
"github.com/stretchr/testify/assert"
)
func TestCSRF_tokenExtractors(t *testing.T) {
var testCases = []struct {
name string
whenTokenLookup string
whenCookieName string
givenCSRFCookie string
givenMethod string
givenQueryTokens map[string][]string
givenFormTokens map[string][]string
givenHeaderTokens map[string][]string
expectError string
}{
{
name: "ok, multiple token lookups sources, succeeds on last one",
whenTokenLookup: "header:X-CSRF-Token,form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid_token"},
},
givenFormTokens: map[string][]string{
"csrf": {"token"},
},
},
{
name: "ok, token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"token"},
},
},
{
name: "ok, token from POST form, second token passes",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
},
{
name: "nok, invalid token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{
"csrf": {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from POST form",
whenTokenLookup: "form:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenFormTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in the form parameter",
},
{
name: "ok, token from POST header",
whenTokenLookup: "", // will use defaults
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"token"},
},
},
{
name: "ok, token from POST header, second token passes",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid", "token"},
},
},
{
name: "nok, invalid token from POST header",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{
echo.HeaderXCSRFToken: {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from POST header",
whenTokenLookup: "header:" + echo.HeaderXCSRFToken,
givenCSRFCookie: "token",
givenMethod: http.MethodPost,
givenHeaderTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in request header",
},
{
name: "ok, token from PUT query param",
whenTokenLookup: "query:csrf-param",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf-param": {"token"},
},
},
{
name: "ok, token from PUT query form, second token passes",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf": {"invalid", "token"},
},
},
{
name: "nok, invalid token from PUT query form",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{
"csrf": {"invalid_token"},
},
expectError: "code=403, message=invalid csrf token",
},
{
name: "nok, missing token from PUT query form",
whenTokenLookup: "query:csrf",
givenCSRFCookie: "token",
givenMethod: http.MethodPut,
givenQueryTokens: map[string][]string{},
expectError: "code=400, message=missing csrf token in the query string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
q := make(url.Values)
for queryParam, values := range tc.givenQueryTokens {
for _, v := range values {
q.Add(queryParam, v)
}
}
f := make(url.Values)
for formKey, values := range tc.givenFormTokens {
for _, v := range values {
f.Add(formKey, v)
}
}
var req *http.Request
switch tc.givenMethod {
case http.MethodGet:
req = httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil)
case http.MethodPost, http.MethodPut:
req = httptest.NewRequest(http.MethodPost, "/?"+q.Encode(), strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
}
for header, values := range tc.givenHeaderTokens {
for _, v := range values {
req.Header.Add(header, v)
}
}
if tc.givenCSRFCookie != "" {
req.Header.Set(echo.HeaderCookie, "_csrf="+tc.givenCSRFCookie)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
TokenLookup: tc.whenTokenLookup,
CookieName: tc.whenCookieName,
})
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
err := h(c)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestCSRF(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
TokenLength: 16,
})
csrf := CSRF()
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
@ -43,7 +233,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c))
// Valid CSRF token
token := random.String(16)
token := random.String(32)
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) {
@ -51,38 +241,6 @@ func TestCSRF(t *testing.T) {
}
}
func TestCSRFTokenFromForm(t *testing.T) {
f := make(url.Values)
f.Set("csrf", "token")
e := echo.New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
c := e.NewContext(req, nil)
token, err := csrfTokenFromForm("csrf")(c)
if assert.NoError(t, err) {
assert.Equal(t, "token", token)
}
_, err = csrfTokenFromForm("invalid")(c)
assert.Error(t, err)
}
func TestCSRFTokenFromQuery(t *testing.T) {
q := make(url.Values)
q.Set("csrf", "token")
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
req.URL.RawQuery = q.Encode()
c := e.NewContext(req, nil)
token, err := csrfTokenFromQuery("csrf")(c)
if assert.NoError(t, err) {
assert.Equal(t, "token", token)
}
_, err = csrfTokenFromQuery("invalid")(c)
assert.Error(t, err)
csrfTokenFromQuery("csrf")
}
func TestCSRFSetSameSiteMode(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
@ -135,7 +293,6 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) {
r := h(c)
assert.NoError(t, r)
fmt.Println(rec.Header()["Set-Cookie"])
assert.NotRegexp(t, "SameSite=", rec.Header()["Set-Cookie"])
}
@ -158,3 +315,46 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) {
assert.Regexp(t, "SameSite=None", rec.Header()["Set-Cookie"])
assert.Regexp(t, "Secure", rec.Header()["Set-Cookie"])
}
func TestCSRFConfig_skipper(t *testing.T) {
var testCases = []struct {
name string
whenSkip bool
expectCookies int
}{
{
name: "do skip",
whenSkip: true,
expectCookies: 0,
},
{
name: "do not skip",
whenSkip: false,
expectCookies: 1,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{
Skipper: func(c echo.Context) bool {
return tc.whenSkip
},
})
h := csrf(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
r := h(c)
assert.NoError(t, r)
cookie := rec.Header()["Set-Cookie"]
assert.Len(t, cookie, tc.expectCookies)
})
}
}

184
middleware/extractor.go Normal file
View File

@ -0,0 +1,184 @@
package middleware
import (
"errors"
"fmt"
"github.com/labstack/echo/v4"
"net/textproto"
"strings"
)
const (
// extractorLimit is arbitrary number to limit values extractor can return. this limits possible resource exhaustion
// attack vector
extractorLimit = 20
)
var errHeaderExtractorValueMissing = errors.New("missing value in request header")
var errHeaderExtractorValueInvalid = errors.New("invalid value in request header")
var errQueryExtractorValueMissing = errors.New("missing value in the query string")
var errParamExtractorValueMissing = errors.New("missing value in path params")
var errCookieExtractorValueMissing = errors.New("missing value in cookies")
var errFormExtractorValueMissing = errors.New("missing value in the form")
// ValuesExtractor defines a function for extracting values (keys/tokens) from the given context.
type ValuesExtractor func(c echo.Context) ([]string, error)
func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) {
if lookups == "" {
return nil, nil
}
sources := strings.Split(lookups, ",")
var extractors = make([]ValuesExtractor, 0)
for _, source := range sources {
parts := strings.Split(source, ":")
if len(parts) < 2 {
return nil, fmt.Errorf("extractor source for lookup could not be split into needed parts: %v", source)
}
switch parts[0] {
case "query":
extractors = append(extractors, valuesFromQuery(parts[1]))
case "param":
extractors = append(extractors, valuesFromParam(parts[1]))
case "cookie":
extractors = append(extractors, valuesFromCookie(parts[1]))
case "form":
extractors = append(extractors, valuesFromForm(parts[1]))
case "header":
prefix := ""
if len(parts) > 2 {
prefix = parts[2]
} else if authScheme != "" && parts[1] == echo.HeaderAuthorization {
// backwards compatibility for JWT and KeyAuth:
// * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer <token-value>" etc
// * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that
// behaviour for default values and Authorization header.
prefix = authScheme
if !strings.HasSuffix(prefix, " ") {
prefix += " "
}
}
extractors = append(extractors, valuesFromHeader(parts[1], prefix))
}
}
return extractors, nil
}
// valuesFromHeader returns a functions that extracts values from the request header.
// valuePrefix is parameter to remove first part (prefix) of the extracted value. This is useful if header value has static
// prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we want to remove is `<auth-scheme> `
// note the space at the end. In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove
// is `Basic `. In case of JWT tokens `Authorization: Bearer <token>` prefix is `Bearer `.
// If prefix is left empty the whole value is returned.
func valuesFromHeader(header string, valuePrefix string) ValuesExtractor {
prefixLen := len(valuePrefix)
// standard library parses http.Request header keys in canonical form but we may provide something else so fix this
header = textproto.CanonicalMIMEHeaderKey(header)
return func(c echo.Context) ([]string, error) {
values := c.Request().Header.Values(header)
if len(values) == 0 {
return nil, errHeaderExtractorValueMissing
}
result := make([]string, 0)
for i, value := range values {
if prefixLen == 0 {
result = append(result, value)
if i >= extractorLimit-1 {
break
}
continue
}
if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) {
result = append(result, value[prefixLen:])
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
if prefixLen > 0 {
return nil, errHeaderExtractorValueInvalid
}
return nil, errHeaderExtractorValueMissing
}
return result, nil
}
}
// valuesFromQuery returns a function that extracts values from the query string.
func valuesFromQuery(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
result := c.QueryParams()[param]
if len(result) == 0 {
return nil, errQueryExtractorValueMissing
} else if len(result) > extractorLimit-1 {
result = result[:extractorLimit]
}
return result, nil
}
}
// valuesFromParam returns a function that extracts values from the url param string.
func valuesFromParam(param string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
result := make([]string, 0)
paramVales := c.ParamValues()
for i, p := range c.ParamNames() {
if param == p {
result = append(result, paramVales[i])
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
return nil, errParamExtractorValueMissing
}
return result, nil
}
}
// valuesFromCookie returns a function that extracts values from the named cookie.
func valuesFromCookie(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
cookies := c.Cookies()
if len(cookies) == 0 {
return nil, errCookieExtractorValueMissing
}
result := make([]string, 0)
for i, cookie := range cookies {
if name == cookie.Name {
result = append(result, cookie.Value)
if i >= extractorLimit-1 {
break
}
}
}
if len(result) == 0 {
return nil, errCookieExtractorValueMissing
}
return result, nil
}
}
// valuesFromForm returns a function that extracts values from the form field.
func valuesFromForm(name string) ValuesExtractor {
return func(c echo.Context) ([]string, error) {
if parseErr := c.Request().ParseForm(); parseErr != nil {
return nil, fmt.Errorf("valuesFromForm parse form failed: %w", parseErr)
}
values := c.Request().Form[name]
if len(values) == 0 {
return nil, errFormExtractorValueMissing
}
if len(values) > extractorLimit-1 {
values = values[:extractorLimit]
}
result := append([]string{}, values...)
return result, nil
}
}

View File

@ -0,0 +1,587 @@
package middleware
import (
"fmt"
"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
type pathParam struct {
name string
value string
}
func setPathParams(c echo.Context, params []pathParam) {
names := make([]string, 0, len(params))
values := make([]string, 0, len(params))
for _, pp := range params {
names = append(names, pp.name)
values = append(values, pp.value)
}
c.SetParamNames(names...)
c.SetParamValues(values...)
}
func TestCreateExtractors(t *testing.T) {
var testCases = []struct {
name string
givenRequest func() *http.Request
givenPathParams []pathParam
whenLoopups string
expectValues []string
expectCreateError string
expectError string
}{
{
name: "ok, header",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "Bearer token")
return req
},
whenLoopups: "header:Authorization:Bearer ",
expectValues: []string{"token"},
},
{
name: "ok, form",
givenRequest: func() *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
},
whenLoopups: "form:name",
expectValues: []string{"Jon Snow"},
},
{
name: "ok, cookie",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderCookie, "_csrf=token")
return req
},
whenLoopups: "cookie:_csrf",
expectValues: []string{"token"},
},
{
name: "ok, param",
givenPathParams: []pathParam{
{name: "id", value: "123"},
},
whenLoopups: "param:id",
expectValues: []string{"123"},
},
{
name: "ok, query",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/?id=999", nil)
return req
},
whenLoopups: "query:id",
expectValues: []string{"999"},
},
{
name: "nok, invalid lookup",
whenLoopups: "query",
expectCreateError: "extractor source for lookup could not be split into needed parts: query",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
req = tc.givenRequest()
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tc.givenPathParams != nil {
setPathParams(c, tc.givenPathParams)
}
extractors, err := createExtractors(tc.whenLoopups, "")
if tc.expectCreateError != "" {
assert.EqualError(t, err, tc.expectCreateError)
return
}
assert.NoError(t, err)
for _, e := range extractors {
values, eErr := e(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, eErr, tc.expectError)
return
}
assert.NoError(t, eErr)
}
})
}
}
func TestValuesFromHeader(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
whenValuePrefix string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, single value, case insensitive",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA=="},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"},
},
{
name: "ok, empty prefix",
givenRequest: exampleRequest,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{"basic dXNlcjpwYXNzd29yZA=="},
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "Bearer ",
expectError: errHeaderExtractorValueInvalid.Error(),
},
{
name: "nok, no matching due different prefix",
givenRequest: func(req *http.Request) {
req.Header.Set(echo.HeaderAuthorization, "basic dXNlcjpwYXNzd29yZA==")
req.Header.Add(echo.HeaderAuthorization, "basic dGVzdDp0ZXN0")
},
whenName: echo.HeaderWWWAuthenticate,
whenValuePrefix: "",
expectError: errHeaderExtractorValueMissing.Error(),
},
{
name: "nok, no headers",
givenRequest: nil,
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectError: errHeaderExtractorValueMissing.Error(),
},
{
name: "ok, prefix, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i <= 25; i++ {
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("basic %v", i))
}
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "basic ",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
{
name: "ok, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i <= 25; i++ {
req.Header.Add(echo.HeaderAuthorization, fmt.Sprintf("%v", i))
}
},
whenName: echo.HeaderAuthorization,
whenValuePrefix: "",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromQuery(t *testing.T) {
var testCases = []struct {
name string
givenQueryPart string
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenQueryPart: "?id=123&name=test",
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenQueryPart: "?id=123&id=456&name=test",
whenName: "id",
expectValues: []string{"123", "456"},
},
{
name: "nok, missing value",
givenQueryPart: "?id=123&name=test",
whenName: "nope",
expectError: errQueryExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenQueryPart: "?name=test" +
"&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" +
"&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" +
"&id=21&id=22&id=23&id=24&id=25",
whenName: "id",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/"+tc.givenQueryPart, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromQuery(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromParam(t *testing.T) {
examplePathParams := []pathParam{
{name: "id", value: "123"},
{name: "gid", value: "456"},
{name: "gid", value: "789"},
}
examplePathParams20 := make([]pathParam, 0)
for i := 1; i < 25; i++ {
examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)})
}
var testCases = []struct {
name string
givenPathParams []pathParam
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenPathParams: examplePathParams,
whenName: "id",
expectValues: []string{"123"},
},
{
name: "ok, multiple value",
givenPathParams: examplePathParams,
whenName: "gid",
expectValues: []string{"456", "789"},
},
{
name: "nok, no values",
givenPathParams: nil,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "nok, no matching value",
givenPathParams: examplePathParams,
whenName: "nope",
expectValues: nil,
expectError: errParamExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenPathParams: examplePathParams20,
whenName: "id",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tc.givenPathParams != nil {
setPathParams(c, tc.givenPathParams)
}
extractor := valuesFromParam(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromCookie(t *testing.T) {
exampleRequest := func(req *http.Request) {
req.Header.Set(echo.HeaderCookie, "_csrf=token")
}
var testCases = []struct {
name string
givenRequest func(req *http.Request)
whenName string
expectValues []string
expectError string
}{
{
name: "ok, single value",
givenRequest: exampleRequest,
whenName: "_csrf",
expectValues: []string{"token"},
},
{
name: "ok, multiple value",
givenRequest: func(req *http.Request) {
req.Header.Add(echo.HeaderCookie, "_csrf=token")
req.Header.Add(echo.HeaderCookie, "_csrf=token2")
},
whenName: "_csrf",
expectValues: []string{"token", "token2"},
},
{
name: "nok, no matching cookie",
givenRequest: exampleRequest,
whenName: "xxx",
expectValues: nil,
expectError: errCookieExtractorValueMissing.Error(),
},
{
name: "nok, no cookies at all",
givenRequest: nil,
whenName: "xxx",
expectValues: nil,
expectError: errCookieExtractorValueMissing.Error(),
},
{
name: "ok, cut values over extractorLimit",
givenRequest: func(req *http.Request) {
for i := 1; i < 25; i++ {
req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i))
}
},
whenName: "_csrf",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenRequest != nil {
tc.givenRequest(req)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromCookie(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValuesFromForm(t *testing.T) {
examplePostFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode()))
req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
return req
}
exampleGetFormRequest := func(mod func(v *url.Values)) *http.Request {
f := make(url.Values)
f.Set("name", "Jon Snow")
f.Set("emails[]", "jon@labstack.com")
if mod != nil {
mod(&f)
}
req := httptest.NewRequest(http.MethodGet, "/?"+f.Encode(), nil)
return req
}
var testCases = []struct {
name string
givenRequest *http.Request
whenName string
expectValues []string
expectError string
}{
{
name: "ok, POST form, single value",
givenRequest: examplePostFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, POST form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "ok, GET form, single value",
givenRequest: exampleGetFormRequest(nil),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com"},
},
{
name: "ok, GET form, multiple value",
givenRequest: examplePostFormRequest(func(v *url.Values) {
v.Add("emails[]", "snow@labstack.com")
}),
whenName: "emails[]",
expectValues: []string{"jon@labstack.com", "snow@labstack.com"},
},
{
name: "nok, POST form, value missing",
givenRequest: examplePostFormRequest(nil),
whenName: "nope",
expectError: errFormExtractorValueMissing.Error(),
},
{
name: "nok, POST form, form parsing error",
givenRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Body = nil
return req
}(),
whenName: "name",
expectError: "valuesFromForm parse form failed: missing form body",
},
{
name: "ok, cut values over extractorLimit",
givenRequest: examplePostFormRequest(func(v *url.Values) {
for i := 1; i < 25; i++ {
v.Add("id[]", fmt.Sprintf("%v", i))
}
}),
whenName: "id[]",
expectValues: []string{
"1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
"11", "12", "13", "14", "15", "16", "17", "18", "19", "20",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := tc.givenRequest
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
extractor := valuesFromForm(tc.whenName)
values, err := extractor(c)
assert.Equal(t, tc.expectValues, values)
if tc.expectError != "" {
assert.EqualError(t, err, tc.expectError)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -1,3 +1,4 @@
//go:build go1.15
// +build go1.15
package middleware
@ -5,12 +6,10 @@ package middleware
import (
"errors"
"fmt"
"net/http"
"reflect"
"strings"
"github.com/golang-jwt/jwt"
"github.com/labstack/echo/v4"
"net/http"
"reflect"
)
type (
@ -22,7 +21,8 @@ type (
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc BeforeFunc
// SuccessHandler defines a function which is executed for a valid token.
// SuccessHandler defines a function which is executed for a valid token before middleware chain continues with next
// middleware or handler.
SuccessHandler JWTSuccessHandler
// ErrorHandler defines a function which is executed for an invalid token.
@ -32,6 +32,13 @@ type (
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
ErrorHandlerWithContext JWTErrorHandlerWithContext
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandlerWithContext decides to
// ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandlerWithContext to set a default public JWT token value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) token value then.
ContinueOnIgnoredError bool
// Signing key to validate token.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
@ -61,12 +68,17 @@ type (
// to extract token from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end.
// In case of JWT tokens `Authorization: Bearer <token>` prefix we cut is `Bearer `.
// If prefix is left empty the whole value is returned.
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
// Multiply sources example:
// Multiple sources example:
// - "header:Authorization,cookie:myowncookie"
TokenLookup string
@ -74,7 +86,7 @@ type (
// This is one of the two options to provide a token extractor.
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
// You can also provide both if you want.
TokenLookupFuncs []TokenLookupFunc
TokenLookupFuncs []ValuesExtractor
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
@ -100,16 +112,13 @@ type (
}
// JWTSuccessHandler defines a function which is executed for a valid token.
JWTSuccessHandler func(echo.Context)
JWTSuccessHandler func(c echo.Context)
// JWTErrorHandler defines a function which is executed for an invalid token.
JWTErrorHandler func(error) error
JWTErrorHandler func(err error) error
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error
// TokenLookupFunc defines a function for extracting JWT token from the given context.
TokenLookupFunc func(echo.Context) (string, error)
JWTErrorHandlerWithContext func(err error, c echo.Context) error
)
// Algorithms
@ -183,25 +192,12 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
config.ParseTokenFunc = config.defaultParseToken
}
// Initialize
// Split sources
sources := strings.Split(config.TokenLookup, ",")
var extractors = config.TokenLookupFuncs
for _, source := range sources {
parts := strings.Split(source, ":")
switch parts[0] {
case "query":
extractors = append(extractors, jwtFromQuery(parts[1]))
case "param":
extractors = append(extractors, jwtFromParam(parts[1]))
case "cookie":
extractors = append(extractors, jwtFromCookie(parts[1]))
case "form":
extractors = append(extractors, jwtFromForm(parts[1]))
case "header":
extractors = append(extractors, jwtFromHeader(parts[1], config.AuthScheme))
extractors, err := createExtractors(config.TokenLookup, config.AuthScheme)
if err != nil {
panic(err)
}
if len(config.TokenLookupFuncs) > 0 {
extractors = append(config.TokenLookupFuncs, extractors...)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -213,30 +209,21 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
var auth string
var err error
var lastExtractorErr error
var lastTokenErr error
for _, extractor := range extractors {
// Extract token from extractor, if it's not fail break the loop and
// set auth
auth, err = extractor(c)
if err == nil {
break
}
}
// If none of extractor has a token, handle error
auths, err := extractor(c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
lastExtractorErr = ErrJWTMissing // backwards compatibility: all extraction errors are same (unlike KeyAuth)
continue
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
}
return err
}
for _, auth := range auths {
token, err := config.ParseTokenFunc(auth, c)
if err == nil {
if err != nil {
lastTokenErr = err
continue
}
// Store user information from token into context.
c.Set(config.ContextKey, token)
if config.SuccessHandler != nil {
@ -244,18 +231,33 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
return next(c)
}
}
// we are here only when we did not successfully extract or parse any of the tokens
err := lastTokenErr
if err == nil { // prioritize token errors over extracting errors
err = lastExtractorErr
}
if config.ErrorHandler != nil {
return config.ErrorHandler(err)
}
if config.ErrorHandlerWithContext != nil {
return config.ErrorHandlerWithContext(err, c)
tmpErr := config.ErrorHandlerWithContext(err, c)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
// backwards compatible errors codes
if lastTokenErr != nil {
return &echo.HTTPError{
Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message,
Internal: err,
}
}
return err // this is lastExtractorErr value
}
}
}
@ -296,59 +298,3 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
return config.SigningKey, nil
}
// jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
l := len(authScheme)
if len(auth) > l+1 && strings.EqualFold(auth[:l], authScheme) {
return auth[l+1:], nil
}
return "", ErrJWTMissing
}
}
// jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
func jwtFromQuery(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
func jwtFromParam(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
token := c.Param(param)
if token == "" {
return "", ErrJWTMissing
}
return token, nil
}
}
// jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
func jwtFromCookie(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", ErrJWTMissing
}
return cookie.Value, nil
}
}
// jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
func jwtFromForm(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -1,3 +1,4 @@
//go:build go1.15
// +build go1.15
package middleware
@ -28,6 +29,26 @@ type jwtCustomClaims struct {
jwtCustomInfo
}
func TestJWT(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
})
e.Use(JWT([]byte("secret")))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
}
func TestJWTRace(t *testing.T) {
e := echo.New()
handler := func(c echo.Context) error {
@ -64,8 +85,7 @@ func TestJWTRace(t *testing.T) {
assert.Equal(t, claims.Admin, true)
}
func TestJWT(t *testing.T) {
e := echo.New()
func TestJWTConfig(t *testing.T) {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
@ -74,7 +94,8 @@ func TestJWT(t *testing.T) {
invalidKey := []byte("invalid-key")
validAuth := DefaultJWTConfig.AuthScheme + " " + token
for _, tc := range []struct {
testCases := []struct {
name string
expPanic bool
expErrCode int // 0 for Success
config JWTConfig
@ -82,166 +103,166 @@ func TestJWT(t *testing.T) {
hdrAuth string
hdrCookie string // test.Request doesn't provide SetCookie(); use name=val
formValues map[string]string
info string
}{
{
name: "No signing key provided",
expPanic: true,
info: "No signing key provided",
},
{
name: "Unexpected signing method",
expErrCode: http.StatusBadRequest,
config: JWTConfig{
SigningKey: validKey,
SigningMethod: "RS256",
},
info: "Unexpected signing method",
},
{
name: "Invalid key",
expErrCode: http.StatusUnauthorized,
hdrAuth: validAuth,
config: JWTConfig{SigningKey: invalidKey},
info: "Invalid key",
},
{
name: "Valid JWT",
hdrAuth: validAuth,
config: JWTConfig{SigningKey: validKey},
info: "Valid JWT",
},
{
name: "Valid JWT with custom AuthScheme",
hdrAuth: "Token" + " " + token,
config: JWTConfig{AuthScheme: "Token", SigningKey: validKey},
info: "Valid JWT with custom AuthScheme",
},
{
name: "Valid JWT with custom claims",
hdrAuth: validAuth,
config: JWTConfig{
Claims: &jwtCustomClaims{},
SigningKey: []byte("secret"),
},
info: "Valid JWT with custom claims",
},
{
name: "Invalid Authorization header",
hdrAuth: "invalid-auth",
expErrCode: http.StatusBadRequest,
config: JWTConfig{SigningKey: validKey},
info: "Invalid Authorization header",
},
{
name: "Empty header auth field",
config: JWTConfig{SigningKey: validKey},
expErrCode: http.StatusBadRequest,
info: "Empty header auth field",
},
{
name: "Valid query method",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwt=" + token,
info: "Valid query method",
},
{
name: "Invalid query param name",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwtxyz=" + token,
expErrCode: http.StatusBadRequest,
info: "Invalid query param name",
},
{
name: "Invalid query param value",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "query:jwt",
},
reqURL: "/?a=b&jwt=invalid-token",
expErrCode: http.StatusUnauthorized,
info: "Invalid query param value",
},
{
name: "Empty query",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "query:jwt",
},
reqURL: "/?a=b",
expErrCode: http.StatusBadRequest,
info: "Empty query",
},
{
name: "Valid param method",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "param:jwt",
},
reqURL: "/" + token,
info: "Valid param method",
},
{
name: "Valid cookie method",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "cookie:jwt",
},
hdrCookie: "jwt=" + token,
info: "Valid cookie method",
},
{
name: "Multiple jwt lookuop",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "query:jwt,cookie:jwt",
},
hdrCookie: "jwt=" + token,
info: "Multiple jwt lookuop",
},
{
name: "Invalid token with cookie method",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "cookie:jwt",
},
expErrCode: http.StatusUnauthorized,
hdrCookie: "jwt=invalid",
info: "Invalid token with cookie method",
},
{
name: "Empty cookie",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "cookie:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty cookie",
},
{
name: "Valid form method",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
formValues: map[string]string{"jwt": token},
info: "Valid form method",
},
{
name: "Invalid token with form method",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusUnauthorized,
formValues: map[string]string{"jwt": "invalid"},
info: "Invalid token with form method",
},
{
name: "Empty form field",
config: JWTConfig{
SigningKey: validKey,
TokenLookup: "form:jwt",
},
expErrCode: http.StatusBadRequest,
info: "Empty form field",
},
{
name: "Valid JWT with a valid key using a user-defined KeyFunc",
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
return validKey, nil
},
},
info: "Valid JWT with a valid key using a user-defined KeyFunc",
},
{
name: "Valid JWT with an invalid key using a user-defined KeyFunc",
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
@ -249,9 +270,9 @@ func TestJWT(t *testing.T) {
},
},
expErrCode: http.StatusUnauthorized,
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
},
{
name: "Token verification does not pass using a user-defined KeyFunc",
hdrAuth: validAuth,
config: JWTConfig{
KeyFunc: func(*jwt.Token) (interface{}, error) {
@ -259,14 +280,16 @@ func TestJWT(t *testing.T) {
},
},
expErrCode: http.StatusUnauthorized,
info: "Token verification does not pass using a user-defined KeyFunc",
},
{
name: "Valid JWT with lower case AuthScheme",
hdrAuth: strings.ToLower(DefaultJWTConfig.AuthScheme) + " " + token,
config: JWTConfig{SigningKey: validKey},
info: "Valid JWT with lower case AuthScheme",
},
} {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
if tc.reqURL == "" {
tc.reqURL = "/"
}
@ -296,30 +319,31 @@ func TestJWT(t *testing.T) {
if tc.expPanic {
assert.Panics(t, func() {
JWTWithConfig(tc.config)
}, tc.info)
continue
}, tc.name)
return
}
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
continue
assert.Equal(t, tc.expErrCode, he.Code, tc.name)
return
}
h := JWTWithConfig(tc.config)(handler)
if assert.NoError(t, h(c), tc.info) {
if assert.NoError(t, h(c), tc.name) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
assert.Equal(t, claims["name"], "John Doe", tc.info)
assert.Equal(t, claims["name"], "John Doe", tc.name)
case *jwtCustomClaims:
assert.Equal(t, claims.Name, "John Doe", tc.info)
assert.Equal(t, claims.Admin, true, tc.info)
assert.Equal(t, claims.Name, "John Doe", tc.name)
assert.Equal(t, claims.Admin, true, tc.name)
default:
panic("unexpected type of claims")
}
}
})
}
}
@ -608,13 +632,14 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
})
e.Use(JWTWithConfig(JWTConfig{
TokenLookupFuncs: []TokenLookupFunc{
func(c echo.Context) (string, error) {
return c.Request().Header.Get("X-API-Key"), nil
TokenLookupFuncs: []ValuesExtractor{
func(c echo.Context) ([]string, error) {
return []string{c.Request().Header.Get("X-API-Key")}, nil
},
},
SigningKey: []byte("secret"),
@ -626,4 +651,129 @@ func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, `{"admin":true,"name":"John Doe","sub":"1234567890"}`+"\n", res.Body.String())
}
func TestJWTConfig_SuccessHandler(t *testing.T) {
var testCases = []struct {
name string
givenToken string
expectCalled bool
expectStatus int
}{
{
name: "ok, success handler is called",
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
expectCalled: true,
expectStatus: http.StatusOK,
},
{
name: "nok, success handler is not called",
givenToken: "x.x.x",
expectCalled: false,
expectStatus: http.StatusUnauthorized,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
token := c.Get("user").(*jwt.Token)
return c.JSON(http.StatusOK, token.Claims)
})
wasCalled := false
e.Use(JWTWithConfig(JWTConfig{
SuccessHandler: func(c echo.Context) {
wasCalled = true
},
SigningKey: []byte("secret"),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectCalled, wasCalled)
assert.Equal(t, tc.expectStatus, res.Code)
})
}
}
func TestJWTConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
name string
whenContinueOnIgnoredError bool
givenToken string
expectStatus int
expectBody string
}{
{
name: "no error handler is called",
whenContinueOnIgnoredError: true,
givenToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ",
expectStatus: http.StatusTeapot,
expectBody: "",
},
{
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
whenContinueOnIgnoredError: false,
givenToken: "",
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
expectStatus: http.StatusOK,
expectBody: "",
},
{
name: "error handler is called for missing token",
whenContinueOnIgnoredError: true,
givenToken: "",
expectStatus: http.StatusTeapot,
expectBody: "public-token",
},
{
name: "error handler is called for invalid token",
whenContinueOnIgnoredError: true,
givenToken: "x.x.x",
expectStatus: http.StatusUnauthorized,
expectBody: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
testValue, _ := c.Get("test").(string)
return c.String(http.StatusTeapot, testValue)
})
e.Use(JWTWithConfig(JWTConfig{
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, c echo.Context) error {
if err == ErrJWTMissing {
c.Set("test", "public-token")
return nil
}
return echo.ErrUnauthorized
},
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenToken != "" {
req.Header.Set(echo.HeaderAuthorization, "bearer "+tc.givenToken)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectStatus, res.Code)
assert.Equal(t, tc.expectBody, res.Body.String())
})
}
}

View File

@ -2,11 +2,8 @@ package middleware
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"net/http"
)
type (
@ -15,15 +12,21 @@ type (
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "header:<name>" or "header:<name>:<cut-prefix>"
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
// want to cut is `<auth-scheme> ` note the space at the end.
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
// - "query:<name>"
// - "form:<name>"
// - "cookie:<name>"
KeyLookup string `yaml:"key_lookup"`
// Multiple sources example:
// - "header:Authorization,header:X-Api-Key"
KeyLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
@ -36,15 +39,20 @@ type (
// ErrorHandler defines a function which is executed for an invalid key.
// It may be used to define a custom error.
ErrorHandler KeyAuthErrorHandler
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
// ignore the error (by returning `nil`).
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
// In that case you can use ErrorHandler to set a default public key auth value in the request context
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
ContinueOnIgnoredError bool
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string, echo.Context) (bool, error)
keyExtractor func(echo.Context) (string, error)
KeyAuthValidator func(auth string, c echo.Context) (bool, error)
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
KeyAuthErrorHandler func(error, echo.Context) error
KeyAuthErrorHandler func(err error, c echo.Context) error
)
var (
@ -56,6 +64,21 @@ var (
}
)
// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups
type ErrKeyAuthMissing struct {
Err error
}
// Error returns errors text
func (e *ErrKeyAuthMissing) Error() string {
return e.Err.Error()
}
// Unwrap unwraps error
func (e *ErrKeyAuthMissing) Unwrap() error {
return e.Err
}
// KeyAuth returns an KeyAuth middleware.
//
// For valid key it calls the next handler.
@ -85,16 +108,9 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
panic("echo: key-auth middleware requires a validator function")
}
// Initialize
parts := strings.Split(config.KeyLookup, ":")
extractor := keyFromHeader(parts[1], config.AuthScheme)
switch parts[0] {
case "query":
extractor = keyFromQuery(parts[1])
case "form":
extractor = keyFromForm(parts[1])
case "cookie":
extractor = keyFromCookie(parts[1])
extractors, err := createExtractors(config.KeyLookup, config.AuthScheme)
if err != nil {
panic(err)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -103,79 +119,62 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
return next(c)
}
// Extract and verify key
key, err := extractor(c)
var lastExtractorErr error
var lastValidatorErr error
for _, extractor := range extractors {
keys, err := extractor(c)
if err != nil {
lastExtractorErr = err
continue
}
for _, key := range keys {
valid, err := config.Validator(key, c)
if err != nil {
lastValidatorErr = err
continue
}
if valid {
return next(c)
}
lastValidatorErr = errors.New("invalid key")
}
}
// we are here only when we did not successfully extract and validate any of keys
err := lastValidatorErr
if err == nil { // prioritize validator errors over extracting errors
// ugly part to preserve backwards compatible errors. someone could rely on them
if lastExtractorErr == errQueryExtractorValueMissing {
err = errors.New("missing key in the query string")
} else if lastExtractorErr == errCookieExtractorValueMissing {
err = errors.New("missing key in cookies")
} else if lastExtractorErr == errFormExtractorValueMissing {
err = errors.New("missing key in the form")
} else if lastExtractorErr == errHeaderExtractorValueMissing {
err = errors.New("missing key in request header")
} else if lastExtractorErr == errHeaderExtractorValueInvalid {
err = errors.New("invalid key in the request header")
} else {
err = lastExtractorErr
}
err = &ErrKeyAuthMissing{Err: err}
}
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
tmpErr := config.ErrorHandler(err, c)
if config.ContinueOnIgnoredError && tmpErr == nil {
return next(c)
}
return tmpErr
}
if lastValidatorErr != nil { // prioritize validator errors over extracting errors
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "Unauthorized",
Internal: lastValidatorErr,
}
}
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
valid, err := config.Validator(key, c)
if err != nil {
if config.ErrorHandler != nil {
return config.ErrorHandler(err, c)
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid key",
Internal: err,
}
} else if valid {
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if auth == "" {
return "", errors.New("missing key in request header")
}
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("invalid key in the request header")
}
return auth, nil
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("missing key in the query string")
}
return key, nil
}
}
// keyFromForm returns a `keyExtractor` that extracts key from the form.
func keyFromForm(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.FormValue(param)
if key == "" {
return "", errors.New("missing key in the form")
}
return key, nil
}
}
// keyFromCookie returns a `keyExtractor` that extracts key from the form.
func keyFromCookie(cookieName string) keyExtractor {
return func(c echo.Context) (string, error) {
key, err := c.Cookie(cookieName)
if err != nil {
return "", fmt.Errorf("missing key in cookies: %w", err)
}
return key.Value, nil
}
}

View File

@ -76,7 +76,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key")
},
expectHandlerCalled: false,
expectError: "code=401, message=Unauthorized",
expectError: "code=401, message=Unauthorized, internal=invalid key",
},
{
name: "nok, defaults, invalid scheme in header",
@ -92,6 +92,17 @@ func TestKeyAuthWithConfig(t *testing.T) {
expectHandlerCalled: false,
expectError: "code=400, message=missing key in request header",
},
{
name: "ok, custom key lookup from multiple places, query and header",
givenRequest: func(req *http.Request) {
req.URL.RawQuery = "key=invalid-key"
req.Header.Set("API-Key", "valid-key")
},
whenConfig: func(conf *KeyAuthConfig) {
conf.KeyLookup = "query:key,header:API-Key"
},
expectHandlerCalled: true,
},
{
name: "ok, custom key lookup, header",
givenRequest: func(req *http.Request) {
@ -179,7 +190,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
conf.KeyLookup = "cookie:key"
},
expectHandlerCalled: false,
expectError: "code=400, message=missing key in cookies: http: named cookie not present",
expectError: "code=400, message=missing key in cookies",
},
{
name: "nok, custom errorHandler, error from extractor",
@ -216,7 +227,7 @@ func TestKeyAuthWithConfig(t *testing.T) {
},
whenConfig: func(conf *KeyAuthConfig) {},
expectHandlerCalled: false,
expectError: "code=401, message=invalid key, internal=some user defined error",
expectError: "code=401, message=Unauthorized, internal=some user defined error",
},
}
@ -257,3 +268,109 @@ func TestKeyAuthWithConfig(t *testing.T) {
})
}
}
func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) {
assert.PanicsWithError(
t,
"extractor source for lookup could not be split into needed parts: a",
func() {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
KeyLookup: "a",
})(handler)
},
)
}
func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) {
assert.PanicsWithValue(
t,
"echo: key-auth middleware requires a validator function",
func() {
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
KeyAuthWithConfig(KeyAuthConfig{
Validator: nil,
})(handler)
},
)
}
func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) {
var testCases = []struct {
name string
whenContinueOnIgnoredError bool
givenKey string
expectStatus int
expectBody string
}{
{
name: "no error handler is called",
whenContinueOnIgnoredError: true,
givenKey: "valid-key",
expectStatus: http.StatusTeapot,
expectBody: "",
},
{
name: "ContinueOnIgnoredError is false and error handler is called for missing token",
whenContinueOnIgnoredError: false,
givenKey: "",
// empty response with 200. This emulates previous behaviour when error handler swallowed the error
expectStatus: http.StatusOK,
expectBody: "",
},
{
name: "error handler is called for missing token",
whenContinueOnIgnoredError: true,
givenKey: "",
expectStatus: http.StatusTeapot,
expectBody: "public-auth",
},
{
name: "error handler is called for invalid token",
whenContinueOnIgnoredError: true,
givenKey: "x.x.x",
expectStatus: http.StatusUnauthorized,
expectBody: "{\"message\":\"Unauthorized\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
testValue, _ := c.Get("test").(string)
return c.String(http.StatusTeapot, testValue)
})
e.Use(KeyAuthWithConfig(KeyAuthConfig{
Validator: testKeyValidator,
ErrorHandler: func(err error, c echo.Context) error {
if _, ok := err.(*ErrKeyAuthMissing); ok {
c.Set("test", "public-auth")
return nil
}
return echo.ErrUnauthorized
},
KeyLookup: "header:X-API-Key",
ContinueOnIgnoredError: tc.whenContinueOnIgnoredError,
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.givenKey != "" {
req.Header.Set("X-API-Key", tc.givenKey)
}
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectStatus, res.Code)
assert.Equal(t, tc.expectBody, res.Body.String())
})
}
}

View File

@ -12,10 +12,10 @@ import (
type (
// Skipper defines a function to skip middleware. Returning true skips processing
// the middleware.
Skipper func(echo.Context) bool
Skipper func(c echo.Context) bool
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc func(echo.Context)
BeforeFunc func(c echo.Context)
)
func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {