mirror of
https://github.com/labstack/echo.git
synced 2025-01-01 22:09:21 +02:00
Adds JWTConfig.ParseTokenFunc to JWT middleware to allow different libraries implementing JWT parsing.
This commit is contained in:
parent
fdacff0d93
commit
1ac4a8f3d0
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
@ -49,7 +50,8 @@ type (
|
||||
// Optional. Default value "user".
|
||||
ContextKey string
|
||||
|
||||
// Claims are extendable claims data defining token content.
|
||||
// Claims are extendable claims data defining token content. Used by default ParseTokenFunc implementation.
|
||||
// Not used if custom ParseTokenFunc is set.
|
||||
// Optional. Default value jwt.MapClaims
|
||||
Claims jwt.Claims
|
||||
|
||||
@ -74,13 +76,20 @@ type (
|
||||
// KeyFunc defines a user-defined function that supplies the public key for a token validation.
|
||||
// The function shall take care of verifying the signing algorithm and selecting the proper key.
|
||||
// A user-defined KeyFunc can be useful if tokens are issued by an external party.
|
||||
// Used by default ParseTokenFunc implementation.
|
||||
//
|
||||
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
|
||||
// This is one of the three options to provide a token validation key.
|
||||
// The order of precedence is a user-defined KeyFunc, SigningKeys and SigningKey.
|
||||
// Required if neither SigningKeys nor SigningKey is provided.
|
||||
// Not used if custom ParseTokenFunc is set.
|
||||
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
|
||||
KeyFunc jwt.Keyfunc
|
||||
|
||||
// ParseTokenFunc defines a user-defined function that parses token from given auth. Returns an error when token
|
||||
// parsing fails or parsed token is invalid.
|
||||
// Defaults to implementation using `github.com/dgrijalva/jwt-go` as JWT implementation library
|
||||
ParseTokenFunc func(auth string, c echo.Context) (interface{}, error)
|
||||
}
|
||||
|
||||
// JWTSuccessHandler defines a function which is executed for a valid token.
|
||||
@ -140,7 +149,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
if config.Skipper == nil {
|
||||
config.Skipper = DefaultJWTConfig.Skipper
|
||||
}
|
||||
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil {
|
||||
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil && config.ParseTokenFunc == nil {
|
||||
panic("echo: jwt middleware requires signing key")
|
||||
}
|
||||
if config.SigningMethod == "" {
|
||||
@ -161,6 +170,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
if config.KeyFunc == nil {
|
||||
config.KeyFunc = config.defaultKeyFunc
|
||||
}
|
||||
if config.ParseTokenFunc == nil {
|
||||
config.ParseTokenFunc = config.defaultParseToken
|
||||
}
|
||||
|
||||
// Initialize
|
||||
// Split sources
|
||||
@ -214,16 +226,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
return err
|
||||
}
|
||||
|
||||
token := new(jwt.Token)
|
||||
// Issue #647, #656
|
||||
if _, ok := config.Claims.(jwt.MapClaims); ok {
|
||||
token, err = jwt.Parse(auth, config.KeyFunc)
|
||||
} else {
|
||||
t := reflect.ValueOf(config.Claims).Type().Elem()
|
||||
claims := reflect.New(t).Interface().(jwt.Claims)
|
||||
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
|
||||
}
|
||||
if err == nil && token.Valid {
|
||||
token, err := config.ParseTokenFunc(auth, c)
|
||||
if err == nil {
|
||||
// Store user information from token into context.
|
||||
c.Set(config.ContextKey, token)
|
||||
if config.SuccessHandler != nil {
|
||||
@ -246,6 +250,26 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (config *JWTConfig) defaultParseToken(auth string, c echo.Context) (interface{}, error) {
|
||||
token := new(jwt.Token)
|
||||
var err error
|
||||
// Issue #647, #656
|
||||
if _, ok := config.Claims.(jwt.MapClaims); ok {
|
||||
token, err = jwt.Parse(auth, config.KeyFunc)
|
||||
} else {
|
||||
t := reflect.ValueOf(config.Claims).Type().Elem()
|
||||
claims := reflect.New(t).Interface().(jwt.Claims)
|
||||
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// defaultKeyFunc returns a signing key of the given token.
|
||||
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
|
||||
// Check the signing method
|
||||
|
@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -404,3 +405,194 @@ func TestJWTwithKID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTConfig_skipper(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
e.Use(JWTWithConfig(JWTConfig{
|
||||
Skipper: func(context echo.Context) bool {
|
||||
return true // skip everything
|
||||
},
|
||||
SigningKey: []byte("secret"),
|
||||
}))
|
||||
|
||||
isCalled := false
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
isCalled = true
|
||||
return c.String(http.StatusTeapot, "test")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.True(t, isCalled)
|
||||
}
|
||||
|
||||
func TestJWTConfig_BeforeFunc(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusTeapot, "test")
|
||||
})
|
||||
|
||||
isCalled := false
|
||||
e.Use(JWTWithConfig(JWTConfig{
|
||||
BeforeFunc: func(context echo.Context) {
|
||||
isCalled = true
|
||||
},
|
||||
SigningKey: []byte("secret"),
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.True(t, isCalled)
|
||||
}
|
||||
|
||||
func TestJWTConfig_extractorErrorHandling(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given JWTConfig
|
||||
expectStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "ok, ErrorHandler is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandler: func(err error) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
|
||||
},
|
||||
},
|
||||
expectStatusCode: http.StatusTeapot,
|
||||
},
|
||||
{
|
||||
name: "ok, ErrorHandlerWithContext is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandlerWithContext: func(err error, context echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "custom_error")
|
||||
},
|
||||
},
|
||||
expectStatusCode: http.StatusTeapot,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusNotImplemented, "should not end up here")
|
||||
})
|
||||
|
||||
e.Use(JWTWithConfig(tc.given))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, tc.expectStatusCode, res.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTConfig_parseTokenErrorHandling(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
given JWTConfig
|
||||
expectErr string
|
||||
}{
|
||||
{
|
||||
name: "ok, ErrorHandler is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandler: func(err error) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandler: "+err.Error())
|
||||
},
|
||||
},
|
||||
expectErr: "{\"message\":\"ErrorHandler: parsing failed\"}\n",
|
||||
},
|
||||
{
|
||||
name: "ok, ErrorHandlerWithContext is executed",
|
||||
given: JWTConfig{
|
||||
SigningKey: []byte("secret"),
|
||||
ErrorHandlerWithContext: func(err error, context echo.Context) error {
|
||||
return echo.NewHTTPError(http.StatusTeapot, "ErrorHandlerWithContext: "+err.Error())
|
||||
},
|
||||
},
|
||||
expectErr: "{\"message\":\"ErrorHandlerWithContext: parsing failed\"}\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
//e.Debug = true
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusNotImplemented, "should not end up here")
|
||||
})
|
||||
|
||||
config := tc.given
|
||||
parseTokenCalled := false
|
||||
config.ParseTokenFunc = func(auth string, c echo.Context) (interface{}, error) {
|
||||
parseTokenCalled = true
|
||||
return nil, errors.New("parsing failed")
|
||||
}
|
||||
e.Use(JWTWithConfig(config))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
assert.Equal(t, tc.expectErr, res.Body.String())
|
||||
assert.True(t, parseTokenCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusTeapot, "test")
|
||||
})
|
||||
|
||||
// example of minimal custom ParseTokenFunc implementation. Allows you to use different versions of `github.com/dgrijalva/jwt-go`
|
||||
// with current JWT middleware
|
||||
signingKey := []byte("secret")
|
||||
|
||||
config := JWTConfig{
|
||||
ParseTokenFunc: func(auth string, c echo.Context) (interface{}, error) {
|
||||
keyFunc := func(t *jwt.Token) (interface{}, error) {
|
||||
if t.Method.Alg() != "HS256" {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
return signingKey, nil
|
||||
}
|
||||
|
||||
// claims are of type `jwt.MapClaims` when token is created with `jwt.Parse`
|
||||
token, err := jwt.Parse(auth, keyFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !token.Valid {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
return token, nil
|
||||
},
|
||||
}
|
||||
|
||||
e.Use(JWTWithConfig(config))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(echo.HeaderAuthorization, DefaultJWTConfig.AuthScheme+" eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
|
||||
res := httptest.NewRecorder()
|
||||
e.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusTeapot, res.Code)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user