mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Added JWT middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
be825e0229
commit
02676bdb44
@ -37,11 +37,11 @@ func New(addr string) *Server {
|
||||
}
|
||||
|
||||
// WithTLS returns `Server` with provided TLS config.
|
||||
func WithTLS(addr, certfile, keyfile string) *Server {
|
||||
func WithTLS(addr, certFile, keyFile string) *Server {
|
||||
c := engine.Config{
|
||||
Address: addr,
|
||||
TLSCertfile: certfile,
|
||||
TLSKeyfile: keyfile,
|
||||
TLSCertfile: certFile,
|
||||
TLSKeyfile: keyFile,
|
||||
}
|
||||
return WithConfig(c)
|
||||
}
|
||||
|
@ -35,11 +35,11 @@ func New(addr string) *Server {
|
||||
}
|
||||
|
||||
// WithTLS returns `Server` instance with provided TLS config.
|
||||
func WithTLS(addr, certfile, keyfile string) *Server {
|
||||
func WithTLS(addr, certFile, keyFile string) *Server {
|
||||
c := engine.Config{
|
||||
Address: addr,
|
||||
TLSCertfile: certfile,
|
||||
TLSKeyfile: keyfile,
|
||||
TLSCertfile: certFile,
|
||||
TLSKeyfile: keyFile,
|
||||
}
|
||||
return WithConfig(c)
|
||||
}
|
||||
|
6
glide.lock
generated
6
glide.lock
generated
@ -1,6 +1,8 @@
|
||||
hash: 44dfc8aaffca5078e71afdb209a0ef0a359a35f69fb98c7b6a2fb87a5a70e757
|
||||
updated: 2016-04-24T10:21:38.007105128-07:00
|
||||
hash: 21820434709470e49c64df0f854d3352088ca664d193e29bc6cd434518c27a7c
|
||||
updated: 2016-04-24T11:03:22.86754619-07:00
|
||||
imports:
|
||||
- name: github.com/dgrijalva/jwt-go
|
||||
version: a2c85815a77d0f951e33ba4db5ae93629a1530af
|
||||
- name: github.com/klauspost/compress
|
||||
version: 14eb9c4951195779ecfbec34431a976de7335b0a
|
||||
subpackages:
|
||||
|
@ -13,3 +13,4 @@ import:
|
||||
- package: github.com/stretchr/testify
|
||||
subpackages:
|
||||
- assert
|
||||
- package: github.com/dgrijalva/jwt-go
|
||||
|
@ -2,7 +2,10 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
@ -15,24 +18,62 @@ type (
|
||||
|
||||
// BasicAuthFunc defines a function to validate basic auth credentials.
|
||||
BasicAuthFunc func(string, string) bool
|
||||
|
||||
// JWTAuthConfig defines the config for JWT auth middleware.
|
||||
JWTAuthConfig struct {
|
||||
// SigningKey is the key to validate token.
|
||||
// Required.
|
||||
SigningKey string
|
||||
|
||||
// SigningMethod is used to check token signing method.
|
||||
// Optional, with default value as `HS256`.
|
||||
SigningMethod string
|
||||
|
||||
// ContextKey is the key to be used for storing user information from the
|
||||
// token into context.
|
||||
// Optional, with default value as `user`.
|
||||
ContextKey string
|
||||
|
||||
// Extractor is a function that extracts token from the request
|
||||
// Optional, with default values as `JWTFromHeader`.
|
||||
Extractor JWTExtractor
|
||||
}
|
||||
|
||||
// JWTExtractor defines a function that takes `echo.Context` and returns either
|
||||
// a token or an error.
|
||||
JWTExtractor func(echo.Context) (string, error)
|
||||
)
|
||||
|
||||
const (
|
||||
basic = "Basic"
|
||||
basic = "Basic"
|
||||
bearer = "Bearer"
|
||||
)
|
||||
|
||||
// Algorithims
|
||||
const (
|
||||
AlgorithmHS256 = "HS256"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
||||
DefaultBasicAuthConfig = BasicAuthConfig{}
|
||||
|
||||
// DefaultJWTAuthConfig is the default JWT auth middleware config.
|
||||
DefaultJWTAuthConfig = JWTAuthConfig{
|
||||
SigningMethod: AlgorithmHS256,
|
||||
ContextKey: "user",
|
||||
Extractor: JWTFromHeader,
|
||||
}
|
||||
)
|
||||
|
||||
// BasicAuth returns an HTTP basic auth middleware.
|
||||
//
|
||||
// For valid credentials it calls the next handler.
|
||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||
func BasicAuth(f BasicAuthFunc) echo.MiddlewareFunc {
|
||||
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
|
||||
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
|
||||
c := DefaultBasicAuthConfig
|
||||
c.AuthFunc = f
|
||||
c.AuthFunc = fn
|
||||
return BasicAuthWithConfig(c)
|
||||
}
|
||||
|
||||
@ -46,19 +87,94 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
|
||||
if len(auth) > l+1 && auth[:l] == basic {
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err == nil {
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
if config.AuthFunc(cred[:i], cred[i+1:]) {
|
||||
return next(c)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
if config.AuthFunc(cred[:i], cred[i+1:]) {
|
||||
return next(c)
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
||||
return echo.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request
|
||||
// header.
|
||||
func JWTFromHeader(c echo.Context) (string, error) {
|
||||
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
||||
l := len(bearer)
|
||||
if len(auth) > l+1 && auth[:l] == bearer {
|
||||
return auth[l+1:], nil
|
||||
}
|
||||
return "", echo.NewHTTPError(http.StatusBadRequest, "invalid jwt authorization header="+auth)
|
||||
}
|
||||
|
||||
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
|
||||
// parameter.
|
||||
func JWTFromQuery(param string) JWTExtractor {
|
||||
return func(c echo.Context) (string, error) {
|
||||
return c.QueryParam(param), nil
|
||||
}
|
||||
}
|
||||
|
||||
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
|
||||
//
|
||||
// For valid token, it sets the user in context and calls next handler.
|
||||
// For invalid token, it sends "401 - Unauthorized" response.
|
||||
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
|
||||
//
|
||||
// See https://jwt.io/introduction
|
||||
func JWTAuth(key string) echo.MiddlewareFunc {
|
||||
c := DefaultJWTAuthConfig
|
||||
c.SigningKey = key
|
||||
return JWTAuthWithConfig(c)
|
||||
}
|
||||
|
||||
// JWTAuthWithConfig returns a JWT auth middleware from config.
|
||||
// See `JWTAuth()`.
|
||||
func JWTAuthWithConfig(config JWTAuthConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
if config.SigningKey == "" {
|
||||
panic("jwt middleware requires signing key")
|
||||
}
|
||||
if config.SigningMethod == "" {
|
||||
config.SigningMethod = DefaultJWTAuthConfig.SigningMethod
|
||||
}
|
||||
if config.ContextKey == "" {
|
||||
config.ContextKey = DefaultJWTAuthConfig.ContextKey
|
||||
}
|
||||
if config.Extractor == nil {
|
||||
config.Extractor = DefaultJWTAuthConfig.Extractor
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
auth, err := config.Extractor(c)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
token, err := jwt.Parse(auth, func(t *jwt.Token) (interface{}, error) {
|
||||
// Check the signing method
|
||||
if t.Method.Alg() != config.SigningMethod {
|
||||
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||
}
|
||||
return []byte(config.SigningKey), nil
|
||||
|
||||
})
|
||||
if err == nil && token.Valid {
|
||||
// Store user information from token into context.
|
||||
c.Set(config.ContextKey, token)
|
||||
return next(c)
|
||||
}
|
||||
return echo.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/labstack/echo"
|
||||
"github.com/labstack/echo/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -30,10 +31,6 @@ func TestBasicAuth(t *testing.T) {
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
//---------------------
|
||||
// Invalid credentials
|
||||
//---------------------
|
||||
|
||||
// Incorrect password
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
@ -44,13 +41,56 @@ func TestBasicAuth(t *testing.T) {
|
||||
// Empty Authorization header
|
||||
req.Header().Set(echo.HeaderAuthorization, "")
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.GET, "/", nil)
|
||||
res := test.NewResponseRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
handler := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
config := JWTAuthConfig{}
|
||||
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
|
||||
|
||||
// No signing key provided
|
||||
assert.Panics(t, func() {
|
||||
JWTAuthWithConfig(config)
|
||||
})
|
||||
|
||||
// Unexpected signing method
|
||||
config.SigningKey = "secret"
|
||||
config.SigningMethod = "RS256"
|
||||
h := JWTAuthWithConfig(config)(handler)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
|
||||
// Invalid key
|
||||
auth := bearer + " " + token
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
config.SigningKey = "invalid-key"
|
||||
h = JWTAuthWithConfig(config)(handler)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
|
||||
// Valid JWT
|
||||
h = JWTAuth("secret")(handler)
|
||||
if assert.NoError(t, h(c)) {
|
||||
user := c.Get("user").(*jwt.Token)
|
||||
assert.Equal(t, user.Claims["name"], "John Doe")
|
||||
}
|
||||
|
||||
// Invalid Authorization header
|
||||
req.Header().Set(echo.HeaderAuthorization, "invalid-auth")
|
||||
h = JWTAuth("secret")(handler)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
}
|
||||
|
@ -30,11 +30,11 @@ func New(addr string) *Server {
|
||||
return NewConfig(c)
|
||||
}
|
||||
|
||||
func NewTLS(addr, certfile, keyfile string) *Server {
|
||||
func NewTLS(addr, certFile, keyFile string) *Server {
|
||||
c := &engine.Config{
|
||||
Address: addr,
|
||||
TLSCertfile: certfile,
|
||||
TLSKeyfile: keyfile,
|
||||
TLSCertfile: certFile,
|
||||
TLSKeyfile: keyFile,
|
||||
}
|
||||
return NewConfig(c)
|
||||
}
|
||||
@ -84,10 +84,10 @@ func (s *Server) SetLogger(l *log.Logger) {
|
||||
func (s *Server) Start() {
|
||||
s.Addr = s.config.Address
|
||||
s.Handler = s
|
||||
certfile := s.config.TLSCertfile
|
||||
keyfile := s.config.TLSKeyfile
|
||||
if certfile != "" && keyfile != "" {
|
||||
s.logger.Fatal(s.ListenAndServeTLS(certfile, keyfile))
|
||||
certFile := s.config.TLSCertfile
|
||||
keyFile := s.config.TLSKeyfile
|
||||
if certFile != "" && keyFile != "" {
|
||||
s.logger.Fatal(s.ListenAndServeTLS(certFile, keyFile))
|
||||
} else {
|
||||
s.logger.Fatal(s.ListenAndServe())
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user