mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Added JWT middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
be825e0229
commit
02676bdb44
@ -37,11 +37,11 @@ func New(addr string) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithTLS returns `Server` with provided TLS config.
|
// WithTLS returns `Server` with provided TLS config.
|
||||||
func WithTLS(addr, certfile, keyfile string) *Server {
|
func WithTLS(addr, certFile, keyFile string) *Server {
|
||||||
c := engine.Config{
|
c := engine.Config{
|
||||||
Address: addr,
|
Address: addr,
|
||||||
TLSCertfile: certfile,
|
TLSCertfile: certFile,
|
||||||
TLSKeyfile: keyfile,
|
TLSKeyfile: keyFile,
|
||||||
}
|
}
|
||||||
return WithConfig(c)
|
return WithConfig(c)
|
||||||
}
|
}
|
||||||
|
@ -35,11 +35,11 @@ func New(addr string) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WithTLS returns `Server` instance with provided TLS config.
|
// WithTLS returns `Server` instance with provided TLS config.
|
||||||
func WithTLS(addr, certfile, keyfile string) *Server {
|
func WithTLS(addr, certFile, keyFile string) *Server {
|
||||||
c := engine.Config{
|
c := engine.Config{
|
||||||
Address: addr,
|
Address: addr,
|
||||||
TLSCertfile: certfile,
|
TLSCertfile: certFile,
|
||||||
TLSKeyfile: keyfile,
|
TLSKeyfile: keyFile,
|
||||||
}
|
}
|
||||||
return WithConfig(c)
|
return WithConfig(c)
|
||||||
}
|
}
|
||||||
|
6
glide.lock
generated
6
glide.lock
generated
@ -1,6 +1,8 @@
|
|||||||
hash: 44dfc8aaffca5078e71afdb209a0ef0a359a35f69fb98c7b6a2fb87a5a70e757
|
hash: 21820434709470e49c64df0f854d3352088ca664d193e29bc6cd434518c27a7c
|
||||||
updated: 2016-04-24T10:21:38.007105128-07:00
|
updated: 2016-04-24T11:03:22.86754619-07:00
|
||||||
imports:
|
imports:
|
||||||
|
- name: github.com/dgrijalva/jwt-go
|
||||||
|
version: a2c85815a77d0f951e33ba4db5ae93629a1530af
|
||||||
- name: github.com/klauspost/compress
|
- name: github.com/klauspost/compress
|
||||||
version: 14eb9c4951195779ecfbec34431a976de7335b0a
|
version: 14eb9c4951195779ecfbec34431a976de7335b0a
|
||||||
subpackages:
|
subpackages:
|
||||||
|
@ -13,3 +13,4 @@ import:
|
|||||||
- package: github.com/stretchr/testify
|
- package: github.com/stretchr/testify
|
||||||
subpackages:
|
subpackages:
|
||||||
- assert
|
- assert
|
||||||
|
- package: github.com/dgrijalva/jwt-go
|
||||||
|
@ -2,7 +2,10 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,24 +18,62 @@ type (
|
|||||||
|
|
||||||
// BasicAuthFunc defines a function to validate basic auth credentials.
|
// BasicAuthFunc defines a function to validate basic auth credentials.
|
||||||
BasicAuthFunc func(string, string) bool
|
BasicAuthFunc func(string, string) bool
|
||||||
|
|
||||||
|
// JWTAuthConfig defines the config for JWT auth middleware.
|
||||||
|
JWTAuthConfig struct {
|
||||||
|
// SigningKey is the key to validate token.
|
||||||
|
// Required.
|
||||||
|
SigningKey string
|
||||||
|
|
||||||
|
// SigningMethod is used to check token signing method.
|
||||||
|
// Optional, with default value as `HS256`.
|
||||||
|
SigningMethod string
|
||||||
|
|
||||||
|
// ContextKey is the key to be used for storing user information from the
|
||||||
|
// token into context.
|
||||||
|
// Optional, with default value as `user`.
|
||||||
|
ContextKey string
|
||||||
|
|
||||||
|
// Extractor is a function that extracts token from the request
|
||||||
|
// Optional, with default values as `JWTFromHeader`.
|
||||||
|
Extractor JWTExtractor
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTExtractor defines a function that takes `echo.Context` and returns either
|
||||||
|
// a token or an error.
|
||||||
|
JWTExtractor func(echo.Context) (string, error)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
basic = "Basic"
|
basic = "Basic"
|
||||||
|
bearer = "Bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Algorithims
|
||||||
|
const (
|
||||||
|
AlgorithmHS256 = "HS256"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
||||||
DefaultBasicAuthConfig = BasicAuthConfig{}
|
DefaultBasicAuthConfig = BasicAuthConfig{}
|
||||||
|
|
||||||
|
// DefaultJWTAuthConfig is the default JWT auth middleware config.
|
||||||
|
DefaultJWTAuthConfig = JWTAuthConfig{
|
||||||
|
SigningMethod: AlgorithmHS256,
|
||||||
|
ContextKey: "user",
|
||||||
|
Extractor: JWTFromHeader,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// BasicAuth returns an HTTP basic auth middleware.
|
// BasicAuth returns an HTTP basic auth middleware.
|
||||||
//
|
//
|
||||||
// For valid credentials it calls the next handler.
|
// For valid credentials it calls the next handler.
|
||||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||||
func BasicAuth(f BasicAuthFunc) echo.MiddlewareFunc {
|
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
|
||||||
|
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
|
||||||
c := DefaultBasicAuthConfig
|
c := DefaultBasicAuthConfig
|
||||||
c.AuthFunc = f
|
c.AuthFunc = fn
|
||||||
return BasicAuthWithConfig(c)
|
return BasicAuthWithConfig(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,7 +87,9 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
|||||||
|
|
||||||
if len(auth) > l+1 && auth[:l] == basic {
|
if len(auth) > l+1 && auth[:l] == basic {
|
||||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
cred := string(b)
|
cred := string(b)
|
||||||
for i := 0; i < len(cred); i++ {
|
for i := 0; i < len(cred); i++ {
|
||||||
if cred[i] == ':' {
|
if cred[i] == ':' {
|
||||||
@ -54,12 +97,85 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
|||||||
if config.AuthFunc(cred[:i], cred[i+1:]) {
|
if config.AuthFunc(cred[:i], cred[i+1:]) {
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
||||||
return echo.ErrUnauthorized
|
return echo.ErrUnauthorized
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTFromHeader is a `JWTExtractor` that extracts token from the `Authorization` request
|
||||||
|
// header.
|
||||||
|
func JWTFromHeader(c echo.Context) (string, error) {
|
||||||
|
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
||||||
|
l := len(bearer)
|
||||||
|
if len(auth) > l+1 && auth[:l] == bearer {
|
||||||
|
return auth[l+1:], nil
|
||||||
|
}
|
||||||
|
return "", echo.NewHTTPError(http.StatusBadRequest, "invalid jwt authorization header="+auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTFromQuery returns a `JWTExtractor` that extracts token from the provided query
|
||||||
|
// parameter.
|
||||||
|
func JWTFromQuery(param string) JWTExtractor {
|
||||||
|
return func(c echo.Context) (string, error) {
|
||||||
|
return c.QueryParam(param), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
|
||||||
|
//
|
||||||
|
// For valid token, it sets the user in context and calls next handler.
|
||||||
|
// For invalid token, it sends "401 - Unauthorized" response.
|
||||||
|
// For empty or invalid `Authorization` header, it sends "400 - Bad Request".
|
||||||
|
//
|
||||||
|
// See https://jwt.io/introduction
|
||||||
|
func JWTAuth(key string) echo.MiddlewareFunc {
|
||||||
|
c := DefaultJWTAuthConfig
|
||||||
|
c.SigningKey = key
|
||||||
|
return JWTAuthWithConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTAuthWithConfig returns a JWT auth middleware from config.
|
||||||
|
// See `JWTAuth()`.
|
||||||
|
func JWTAuthWithConfig(config JWTAuthConfig) echo.MiddlewareFunc {
|
||||||
|
// Defaults
|
||||||
|
if config.SigningKey == "" {
|
||||||
|
panic("jwt middleware requires signing key")
|
||||||
|
}
|
||||||
|
if config.SigningMethod == "" {
|
||||||
|
config.SigningMethod = DefaultJWTAuthConfig.SigningMethod
|
||||||
|
}
|
||||||
|
if config.ContextKey == "" {
|
||||||
|
config.ContextKey = DefaultJWTAuthConfig.ContextKey
|
||||||
|
}
|
||||||
|
if config.Extractor == nil {
|
||||||
|
config.Extractor = DefaultJWTAuthConfig.Extractor
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
auth, err := config.Extractor(c)
|
||||||
|
if err != nil {
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
token, err := jwt.Parse(auth, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
// Check the signing method
|
||||||
|
if t.Method.Alg() != config.SigningMethod {
|
||||||
|
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
|
||||||
|
}
|
||||||
|
return []byte(config.SigningKey), nil
|
||||||
|
|
||||||
|
})
|
||||||
|
if err == nil && token.Valid {
|
||||||
|
// Store user information from token into context.
|
||||||
|
c.Set(config.ContextKey, token)
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
return echo.ErrUnauthorized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
"github.com/labstack/echo/test"
|
"github.com/labstack/echo/test"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@ -30,10 +31,6 @@ func TestBasicAuth(t *testing.T) {
|
|||||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
assert.NoError(t, h(c))
|
assert.NoError(t, h(c))
|
||||||
|
|
||||||
//---------------------
|
|
||||||
// Invalid credentials
|
|
||||||
//---------------------
|
|
||||||
|
|
||||||
// Incorrect password
|
// Incorrect password
|
||||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
@ -44,13 +41,56 @@ func TestBasicAuth(t *testing.T) {
|
|||||||
// Empty Authorization header
|
// Empty Authorization header
|
||||||
req.Header().Set(echo.HeaderAuthorization, "")
|
req.Header().Set(echo.HeaderAuthorization, "")
|
||||||
he = h(c).(*echo.HTTPError)
|
he = h(c).(*echo.HTTPError)
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
|
||||||
|
|
||||||
// Invalid Authorization header
|
// Invalid Authorization header
|
||||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
he = h(c).(*echo.HTTPError)
|
he = h(c).(*echo.HTTPError)
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
}
|
||||||
|
|
||||||
|
func TestJWTAuth(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := test.NewRequest(echo.GET, "/", nil)
|
||||||
|
res := test.NewResponseRecorder()
|
||||||
|
c := e.NewContext(req, res)
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
config := JWTAuthConfig{}
|
||||||
|
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ"
|
||||||
|
|
||||||
|
// No signing key provided
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
JWTAuthWithConfig(config)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Unexpected signing method
|
||||||
|
config.SigningKey = "secret"
|
||||||
|
config.SigningMethod = "RS256"
|
||||||
|
h := JWTAuthWithConfig(config)(handler)
|
||||||
|
he := h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||||
|
|
||||||
|
// Invalid key
|
||||||
|
auth := bearer + " " + token
|
||||||
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
|
config.SigningKey = "invalid-key"
|
||||||
|
h = JWTAuthWithConfig(config)(handler)
|
||||||
|
he = h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||||
|
|
||||||
|
// Valid JWT
|
||||||
|
h = JWTAuth("secret")(handler)
|
||||||
|
if assert.NoError(t, h(c)) {
|
||||||
|
user := c.Get("user").(*jwt.Token)
|
||||||
|
assert.Equal(t, user.Claims["name"], "John Doe")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid Authorization header
|
||||||
|
req.Header().Set(echo.HeaderAuthorization, "invalid-auth")
|
||||||
|
h = JWTAuth("secret")(handler)
|
||||||
|
he = h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||||
}
|
}
|
||||||
|
@ -30,11 +30,11 @@ func New(addr string) *Server {
|
|||||||
return NewConfig(c)
|
return NewConfig(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTLS(addr, certfile, keyfile string) *Server {
|
func NewTLS(addr, certFile, keyFile string) *Server {
|
||||||
c := &engine.Config{
|
c := &engine.Config{
|
||||||
Address: addr,
|
Address: addr,
|
||||||
TLSCertfile: certfile,
|
TLSCertfile: certFile,
|
||||||
TLSKeyfile: keyfile,
|
TLSKeyfile: keyFile,
|
||||||
}
|
}
|
||||||
return NewConfig(c)
|
return NewConfig(c)
|
||||||
}
|
}
|
||||||
@ -84,10 +84,10 @@ func (s *Server) SetLogger(l *log.Logger) {
|
|||||||
func (s *Server) Start() {
|
func (s *Server) Start() {
|
||||||
s.Addr = s.config.Address
|
s.Addr = s.config.Address
|
||||||
s.Handler = s
|
s.Handler = s
|
||||||
certfile := s.config.TLSCertfile
|
certFile := s.config.TLSCertfile
|
||||||
keyfile := s.config.TLSKeyfile
|
keyFile := s.config.TLSKeyfile
|
||||||
if certfile != "" && keyfile != "" {
|
if certFile != "" && keyFile != "" {
|
||||||
s.logger.Fatal(s.ListenAndServeTLS(certfile, keyfile))
|
s.logger.Fatal(s.ListenAndServeTLS(certFile, keyFile))
|
||||||
} else {
|
} else {
|
||||||
s.logger.Fatal(s.ListenAndServe())
|
s.logger.Fatal(s.ListenAndServe())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user