1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added JWT middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-04-25 10:58:11 -07:00
parent be825e0229
commit 02676bdb44
7 changed files with 194 additions and 35 deletions

View File

@ -37,11 +37,11 @@ func New(addr string) *Server {
} }
// WithTLS returns `Server` with provided TLS config. // WithTLS returns `Server` with provided TLS config.
func WithTLS(addr, certfile, keyfile string) *Server { func WithTLS(addr, certFile, keyFile string) *Server {
c := engine.Config{ c := engine.Config{
Address: addr, Address: addr,
TLSCertfile: certfile, TLSCertfile: certFile,
TLSKeyfile: keyfile, TLSKeyfile: keyFile,
} }
return WithConfig(c) return WithConfig(c)
} }

View File

@ -35,11 +35,11 @@ func New(addr string) *Server {
} }
// WithTLS returns `Server` instance with provided TLS config. // WithTLS returns `Server` instance with provided TLS config.
func WithTLS(addr, certfile, keyfile string) *Server { func WithTLS(addr, certFile, keyFile string) *Server {
c := engine.Config{ c := engine.Config{
Address: addr, Address: addr,
TLSCertfile: certfile, TLSCertfile: certFile,
TLSKeyfile: keyfile, TLSKeyfile: keyFile,
} }
return WithConfig(c) return WithConfig(c)
} }

6
glide.lock generated
View File

@ -1,6 +1,8 @@
hash: 44dfc8aaffca5078e71afdb209a0ef0a359a35f69fb98c7b6a2fb87a5a70e757 hash: 21820434709470e49c64df0f854d3352088ca664d193e29bc6cd434518c27a7c
updated: 2016-04-24T10:21:38.007105128-07:00 updated: 2016-04-24T11:03:22.86754619-07:00
imports: imports:
- name: github.com/dgrijalva/jwt-go
version: a2c85815a77d0f951e33ba4db5ae93629a1530af
- name: github.com/klauspost/compress - name: github.com/klauspost/compress
version: 14eb9c4951195779ecfbec34431a976de7335b0a version: 14eb9c4951195779ecfbec34431a976de7335b0a
subpackages: subpackages:

View File

@ -13,3 +13,4 @@ import:
- package: github.com/stretchr/testify - package: github.com/stretchr/testify
subpackages: subpackages:
- assert - assert
- package: github.com/dgrijalva/jwt-go

View File

@ -2,7 +2,10 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"fmt"
"net/http"
"github.com/dgrijalva/jwt-go"
"github.com/labstack/echo" "github.com/labstack/echo"
) )
@ -15,24 +18,62 @@ type (
// BasicAuthFunc defines a function to validate basic auth credentials. // BasicAuthFunc defines a function to validate basic auth credentials.
BasicAuthFunc func(string, string) bool BasicAuthFunc func(string, string) bool
// JWTAuthConfig defines the config for JWT auth middleware.
JWTAuthConfig struct {
// SigningKey is the key to validate token.
// Required.
SigningKey string
// SigningMethod is used to check token signing method.
// Optional, with default value as `HS256`.
SigningMethod string
// ContextKey is the key to be used for storing user information from the
// token into context.
// Optional, with default value as `user`.
ContextKey string
// Extractor is a function that extracts token from the request
// Optional, with default values as `JWTFromHeader`.
Extractor JWTExtractor
}
// JWTExtractor defines a function that takes `echo.Context` and returns either
// a token or an error.
JWTExtractor func(echo.Context) (string, error)
) )
const ( const (
basic = "Basic" basic = "Basic"
bearer = "Bearer"
)
// Algorithims
const (
AlgorithmHS256 = "HS256"
) )
var ( var (
// DefaultBasicAuthConfig is the default basic auth middleware config. // DefaultBasicAuthConfig is the default basic auth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{} DefaultBasicAuthConfig = BasicAuthConfig{}
// DefaultJWTAuthConfig is the default JWT auth middleware config.
DefaultJWTAuthConfig = JWTAuthConfig{
SigningMethod: AlgorithmHS256,
ContextKey: "user",
Extractor: JWTFromHeader,
}
) )
// BasicAuth returns an HTTP basic auth middleware. // BasicAuth returns an HTTP basic auth middleware.
// //
// For valid credentials it calls the next handler. // For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response. // For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(f BasicAuthFunc) echo.MiddlewareFunc { // For empty or invalid `Authorization` header, it sends "400 - Bad Request".
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig c := DefaultBasicAuthConfig
c.AuthFunc = f c.AuthFunc = fn
return BasicAuthWithConfig(c) return BasicAuthWithConfig(c)
} }
@ -46,7 +87,9 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
if len(auth) > l+1 && auth[:l] == basic { if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:]) b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err == nil { if err != nil {
return err
}
cred := string(b) cred := string(b)
for i := 0; i < len(cred); i++ { for i := 0; i < len(cred); i++ {
if cred[i] == ':' { if cred[i] == ':' {
@ -54,12 +97,85 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
if config.AuthFunc(cred[:i], cred[i+1:]) { if config.AuthFunc(cred[:i], cred[i+1:]) {
return next(c) return next(c)
} }
}
}
}
}
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted") c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized return echo.ErrUnauthorized
} }
} }
} }
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
}
}
}
// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request
// header.
func JWTFromHeader(c echo.Context) (string, error) {
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(bearer)
if len(auth) > l+1 && auth[:l] == bearer {
return auth[l+1:], nil
}
return "", echo.NewHTTPError(http.StatusBadRequest, "invalid jwt authorization header="+auth)
}
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
// parameter.
func JWTFromQuery(param string) JWTExtractor {
return func(c echo.Context) (string, error) {
return c.QueryParam(param), nil
}
}
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
//
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
//
// See https://jwt.io/introduction
func JWTAuth(key string) echo.MiddlewareFunc {
c := DefaultJWTAuthConfig
c.SigningKey = key
return JWTAuthWithConfig(c)
}
// JWTAuthWithConfig returns a JWT auth middleware from config.
// See `JWTAuth()`.
func JWTAuthWithConfig(config JWTAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.SigningKey == "" {
panic("jwt middleware requires signing key")
}
if config.SigningMethod == "" {
config.SigningMethod = DefaultJWTAuthConfig.SigningMethod
}
if config.ContextKey == "" {
config.ContextKey = DefaultJWTAuthConfig.ContextKey
}
if config.Extractor == nil {
config.Extractor = DefaultJWTAuthConfig.Extractor
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth, err := config.Extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
token, err := jwt.Parse(auth, func(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
}
return []byte(config.SigningKey), nil
})
if err == nil && token.Valid {
// Store user information from token into context.
c.Set(config.ContextKey, token)
return next(c)
}
return echo.ErrUnauthorized
}
}
}

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"testing" "testing"
"github.com/dgrijalva/jwt-go"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test" "github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -30,10 +31,6 @@ func TestBasicAuth(t *testing.T) {
req.Header().Set(echo.HeaderAuthorization, auth) req.Header().Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c)) assert.NoError(t, h(c))
//---------------------
// Invalid credentials
//---------------------
// Incorrect password // Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.HeaderAuthorization, auth) req.Header().Set(echo.HeaderAuthorization, auth)
@ -44,13 +41,56 @@ func TestBasicAuth(t *testing.T) {
// Empty Authorization header // Empty Authorization header
req.Header().Set(echo.HeaderAuthorization, "") req.Header().Set(echo.HeaderAuthorization, "")
he = h(c).(*echo.HTTPError) he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, http.StatusBadRequest, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Invalid Authorization header // Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid")) auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.HeaderAuthorization, auth) req.Header().Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError) he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, http.StatusBadRequest, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) }
func TestJWTAuth(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
res := test.NewResponseRecorder()
c := e.NewContext(req, res)
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
config := JWTAuthConfig{}
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
// No signing key provided
assert.Panics(t, func() {
JWTAuthWithConfig(config)
})
// Unexpected signing method
config.SigningKey = "secret"
config.SigningMethod = "RS256"
h := JWTAuthWithConfig(config)(handler)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Invalid key
auth := bearer + " " + token
req.Header().Set(echo.HeaderAuthorization, auth)
config.SigningKey = "invalid-key"
h = JWTAuthWithConfig(config)(handler)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
// Valid JWT
h = JWTAuth("secret")(handler)
if assert.NoError(t, h(c)) {
user := c.Get("user").(*jwt.Token)
assert.Equal(t, user.Claims["name"], "John Doe")
}
// Invalid Authorization header
req.Header().Set(echo.HeaderAuthorization, "invalid-auth")
h = JWTAuth("secret")(handler)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
} }

View File

@ -30,11 +30,11 @@ func New(addr string) *Server {
return NewConfig(c) return NewConfig(c)
} }
func NewTLS(addr, certfile, keyfile string) *Server { func NewTLS(addr, certFile, keyFile string) *Server {
c := &engine.Config{ c := &engine.Config{
Address: addr, Address: addr,
TLSCertfile: certfile, TLSCertfile: certFile,
TLSKeyfile: keyfile, TLSKeyfile: keyFile,
} }
return NewConfig(c) return NewConfig(c)
} }
@ -84,10 +84,10 @@ func (s *Server) SetLogger(l *log.Logger) {
func (s *Server) Start() { func (s *Server) Start() {
s.Addr = s.config.Address s.Addr = s.config.Address
s.Handler = s s.Handler = s
certfile := s.config.TLSCertfile certFile := s.config.TLSCertfile
keyfile := s.config.TLSKeyfile keyFile := s.config.TLSKeyfile
if certfile != "" && keyfile != "" { if certFile != "" && keyFile != "" {
s.logger.Fatal(s.ListenAndServeTLS(certfile, keyfile)) s.logger.Fatal(s.ListenAndServeTLS(certFile, keyFile))
} else { } else {
s.logger.Fatal(s.ListenAndServe()) s.logger.Fatal(s.ListenAndServe())
} }