1
0
mirror of https://github.com/labstack/echo.git synced 2024-11-28 08:38:39 +02:00

Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing.

This commit is contained in:
toimtoimtoim 2021-06-06 21:36:41 +03:00 committed by Martti T
parent fdacff0d93
commit 1ac4a8f3d0
2 changed files with 228 additions and 12 deletions

View File

@ -1,6 +1,7 @@
package middleware
import (
"errors"
"fmt"
"net/http"
"reflect"
@ -49,7 +50,8 @@ type (
// Optional. Default value "user".
ContextKey string
// Claims are extendable claims data defining token content.
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
// Not used if custom ParseTokenFunc is set.
// Optional. Default value jwt.MapClaims
Claims jwt.Claims
@ -74,13 +76,20 @@ type (
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
// The function shall take care of verifying the signing algorithm and selecting the proper key.
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
// Used by default ParseTokenFunc implementation.
//
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
// This is one of the three options to provide a token validation key.
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
// Required if neither SigningKeys nor SigningKey is provided.
// Not used if custom ParseTokenFunc is set.
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
KeyFunc jwt.Keyfunc
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
// parsing fails or parsed token is invalid.
// Defaults to implementation using `github.com/dgrijalva/jwt-go` as JWT implementation library
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
}
// JWTSuccessHandler defines a function which is executed for a valid token.
@ -140,7 +149,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil {
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
panic("echo: jwt middleware requires signing key")
}
if config.SigningMethod == "" {
@ -161,6 +170,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.KeyFunc == nil {
config.KeyFunc = config.defaultKeyFunc
}
if config.ParseTokenFunc == nil {
config.ParseTokenFunc = config.defaultParseToken
}
// Initialize
// Split sources
@ -214,16 +226,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
return err
}
token := new(jwt.Token)
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err == nil && token.Valid {
token, err := config.ParseTokenFunc(auth, c)
if err == nil {
// Store user information from token into context.
c.Set(config.ContextKey, token)
if config.SuccessHandler != nil {
@ -246,6 +250,26 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
}
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
token := new(jwt.Token)
var err error
// Issue #647, #656
if _, ok := config.Claims.(jwt.MapClaims); ok {
token, err = jwt.Parse(auth, config.KeyFunc)
} else {
t := reflect.ValueOf(config.Claims).Type().Elem()
claims := reflect.New(t).Interface().(jwt.Claims)
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
}
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
}
// defaultKeyFunc returns a signing key of the given token.
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
// Check the signing method

View File

@ -2,6 +2,7 @@ package middleware
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
@ -404,3 +405,194 @@ func TestJWTwithKID(t *testing.T) {
}
}
}
func TestJWTConfig_skipper(t *testing.T) {
e := echo.New()
e.Use(JWTWithConfig(JWTConfig{
Skipper: func(context echo.Context) bool {
return true // skip everything
},
SigningKey: []byte("secret"),
}))
isCalled := false
e.GET("/", func(c echo.Context) error {
isCalled = true
return c.String(http.StatusTeapot, "test")
})
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.True(t, isCalled)
}
func TestJWTConfig_BeforeFunc(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
isCalled := false
e.Use(JWTWithConfig(JWTConfig{
BeforeFunc: func(context echo.Context) {
isCalled = true
},
SigningKey: []byte("secret"),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.True(t, isCalled)
}
func TestJWTConfig_extractorErrorHandling(t *testing.T) {
var testCases = []struct {
name string
given JWTConfig
expectStatusCode int
}{
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandler: func(err error) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
expectStatusCode: http.StatusTeapot,
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
},
},
expectStatusCode: http.StatusTeapot,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})
e.Use(JWTWithConfig(tc.given))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, tc.expectStatusCode, res.Code)
})
}
}
func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
var testCases = []struct {
name string
given JWTConfig
expectErr string
}{
{
name: "ok, ErrorHandler is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandler: func(err error) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
},
{
name: "ok, ErrorHandlerWithContext is executed",
given: JWTConfig{
SigningKey: []byte("secret"),
ErrorHandlerWithContext: func(err error, context echo.Context) error {
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
},
},
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
//e.Debug = true
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusNotImplemented, "should not end up here")
})
config := tc.given
parseTokenCalled := false
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) {
parseTokenCalled = true
return nil, errors.New("parsing failed")
}
e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
assert.Equal(t, tc.expectErr, res.Body.String())
assert.True(t, parseTokenCalled)
})
}
}
func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusTeapot, "test")
})
// example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/dgrijalva/jwt-go`
// with current JWT middleware
signingKey := []byte("secret")
config := JWTConfig{
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
keyFunc := func(t *jwt.Token) (interface{}, error) {
if t.Method.Alg() != "HS256" {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return signingKey, nil
}
// claims are of type `jwt.MapClaims` when token is created with `jwt.Parse`
token, err := jwt.Parse(auth, keyFunc)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("invalid token")
}
return token, nil
},
}
e.Use(JWTWithConfig(config))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusTeapot, res.Code)
}