From 23176c639e6838cd5413cb9415fa8fdb93f546ba Mon Sep 17 00:00:00 2001 From: Vikram Sreekumar Date: Sun, 18 Dec 2016 16:08:46 +0530 Subject: [PATCH] jwt-authscheme: support for custom jwt auth scheme - added "AuthScheme" in the JWTConfig and set default value to "Bearer". - added test case for validating JWT Auth with a custom auth scheme. --- middleware/jwt.go | 16 ++++++++++++---- middleware/jwt_test.go | 5 +++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/middleware/jwt.go b/middleware/jwt.go index 9f0782c5..ff9c79ce 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -14,6 +14,10 @@ import ( type ( // JWTConfig defines the config for JWT middleware. JWTConfig struct { + // AuthScheme to define custom bearer variable in the Authorization header. + // Optional. Default value "Bearer" + AuthScheme string + // Skipper defines a function to skip middleware. Skipper Skipper @@ -60,6 +64,7 @@ const ( var ( // DefaultJWTConfig is the default JWT auth middleware config. DefaultJWTConfig = JWTConfig{ + AuthScheme: bearer, Skipper: defaultSkipper, SigningMethod: AlgorithmHS256, ContextKey: "user", @@ -86,6 +91,9 @@ func JWT(key []byte) echo.MiddlewareFunc { // See: `JWT()`. func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Defaults + if config.AuthScheme == "" { + config.AuthScheme = DefaultJWTConfig.AuthScheme + } if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } @@ -114,7 +122,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Initialize parts := strings.Split(config.TokenLookup, ":") - extractor := jwtFromHeader(parts[1]) + extractor := jwtFromHeader(parts[1], config.AuthScheme) switch parts[0] { case "query": extractor = jwtFromQuery(parts[1]) @@ -151,11 +159,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } // jwtFromHeader returns a `jwtExtractor` that extracts token from request header. -func jwtFromHeader(header string) jwtExtractor { +func jwtFromHeader(header string, authScheme string) jwtExtractor { return func(c echo.Context) (string, error) { auth := c.Request().Header.Get(header) - l := len(bearer) - if len(auth) > l+1 && auth[:l] == bearer { + l := len(authScheme) + if len(auth) > l+1 && auth[:l] == authScheme { return auth[l+1:], nil } return "", errors.New("empty or invalid jwt in request header") diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index a7241c87..89c5c05c 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -64,6 +64,11 @@ func TestJWT(t *testing.T) { config: JWTConfig{SigningKey: validKey}, info: "Valid JWT", }, + { + hdrAuth: "Token" + " " + token, + config: JWTConfig{AuthScheme: "Token", SigningKey: validKey}, + info: "Valid JWT with custom AuthScheme", + }, { hdrAuth: validAuth, config: JWTConfig{