From 02676bdb446eac0e6fd775ac9e3c14b00cb62b65 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Mon, 25 Apr 2016 10:58:11 -0700 Subject: [PATCH] Added JWT middleware Signed-off-by: Vishal Rana --- engine/fasthttp/server.go | 6 +- engine/standard/server.go | 6 +- glide.lock | 6 +- glide.yaml | 1 + middleware/auth.go | 140 ++++++++++++++++++++++++++++++++++---- middleware/auth_test.go | 56 ++++++++++++--- test/server.go | 14 ++-- 7 files changed, 194 insertions(+), 35 deletions(-) diff --git a/engine/fasthttp/server.go b/engine/fasthttp/server.go index c89bca99..05dd745b 100644 --- a/engine/fasthttp/server.go +++ b/engine/fasthttp/server.go @@ -37,11 +37,11 @@ func New(addr string) *Server { } // WithTLS returns `Server` with provided TLS config. -func WithTLS(addr, certfile, keyfile string) *Server { +func WithTLS(addr, certFile, keyFile string) *Server { c := engine.Config{ Address: addr, - TLSCertfile: certfile, - TLSKeyfile: keyfile, + TLSCertfile: certFile, + TLSKeyfile: keyFile, } return WithConfig(c) } diff --git a/engine/standard/server.go b/engine/standard/server.go index 51ea4992..1c05fd15 100644 --- a/engine/standard/server.go +++ b/engine/standard/server.go @@ -35,11 +35,11 @@ func New(addr string) *Server { } // WithTLS returns `Server` instance with provided TLS config. -func WithTLS(addr, certfile, keyfile string) *Server { +func WithTLS(addr, certFile, keyFile string) *Server { c := engine.Config{ Address: addr, - TLSCertfile: certfile, - TLSKeyfile: keyfile, + TLSCertfile: certFile, + TLSKeyfile: keyFile, } return WithConfig(c) } diff --git a/glide.lock b/glide.lock index f19a9521..a60e5336 100644 --- a/glide.lock +++ b/glide.lock @@ -1,6 +1,8 @@ -hash: 44dfc8aaffca5078e71afdb209a0ef0a359a35f69fb98c7b6a2fb87a5a70e757 -updated: 2016-04-24T10:21:38.007105128-07:00 +hash: 21820434709470e49c64df0f854d3352088ca664d193e29bc6cd434518c27a7c +updated: 2016-04-24T11:03:22.86754619-07:00 imports: +- name: github.com/dgrijalva/jwt-go + version: a2c85815a77d0f951e33ba4db5ae93629a1530af - name: github.com/klauspost/compress version: 14eb9c4951195779ecfbec34431a976de7335b0a subpackages: diff --git a/glide.yaml b/glide.yaml index 1e45996c..39f791ee 100644 --- a/glide.yaml +++ b/glide.yaml @@ -13,3 +13,4 @@ import: - package: github.com/stretchr/testify subpackages: - assert +- package: github.com/dgrijalva/jwt-go diff --git a/middleware/auth.go b/middleware/auth.go index 8800b16e..2a5c6b0b 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -2,7 +2,10 @@ package middleware import ( "encoding/base64" + "fmt" + "net/http" + "github.com/dgrijalva/jwt-go" "github.com/labstack/echo" ) @@ -15,24 +18,62 @@ type ( // BasicAuthFunc defines a function to validate basic auth credentials. BasicAuthFunc func(string, string) bool + + // JWTAuthConfig defines the config for JWT auth middleware. + JWTAuthConfig struct { + // SigningKey is the key to validate token. + // Required. + SigningKey string + + // SigningMethod is used to check token signing method. + // Optional, with default value as `HS256`. + SigningMethod string + + // ContextKey is the key to be used for storing user information from the + // token into context. + // Optional, with default value as `user`. + ContextKey string + + // Extractor is a function that extracts token from the request + // Optional, with default values as `JWTFromHeader`. + Extractor JWTExtractor + } + + // JWTExtractor defines a function that takes `echo.Context` and returns either + // a token or an error. + JWTExtractor func(echo.Context) (string, error) ) const ( - basic = "Basic" + basic = "Basic" + bearer = "Bearer" +) + +// Algorithims +const ( + AlgorithmHS256 = "HS256" ) var ( // DefaultBasicAuthConfig is the default basic auth middleware config. DefaultBasicAuthConfig = BasicAuthConfig{} + + // DefaultJWTAuthConfig is the default JWT auth middleware config. + DefaultJWTAuthConfig = JWTAuthConfig{ + SigningMethod: AlgorithmHS256, + ContextKey: "user", + Extractor: JWTFromHeader, + } ) // BasicAuth returns an HTTP basic auth middleware. // // For valid credentials it calls the next handler. // For invalid credentials, it sends "401 - Unauthorized" response. -func BasicAuth(f BasicAuthFunc) echo.MiddlewareFunc { +// For empty or invalid `Authorization` header, it sends "400 - Bad Request". +func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc { c := DefaultBasicAuthConfig - c.AuthFunc = f + c.AuthFunc = fn return BasicAuthWithConfig(c) } @@ -46,19 +87,94 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { if len(auth) > l+1 && auth[:l] == basic { b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err == nil { - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - if config.AuthFunc(cred[:i], cred[i+1:]) { - return next(c) - } + if err != nil { + return err + } + cred := string(b) + for i := 0; i < len(cred); i++ { + if cred[i] == ':' { + // Verify credentials + if config.AuthFunc(cred[:i], cred[i+1:]) { + return next(c) } + c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted") + return echo.ErrUnauthorized } } } - c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted") + return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth) + } + } +} + +// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request +// header. +func JWTFromHeader(c echo.Context) (string, error) { + auth := c.Request().Header().Get(echo.HeaderAuthorization) + l := len(bearer) + if len(auth) > l+1 && auth[:l] == bearer { + return auth[l+1:], nil + } + return "", echo.NewHTTPError(http.StatusBadRequest, "invalid jwt authorization header="+auth) +} + +// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query +// parameter. +func JWTFromQuery(param string) JWTExtractor { + return func(c echo.Context) (string, error) { + return c.QueryParam(param), nil + } +} + +// JWTAuth returns a JSON Web Token (JWT) auth middleware. +// +// For valid token, it sets the user in context and calls next handler. +// For invalid token, it sends "401 - Unauthorized" response. +// For empty or invalid `Authorization` header, it sends "400 - Bad Request". +// +// See https://jwt.io/introduction +func JWTAuth(key string) echo.MiddlewareFunc { + c := DefaultJWTAuthConfig + c.SigningKey = key + return JWTAuthWithConfig(c) +} + +// JWTAuthWithConfig returns a JWT auth middleware from config. +// See `JWTAuth()`. +func JWTAuthWithConfig(config JWTAuthConfig) echo.MiddlewareFunc { + // Defaults + if config.SigningKey == "" { + panic("jwt middleware requires signing key") + } + if config.SigningMethod == "" { + config.SigningMethod = DefaultJWTAuthConfig.SigningMethod + } + if config.ContextKey == "" { + config.ContextKey = DefaultJWTAuthConfig.ContextKey + } + if config.Extractor == nil { + config.Extractor = DefaultJWTAuthConfig.Extractor + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + auth, err := config.Extractor(c) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + token, err := jwt.Parse(auth, func(t *jwt.Token) (interface{}, error) { + // Check the signing method + if t.Method.Alg() != config.SigningMethod { + return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"]) + } + return []byte(config.SigningKey), nil + + }) + if err == nil && token.Valid { + // Store user information from token into context. + c.Set(config.ContextKey, token) + return next(c) + } return echo.ErrUnauthorized } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 4cc8793b..aaaa1e24 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -5,6 +5,7 @@ import ( "net/http" "testing" + "github.com/dgrijalva/jwt-go" "github.com/labstack/echo" "github.com/labstack/echo/test" "github.com/stretchr/testify/assert" @@ -30,10 +31,6 @@ func TestBasicAuth(t *testing.T) { req.Header().Set(echo.HeaderAuthorization, auth) assert.NoError(t, h(c)) - //--------------------- - // Invalid credentials - //--------------------- - // Incorrect password auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) req.Header().Set(echo.HeaderAuthorization, auth) @@ -44,13 +41,56 @@ func TestBasicAuth(t *testing.T) { // Empty Authorization header req.Header().Set(echo.HeaderAuthorization, "") he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) + assert.Equal(t, http.StatusBadRequest, he.Code) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) req.Header().Set(echo.HeaderAuthorization, auth) he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) + assert.Equal(t, http.StatusBadRequest, he.Code) +} + +func TestJWTAuth(t *testing.T) { + e := echo.New() + req := test.NewRequest(echo.GET, "/", nil) + res := test.NewResponseRecorder() + c := e.NewContext(req, res) + handler := func(c echo.Context) error { + return c.String(http.StatusOK, "test") + } + config := JWTAuthConfig{} + token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" + + // No signing key provided + assert.Panics(t, func() { + JWTAuthWithConfig(config) + }) + + // Unexpected signing method + config.SigningKey = "secret" + config.SigningMethod = "RS256" + h := JWTAuthWithConfig(config)(handler) + he := h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusBadRequest, he.Code) + + // Invalid key + auth := bearer + " " + token + req.Header().Set(echo.HeaderAuthorization, auth) + config.SigningKey = "invalid-key" + h = JWTAuthWithConfig(config)(handler) + he = h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusUnauthorized, he.Code) + + // Valid JWT + h = JWTAuth("secret")(handler) + if assert.NoError(t, h(c)) { + user := c.Get("user").(*jwt.Token) + assert.Equal(t, user.Claims["name"], "John Doe") + } + + // Invalid Authorization header + req.Header().Set(echo.HeaderAuthorization, "invalid-auth") + h = JWTAuth("secret")(handler) + he = h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusBadRequest, he.Code) } diff --git a/test/server.go b/test/server.go index 444cb37a..5c286b4d 100644 --- a/test/server.go +++ b/test/server.go @@ -30,11 +30,11 @@ func New(addr string) *Server { return NewConfig(c) } -func NewTLS(addr, certfile, keyfile string) *Server { +func NewTLS(addr, certFile, keyFile string) *Server { c := &engine.Config{ Address: addr, - TLSCertfile: certfile, - TLSKeyfile: keyfile, + TLSCertfile: certFile, + TLSKeyfile: keyFile, } return NewConfig(c) } @@ -84,10 +84,10 @@ func (s *Server) SetLogger(l *log.Logger) { func (s *Server) Start() { s.Addr = s.config.Address s.Handler = s - certfile := s.config.TLSCertfile - keyfile := s.config.TLSKeyfile - if certfile != "" && keyfile != "" { - s.logger.Fatal(s.ListenAndServeTLS(certfile, keyfile)) + certFile := s.config.TLSCertfile + keyFile := s.config.TLSKeyfile + if certFile != "" && keyFile != "" { + s.logger.Fatal(s.ListenAndServeTLS(certFile, keyFile)) } else { s.logger.Fatal(s.ListenAndServe()) }