1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added key auth middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2017-01-02 20:12:06 -08:00
parent ab203bf19c
commit 412823eabb
11 changed files with 319 additions and 57 deletions

12
bind.go
View File

@ -38,15 +38,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
}
ctype := req.Header.Get(HeaderContentType)
if req.ContentLength == 0 {
return NewHTTPError(http.StatusBadRequest, "request body can't be empty")
return NewHTTPError(http.StatusBadRequest, "Request body can't be empty")
}
switch {
case strings.HasPrefix(ctype, MIMEApplicationJSON):
if err = json.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*json.UnmarshalTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset))
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset))
} else if se, ok := err.(*json.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: offset=%v, error=%v", se.Offset, se.Error()))
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error()))
} else {
return NewHTTPError(http.StatusBadRequest, err.Error())
}
@ -54,9 +54,9 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) {
case strings.HasPrefix(ctype, MIMEApplicationXML):
if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok {
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: line=%v, error=%v", se.Line, se.Error()))
return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error()))
} else {
return NewHTTPError(http.StatusBadRequest, err.Error())
}
@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
val := reflect.ValueOf(ptr).Elem()
if typ.Kind() != reflect.Struct {
return errors.New("binding element must be a struct")
return errors.New("Binding element must be a struct")
}
for i := 0; i < typ.NumField(); i++ {

View File

@ -221,10 +221,10 @@ var (
ErrUnauthorized = NewHTTPError(http.StatusUnauthorized)
ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed)
ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge)
ErrValidatorNotRegistered = errors.New("validator not registered")
ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
ErrCookieNotFound = errors.New("cookie not found")
ErrValidatorNotRegistered = errors.New("Validator not registered")
ErrRendererNotRegistered = errors.New("Renderer not registered")
ErrInvalidRedirectCode = errors.New("Invalid redirect status code")
ErrCookieNotFound = errors.New("Cookie not found")
)
// Error handlers

View File

@ -2,6 +2,7 @@ package middleware
import (
"encoding/base64"
"net/http"
"github.com/labstack/echo"
)
@ -36,7 +37,7 @@ var (
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
// For missing or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
@ -48,7 +49,7 @@ func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.Validator == nil {
panic("basic-auth middleware requires validator function")
panic("basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper
@ -61,6 +62,9 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
}
auth := c.Request().Header.Get(echo.HeaderAuthorization)
if auth == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Missing authorization header")
}
l := len(basic)
if len(auth) > l+1 && auth[:l] == basic {
@ -77,7 +81,10 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
}
}
}
} else {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid authorization header")
}
// Need to return `401` for browsers to pop-up login box.
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized

View File

@ -30,21 +30,21 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Empty Authorization header
req.Header.Set(echo.HeaderAuthorization, "")
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, http.StatusBadRequest, he.Code)
}

View File

@ -143,7 +143,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return err
}
if !validateCSRFToken(token, clientToken) {
return echo.NewHTTPError(http.StatusForbidden, "csrf token is invalid")
return echo.NewHTTPError(http.StatusForbidden, "CSRF token is invalid")
}
}
@ -187,7 +187,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.FormValue(param)
if token == "" {
return "", errors.New("empty csrf token in form param")
return "", errors.New("Missing csrf token in form param")
}
return token, nil
}
@ -199,7 +199,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
if token == "" {
return "", errors.New("empty csrf token in query param")
return "", errors.New("Missing csrf token in query param")
}
return token, nil
}

View File

@ -14,24 +14,20 @@ import (
type (
// JWTConfig defines the config for JWT middleware.
JWTConfig struct {
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Skipper defines a function to skip middleware.
Skipper Skipper
// Signing key to validate token.
// Required.
SigningKey interface{} `json:"signing_key"`
SigningKey interface{}
// Signing method, used to check token signing method.
// Optional. Default value HS256.
SigningMethod string `json:"signing_method"`
SigningMethod string
// Context key to store user information from the token into context.
// Optional. Default value "user".
ContextKey string `json:"context_key"`
ContextKey string
// Claims are extendable claims data defining token content.
// Optional. Default value jwt.MapClaims
@ -44,7 +40,11 @@ type (
// - "header:<name>"
// - "query:<name>"
// - "cookie:<name>"
TokenLookup string `json:"token_lookup"`
TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
keyFunc jwt.Keyfunc
}
@ -60,11 +60,11 @@ const (
var (
// DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{
AuthScheme: "Bearer",
Skipper: defaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
}
)
@ -73,7 +73,7 @@ var (
//
// For valid token, it sets the user in context and calls next handler.
// For invalid token, it returns "401 - Unauthorized" error.
// For empty token, it returns "400 - Bad Request" error.
// For missing token, it returns "400 - Bad Request" error.
//
// See: https://jwt.io/introduction
// See `JWTConfig.TokenLookup`
@ -87,9 +87,6 @@ func JWT(key []byte) echo.MiddlewareFunc {
// See: `JWT()`.
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
@ -108,6 +105,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
if config.TokenLookup == "" {
config.TokenLookup = DefaultJWTConfig.TokenLookup
}
if config.AuthScheme == "" {
config.AuthScheme = DefaultJWTConfig.AuthScheme
}
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
// Check the signing method
if t.Method.Alg() != config.SigningMethod {
@ -154,7 +154,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
}
}
// jwtFromHeader returns a `jwtExtractor` that extracts token from request header.
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
func jwtFromHeader(header string, authScheme string) jwtExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
@ -162,28 +162,27 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor {
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
return "", errors.New("empty or invalid jwt in request header")
return "", errors.New("Missing or invalid jwt in request header")
}
}
// jwtFromQuery returns a `jwtExtractor` that extracts token from query string.
// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string.
func jwtFromQuery(param string) jwtExtractor {
return func(c echo.Context) (string, error) {
token := c.QueryParam(param)
var err error
if token == "" {
return "", errors.New("empty jwt in query string")
return "", errors.New("Missing jwt in query string")
}
return token, err
return token, nil
}
}
// jwtFromCookie returns a `jwtExtractor` that extracts token from named cookie.
// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie.
func jwtFromCookie(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
cookie, err := c.Cookie(name)
if err != nil {
return "", errors.New("empty jwt in cookie")
return "", errors.New("Missing jwt in cookie")
}
return cookie.Value, nil
}

131
middleware/key_auth.go Normal file
View File

@ -0,0 +1,131 @@
package middleware
import (
"errors"
"net/http"
"strings"
"github.com/labstack/echo"
)
type (
// KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
KeyLookup string `json:"key_lookup"`
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
}
// KeyAuthValidator defines a function to validate KeyAuth credentials.
KeyAuthValidator func(string) bool
keyExtractor func(echo.Context) (string, error)
)
var (
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: defaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
)
// KeyAuth returns an KeyAuth middleware.
//
// For valid key it calls the next handler.
// For invalid key, it sends "401 - Unauthorized" response.
// For missing key, it sends "400 - Bad Request" response.
func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc {
c := DefaultKeyAuthConfig
c.Validator = fn
return KeyAuthWithConfig(c)
}
// KeyAuthWithConfig returns an KeyAuth middleware with config.
// See `KeyAuth()`.
func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultKeyAuthConfig.Skipper
}
// Defaults
if config.AuthScheme == "" {
config.AuthScheme = DefaultKeyAuthConfig.AuthScheme
}
if config.KeyLookup == "" {
config.KeyLookup = DefaultKeyAuthConfig.KeyLookup
}
if config.Validator == nil {
panic("key-auth middleware requires a validator function")
}
// Initialize
parts := strings.Split(config.KeyLookup, ":")
extractor := keyFromHeader(parts[1], config.AuthScheme)
switch parts[0] {
case "query":
extractor = keyFromQuery(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
// Extract and verify key
key, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())
}
if config.Validator(key) {
return next(c)
}
return echo.ErrUnauthorized
}
}
}
// keyFromHeader returns a `keyExtractor` that extracts key from the request header.
func keyFromHeader(header string, authScheme string) keyExtractor {
return func(c echo.Context) (string, error) {
auth := c.Request().Header.Get(header)
if header == echo.HeaderAuthorization {
l := len(authScheme)
if len(auth) > l+1 && auth[:l] == authScheme {
return auth[l+1:], nil
}
} else {
return auth, nil
}
return "", errors.New("Missing or invalid key in request header")
}
}
// keyFromQuery returns a `keyExtractor` that extracts key from the query string.
func keyFromQuery(param string) keyExtractor {
return func(c echo.Context) (string, error) {
key := c.QueryParam(param)
if key == "" {
return "", errors.New("Missing key in query string")
}
return key, nil
}
}

View File

@ -0,0 +1,59 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo"
"github.com/stretchr/testify/assert"
)
func TestKeyAuth(t *testing.T) {
e := echo.New()
req, _ := http.NewRequest(echo.GET, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
config := KeyAuthConfig{
Validator: func(key string) bool {
return key == "valid-key"
},
}
h := KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid key
auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key"
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Invalid key
auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key"
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Key from custom header
config.KeyLookup = "header:API-Key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
req.Header.Set("API-Key", "valid-key")
assert.NoError(t, h(c))
// Key from query string
config.KeyLookup = "query:key"
h = KeyAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
q := req.URL.Query()
q.Add("key", "valid-key")
req.URL.RawQuery = q.Encode()
assert.NoError(t, h(c))
}

View File

@ -11,12 +11,11 @@ Basic auth middleware provides an HTTP basic authentication.
- For valid credentials it calls the next handler.
- For invalid credentials, it sends "401 - Unauthorized" response.
- For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
- For missing or invalid `Authorization` header, it sends "400 - Bad Request" response.
*Usage*
```go
e := echo.New()
e.Use(middleware.BasicAuth(func(username, password string) bool {
if username == "joe" && password == "secret" {
return true
@ -30,9 +29,7 @@ e.Use(middleware.BasicAuth(func(username, password string) bool {
*Usage*
```go
e := echo.New()
e.Use(middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{},
}))
e.Use(middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{}}))
```
## Configuration

View File

@ -11,7 +11,7 @@ JWT provides a JSON Web Token (JWT) authentication middleware.
- For valid token, it sets the user in context and calls next handler.
- For invalid token, it sends "401 - Unauthorized" response.
- For empty or invalid `Authorization` header, it sends "400 - Bad Request".
- For missing or invalid `Authorization` header, it sends "400 - Bad Request".
*Usage*
@ -22,7 +22,6 @@ JWT provides a JSON Web Token (JWT) authentication middleware.
*Usage*
```go
e := echo.New()
e.Use(middleware.JWTWithConfig(middleware.JWTConfig{
SigningKey: []byte("secret"),
TokenLookup: "query:token",
@ -34,24 +33,20 @@ e.Use(middleware.JWTWithConfig(middleware.JWTConfig{
```go
// JWTConfig defines the config for JWT middleware.
JWTConfig struct {
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Skipper defines a function to skip middleware.
Skipper Skipper
// Signing key to validate token.
// Required.
SigningKey interface{} `json:"signing_key"`
SigningKey interface{}
// Signing method, used to check token signing method.
// Optional. Default value HS256.
SigningMethod string `json:"signing_method"`
SigningMethod string
// Context key to store user information from the token into context.
// Optional. Default value "user".
ContextKey string `json:"context_key"`
ContextKey string
// Claims are extendable claims data defining token content.
// Optional. Default value jwt.MapClaims
@ -64,7 +59,11 @@ JWTConfig struct {
// - "header:<name>"
// - "query:<name>"
// - "cookie:<name>"
TokenLookup string `json:"token_lookup"`
TokenLookup string
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
}
```
@ -76,6 +75,7 @@ DefaultJWTConfig = JWTConfig{
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
Claims: jwt.MapClaims{},
}
```

View File

@ -0,0 +1,69 @@
+++
title = "Key Auth Middleware"
description = "Key auth middleware for Echo"
[menu.main]
name = "Key Auth"
parent = "middleware"
weight = 5
+++
Key auth middleware provides a key based authentication.
- For valid key it calls the next handler.
- For invalid key, it sends "401 - Unauthorized" response.
- For missing key, it sends "400 - Bad Request" response.
*Usage*
```go
e.Use(middleware.KeyAuth(func(key string) bool {
return key == "valid-key"
}))
```
## Custom Configuration
*Usage*
```go
e := echo.New()
e.Use(middleware.KeyAuthWithConfig(middleware.KeyAuthConfig{
KeyLookup: "query:api-key",
}))
```
## Configuration
```go
// KeyAuthConfig defines the config for KeyAuth middleware.
KeyAuthConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// KeyLookup is a string in the form of "<source>:<name>" that is used
// to extract key from the request.
// Optional. Default value "header:Authorization".
// Possible values:
// - "header:<name>"
// - "query:<name>"
KeyLookup string `json:"key_lookup"`
// AuthScheme to be used in the Authorization header.
// Optional. Default value "Bearer".
AuthScheme string
// Validator is a function to validate key.
// Required.
Validator KeyAuthValidator
}
```
*Default Configuration*
```go
DefaultKeyAuthConfig = KeyAuthConfig{
Skipper: defaultSkipper,
KeyLookup: "header:" + echo.HeaderAuthorization,
AuthScheme: "Bearer",
}
```