1
0
mirror of https://github.com/labstack/echo.git synced 2025-06-04 23:37:45 +02:00

Add custom jwt extractor to jwt config

This commit is contained in:
Rashad Ansari 2021-08-19 15:00:07 +02:00 committed by Martti T
parent 6b5e62b27e
commit 4fffee2ec8
2 changed files with 52 additions and 21 deletions

View File

@ -68,9 +68,14 @@ type (
// - "form:<name>" // - "form:<name>"
// Multiply sources example: // Multiply sources example:
// - "header: Authorization,cookie: myowncookie" // - "header: Authorization,cookie: myowncookie"
TokenLookup string TokenLookup string
// TokenLookupFuncs defines a list of user-defined functions that extract JWT token from the given context.
// This is one of the two options to provide a token extractor.
// The order of precedence is user-defined TokenLookupFuncs, and TokenLookup.
// You can also provide both if you want.
TokenLookupFuncs []TokenLookupFunc
// AuthScheme to be used in the Authorization header. // AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer". // Optional. Default value "Bearer".
AuthScheme string AuthScheme string
@ -103,7 +108,8 @@ type (
// JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context. // JWTErrorHandlerWithContext is almost identical to JWTErrorHandler, but it's passed the current context.
JWTErrorHandlerWithContext func(error, echo.Context) error JWTErrorHandlerWithContext func(error, echo.Context) error
jwtExtractor func(echo.Context) (string, error) // TokenLookupFunc defines a function for extracting JWT token from the given context.
TokenLookupFunc func(echo.Context) (string, error)
) )
// Algorithms // Algorithms
@ -120,13 +126,14 @@ var (
var ( var (
// DefaultJWTConfig is the default JWT auth middleware config. // DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{ DefaultJWTConfig = JWTConfig{
Skipper: DefaultSkipper, Skipper: DefaultSkipper,
SigningMethod: AlgorithmHS256, SigningMethod: AlgorithmHS256,
ContextKey: "user", ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization, TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer", TokenLookupFuncs: nil,
Claims: jwt.MapClaims{}, AuthScheme: "Bearer",
KeyFunc: nil, Claims: jwt.MapClaims{},
KeyFunc: nil,
} }
) )
@ -163,7 +170,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.Claims == nil { if config.Claims == nil {
config.Claims = DefaultJWTConfig.Claims config.Claims = DefaultJWTConfig.Claims
} }
if config.TokenLookup == "" { if config.TokenLookup == "" && len(config.TokenLookupFuncs) == 0 {
config.TokenLookup = DefaultJWTConfig.TokenLookup config.TokenLookup = DefaultJWTConfig.TokenLookup
} }
if config.AuthScheme == "" { if config.AuthScheme == "" {
@ -179,7 +186,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Initialize // Initialize
// Split sources // Split sources
sources := strings.Split(config.TokenLookup, ",") sources := strings.Split(config.TokenLookup, ",")
var extractors []jwtExtractor var extractors = config.TokenLookupFuncs
for _, source := range sources { for _, source := range sources {
parts := strings.Split(source, ":") parts := strings.Split(source, ":")
@ -290,8 +297,8 @@ func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
return config.SigningKey, nil return config.SigningKey, nil
} }
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. // jwtFromHeader returns a `TokenLookupFunc` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor { func jwtFromHeader(header string, authScheme string) TokenLookupFunc {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header) auth := c.Request().Header.Get(header)
l := len(authScheme) l := len(authScheme)
@ -302,8 +309,8 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor {
} }
} }
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string. // jwtFromQuery returns a `TokenLookupFunc` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor { func jwtFromQuery(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.QueryParam(param) token := c.QueryParam(param)
if token == "" { if token == "" {
@ -313,8 +320,8 @@ func jwtFromQuery(param string) jwtExtractor {
} }
} }
// jwtFromParam returns a `jwtExtractor` that extracts token from the url param string. // jwtFromParam returns a `TokenLookupFunc` that extracts token from the url param string.
func jwtFromParam(param string) jwtExtractor { func jwtFromParam(param string) TokenLookupFunc {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
token := c.Param(param) token := c.Param(param)
if token == "" { if token == "" {
@ -324,8 +331,8 @@ func jwtFromParam(param string) jwtExtractor {
} }
} }
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie. // jwtFromCookie returns a `TokenLookupFunc` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor { func jwtFromCookie(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name) cookie, err := c.Cookie(name)
if err != nil { if err != nil {
@ -335,8 +342,8 @@ func jwtFromCookie(name string) jwtExtractor {
} }
} }
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. // jwtFromForm returns a `TokenLookupFunc` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor { func jwtFromForm(name string) TokenLookupFunc {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
field := c.FormValue(name) field := c.FormValue(name)
if field == "" { if field == "" {

View File

@ -603,3 +603,27 @@ func TestJWTConfig_custom_ParseTokenFunc_Keyfunc(t *testing.T) {
assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, http.StatusTeapot, res.Code)
} }
func TestJWTConfig_TokenLookupFuncs(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
e.Use(JWTWithConfig(JWTConfig{
TokenLookupFuncs: []TokenLookupFunc{
func(c echo.Context) (string, error) {
return c.Request().Header.Get("X-API-Key"), nil
},
},
SigningKey: []byte("secret"),
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-API-Key", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ")
res := httptest.NewRecorder()
e.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
}