From 412823eabb0f1e4c0b93fcf7a26232099469938d Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Mon, 2 Jan 2017 20:12:06 -0800 Subject: [PATCH] Added key auth middleware Signed-off-by: Vishal Rana --- bind.go | 12 +-- echo.go | 8 +- middleware/basic_auth.go | 11 +- middleware/basic_auth_test.go | 12 +-- middleware/csrf.go | 6 +- middleware/jwt.go | 41 ++++--- middleware/key_auth.go | 131 +++++++++++++++++++++++ middleware/key_auth_test.go | 59 ++++++++++ website/content/middleware/basic-auth.md | 7 +- website/content/middleware/jwt.md | 20 ++-- website/content/middleware/key-auth.md | 69 ++++++++++++ 11 files changed, 319 insertions(+), 57 deletions(-) create mode 100644 middleware/key_auth.go create mode 100644 middleware/key_auth_test.go create mode 100644 website/content/middleware/key-auth.md diff --git a/bind.go b/bind.go index 027d9954..05bbbb5b 100644 --- a/bind.go +++ b/bind.go @@ -38,15 +38,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { } ctype := req.Header.Get(HeaderContentType) if req.ContentLength == 0 { - return NewHTTPError(http.StatusBadRequest, "request body can't be empty") + return NewHTTPError(http.StatusBadRequest, "Request body can't be empty") } switch { case strings.HasPrefix(ctype, MIMEApplicationJSON): if err = json.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset)) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset)) } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: offset=%v, error=%v", se.Offset, se.Error())) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())) } else { return NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -54,9 +54,9 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { case strings.HasPrefix(ctype, MIMEApplicationXML): if err = xml.NewDecoder(req.Body).Decode(i); err != nil { if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: line=%v, error=%v", se.Line, se.Error())) + return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())) } else { return NewHTTPError(http.StatusBadRequest, err.Error()) } @@ -80,7 +80,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag val := reflect.ValueOf(ptr).Elem() if typ.Kind() != reflect.Struct { - return errors.New("binding element must be a struct") + return errors.New("Binding element must be a struct") } for i := 0; i < typ.NumField(); i++ { diff --git a/echo.go b/echo.go index 92bbf1f5..d8bb2dbf 100644 --- a/echo.go +++ b/echo.go @@ -221,10 +221,10 @@ var ( ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") + ErrValidatorNotRegistered = errors.New("Validator not registered") + ErrRendererNotRegistered = errors.New("Renderer not registered") + ErrInvalidRedirectCode = errors.New("Invalid redirect status code") + ErrCookieNotFound = errors.New("Cookie not found") ) // Error handlers diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index f24fdc62..59ce4792 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -2,6 +2,7 @@ package middleware import ( "encoding/base64" + "net/http" "github.com/labstack/echo" ) @@ -36,7 +37,7 @@ var ( // // For valid credentials it calls the next handler. // For invalid credentials, it sends "401 - Unauthorized" response. -// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response. +// For missing or invalid `Authorization` header, it sends "400 - Bad Request" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { c := DefaultBasicAuthConfig c.Validator = fn @@ -48,7 +49,7 @@ func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { // Defaults if config.Validator == nil { - panic("basic-auth middleware requires validator function") + panic("basic-auth middleware requires a validator function") } if config.Skipper == nil { config.Skipper = DefaultBasicAuthConfig.Skipper @@ -61,6 +62,9 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { } auth := c.Request().Header.Get(echo.HeaderAuthorization) + if auth == "" { + return echo.NewHTTPError(http.StatusBadRequest, "Missing authorization header") + } l := len(basic) if len(auth) > l+1 && auth[:l] == basic { @@ -77,7 +81,10 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { } } } + } else { + return echo.NewHTTPError(http.StatusBadRequest, "Invalid authorization header") } + // Need to return `401` for browsers to pop-up login box. c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted") return echo.ErrUnauthorized diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 85c00db8..a59790ee 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -30,21 +30,21 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, auth) assert.NoError(t, h(c)) - // Incorrect password - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) + // Invalid credentials + auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) req.Header.Set(echo.HeaderAuthorization, auth) he := h(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) - // Empty Authorization header - req.Header.Set(echo.HeaderAuthorization, "") + // Missing Authorization header + req.Header.Del(echo.HeaderAuthorization) he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) + assert.Equal(t, http.StatusBadRequest, he.Code) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) req.Header.Set(echo.HeaderAuthorization, auth) he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) + assert.Equal(t, http.StatusBadRequest, he.Code) } diff --git a/middleware/csrf.go b/middleware/csrf.go index 6d9b18ed..40989610 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -143,7 +143,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return err } if !validateCSRFToken(token, clientToken) { - return echo.NewHTTPError(http.StatusForbidden, "csrf token is invalid") + return echo.NewHTTPError(http.StatusForbidden, "CSRF token is invalid") } } @@ -187,7 +187,7 @@ func csrfTokenFromForm(param string) csrfTokenExtractor { return func(c echo.Context) (string, error) { token := c.FormValue(param) if token == "" { - return "", errors.New("empty csrf token in form param") + return "", errors.New("Missing csrf token in form param") } return token, nil } @@ -199,7 +199,7 @@ func csrfTokenFromQuery(param string) csrfTokenExtractor { return func(c echo.Context) (string, error) { token := c.QueryParam(param) if token == "" { - return "", errors.New("empty csrf token in query param") + return "", errors.New("Missing csrf token in query param") } return token, nil } diff --git a/middleware/jwt.go b/middleware/jwt.go index cb1daa8c..7867dfef 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -14,24 +14,20 @@ import ( type ( // JWTConfig defines the config for JWT middleware. JWTConfig struct { - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - // Skipper defines a function to skip middleware. Skipper Skipper // Signing key to validate token. // Required. - SigningKey interface{} `json:"signing_key"` + SigningKey interface{} // Signing method, used to check token signing method. // Optional. Default value HS256. - SigningMethod string `json:"signing_method"` + SigningMethod string // Context key to store user information from the token into context. // Optional. Default value "user". - ContextKey string `json:"context_key"` + ContextKey string // Claims are extendable claims data defining token content. // Optional. Default value jwt.MapClaims @@ -44,7 +40,11 @@ type ( // - "header:" // - "query:" // - "cookie:" - TokenLookup string `json:"token_lookup"` + TokenLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string keyFunc jwt.Keyfunc } @@ -60,11 +60,11 @@ const ( var ( // DefaultJWTConfig is the default JWT auth middleware config. DefaultJWTConfig = JWTConfig{ - AuthScheme: "Bearer", Skipper: defaultSkipper, SigningMethod: AlgorithmHS256, ContextKey: "user", TokenLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", Claims: jwt.MapClaims{}, } ) @@ -73,7 +73,7 @@ var ( // // For valid token, it sets the user in context and calls next handler. // For invalid token, it returns "401 - Unauthorized" error. -// For empty token, it returns "400 - Bad Request" error. +// For missing token, it returns "400 - Bad Request" error. // // See: https://jwt.io/introduction // See `JWTConfig.TokenLookup` @@ -87,9 +87,6 @@ func JWT(key []byte) echo.MiddlewareFunc { // See: `JWT()`. func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultJWTConfig.AuthScheme - } if config.Skipper == nil { config.Skipper = DefaultJWTConfig.Skipper } @@ -108,6 +105,9 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { if config.TokenLookup == "" { config.TokenLookup = DefaultJWTConfig.TokenLookup } + if config.AuthScheme == "" { + config.AuthScheme = DefaultJWTConfig.AuthScheme + } config.keyFunc = func(t *jwt.Token) (interface{}, error) { // Check the signing method if t.Method.Alg() != config.SigningMethod { @@ -154,7 +154,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { } } -// jwtFromHeader returns a `jwtExtractor` that extracts token from request header. +// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header. func jwtFromHeader(header string, authScheme string) jwtExtractor { return func(c echo.Context) (string, error) { auth := c.Request().Header.Get(header) @@ -162,28 +162,27 @@ func jwtFromHeader(header string, authScheme string) jwtExtractor { if len(auth) > l+1 && auth[:l] == authScheme { return auth[l+1:], nil } - return "", errors.New("empty or invalid jwt in request header") + return "", errors.New("Missing or invalid jwt in request header") } } -// jwtFromQuery returns a `jwtExtractor` that extracts token from query string. +// jwtFromQuery returns a `jwtExtractor` that extracts token from the query string. func jwtFromQuery(param string) jwtExtractor { return func(c echo.Context) (string, error) { token := c.QueryParam(param) - var err error if token == "" { - return "", errors.New("empty jwt in query string") + return "", errors.New("Missing jwt in query string") } - return token, err + return token, nil } } -// jwtFromCookie returns a `jwtExtractor` that extracts token from named cookie. +// jwtFromCookie returns a `jwtExtractor` that extracts token from the named cookie. func jwtFromCookie(name string) jwtExtractor { return func(c echo.Context) (string, error) { cookie, err := c.Cookie(name) if err != nil { - return "", errors.New("empty jwt in cookie") + return "", errors.New("Missing jwt in cookie") } return cookie.Value, nil } diff --git a/middleware/key_auth.go b/middleware/key_auth.go new file mode 100644 index 00000000..bab987ee --- /dev/null +++ b/middleware/key_auth.go @@ -0,0 +1,131 @@ +package middleware + +import ( + "errors" + "net/http" + "strings" + + "github.com/labstack/echo" +) + +type ( + // KeyAuthConfig defines the config for KeyAuth middleware. + KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + KeyLookup string `json:"key_lookup"` + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator + } + + // KeyAuthValidator defines a function to validate KeyAuth credentials. + KeyAuthValidator func(string) bool + + keyExtractor func(echo.Context) (string, error) +) + +var ( + // DefaultKeyAuthConfig is the default KeyAuth middleware config. + DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: defaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", + } +) + +// KeyAuth returns an KeyAuth middleware. +// +// For valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. +func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { + c := DefaultKeyAuthConfig + c.Validator = fn + return KeyAuthWithConfig(c) +} + +// KeyAuthWithConfig returns an KeyAuth middleware with config. +// See `KeyAuth()`. +func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultKeyAuthConfig.Skipper + } + // Defaults + if config.AuthScheme == "" { + config.AuthScheme = DefaultKeyAuthConfig.AuthScheme + } + if config.KeyLookup == "" { + config.KeyLookup = DefaultKeyAuthConfig.KeyLookup + } + if config.Validator == nil { + panic("key-auth middleware requires a validator function") + } + + // Initialize + parts := strings.Split(config.KeyLookup, ":") + extractor := keyFromHeader(parts[1], config.AuthScheme) + switch parts[0] { + case "query": + extractor = keyFromQuery(parts[1]) + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + + // Extract and verify key + key, err := extractor(c) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + } + if config.Validator(key) { + return next(c) + } + + return echo.ErrUnauthorized + } + } +} + +// keyFromHeader returns a `keyExtractor` that extracts key from the request header. +func keyFromHeader(header string, authScheme string) keyExtractor { + return func(c echo.Context) (string, error) { + auth := c.Request().Header.Get(header) + if header == echo.HeaderAuthorization { + l := len(authScheme) + if len(auth) > l+1 && auth[:l] == authScheme { + return auth[l+1:], nil + } + } else { + return auth, nil + } + return "", errors.New("Missing or invalid key in request header") + } +} + +// keyFromQuery returns a `keyExtractor` that extracts key from the query string. +func keyFromQuery(param string) keyExtractor { + return func(c echo.Context) (string, error) { + key := c.QueryParam(param) + if key == "" { + return "", errors.New("Missing key in query string") + } + return key, nil + } +} diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go new file mode 100644 index 00000000..0f0e9762 --- /dev/null +++ b/middleware/key_auth_test.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo" + "github.com/stretchr/testify/assert" +) + +func TestKeyAuth(t *testing.T) { + e := echo.New() + req, _ := http.NewRequest(echo.GET, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + config := KeyAuthConfig{ + Validator: func(key string) bool { + return key == "valid-key" + }, + } + h := KeyAuthWithConfig(config)(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // Valid key + auth := DefaultKeyAuthConfig.AuthScheme + " " + "valid-key" + req.Header.Set(echo.HeaderAuthorization, auth) + assert.NoError(t, h(c)) + + // Invalid key + auth = DefaultKeyAuthConfig.AuthScheme + " " + "invalid-key" + req.Header.Set(echo.HeaderAuthorization, auth) + he := h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusUnauthorized, he.Code) + + // Missing Authorization header + req.Header.Del(echo.HeaderAuthorization) + he = h(c).(*echo.HTTPError) + assert.Equal(t, http.StatusBadRequest, he.Code) + + // Key from custom header + config.KeyLookup = "header:API-Key" + h = KeyAuthWithConfig(config)(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + req.Header.Set("API-Key", "valid-key") + assert.NoError(t, h(c)) + + // Key from query string + config.KeyLookup = "query:key" + h = KeyAuthWithConfig(config)(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + assert.NoError(t, h(c)) +} diff --git a/website/content/middleware/basic-auth.md b/website/content/middleware/basic-auth.md index 8fa6861d..59fb906f 100644 --- a/website/content/middleware/basic-auth.md +++ b/website/content/middleware/basic-auth.md @@ -11,12 +11,11 @@ Basic auth middleware provides an HTTP basic authentication. - For valid credentials it calls the next handler. - For invalid credentials, it sends "401 - Unauthorized" response. -- For empty or invalid `Authorization` header, it sends "400 - Bad Request" response. +- For missing or invalid `Authorization` header, it sends "400 - Bad Request" response. *Usage* ```go -e := echo.New() e.Use(middleware.BasicAuth(func(username, password string) bool { if username == "joe" && password == "secret" { return true @@ -30,9 +29,7 @@ e.Use(middleware.BasicAuth(func(username, password string) bool { *Usage* ```go -e := echo.New() -e.Use(middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{}, -})) +e.Use(middleware.BasicAuthWithConfig(middleware.BasicAuthConfig{}})) ``` ## Configuration diff --git a/website/content/middleware/jwt.md b/website/content/middleware/jwt.md index c234c82a..50a0061e 100644 --- a/website/content/middleware/jwt.md +++ b/website/content/middleware/jwt.md @@ -11,7 +11,7 @@ JWT provides a JSON Web Token (JWT) authentication middleware. - For valid token, it sets the user in context and calls next handler. - For invalid token, it sends "401 - Unauthorized" response. -- For empty or invalid `Authorization` header, it sends "400 - Bad Request". +- For missing or invalid `Authorization` header, it sends "400 - Bad Request". *Usage* @@ -22,7 +22,6 @@ JWT provides a JSON Web Token (JWT) authentication middleware. *Usage* ```go -e := echo.New() e.Use(middleware.JWTWithConfig(middleware.JWTConfig{ SigningKey: []byte("secret"), TokenLookup: "query:token", @@ -34,24 +33,20 @@ e.Use(middleware.JWTWithConfig(middleware.JWTConfig{ ```go // JWTConfig defines the config for JWT middleware. JWTConfig struct { - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string - // Skipper defines a function to skip middleware. Skipper Skipper // Signing key to validate token. // Required. - SigningKey interface{} `json:"signing_key"` + SigningKey interface{} // Signing method, used to check token signing method. // Optional. Default value HS256. - SigningMethod string `json:"signing_method"` + SigningMethod string // Context key to store user information from the token into context. // Optional. Default value "user". - ContextKey string `json:"context_key"` + ContextKey string // Claims are extendable claims data defining token content. // Optional. Default value jwt.MapClaims @@ -64,7 +59,11 @@ JWTConfig struct { // - "header:" // - "query:" // - "cookie:" - TokenLookup string `json:"token_lookup"` + TokenLookup string + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string } ``` @@ -76,6 +75,7 @@ DefaultJWTConfig = JWTConfig{ SigningMethod: AlgorithmHS256, ContextKey: "user", TokenLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", Claims: jwt.MapClaims{}, } ``` diff --git a/website/content/middleware/key-auth.md b/website/content/middleware/key-auth.md new file mode 100644 index 00000000..84a58a2e --- /dev/null +++ b/website/content/middleware/key-auth.md @@ -0,0 +1,69 @@ ++++ +title = "Key Auth Middleware" +description = "Key auth middleware for Echo" +[menu.main] + name = "Key Auth" + parent = "middleware" + weight = 5 ++++ + +Key auth middleware provides a key based authentication. + +- For valid key it calls the next handler. +- For invalid key, it sends "401 - Unauthorized" response. +- For missing key, it sends "400 - Bad Request" response. + +*Usage* + +```go +e.Use(middleware.KeyAuth(func(key string) bool { + return key == "valid-key" +})) +``` + +## Custom Configuration + +*Usage* + +```go +e := echo.New() +e.Use(middleware.KeyAuthWithConfig(middleware.KeyAuthConfig{ + KeyLookup: "query:api-key", +})) +``` + +## Configuration + +```go +// KeyAuthConfig defines the config for KeyAuth middleware. +KeyAuthConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // KeyLookup is a string in the form of ":" that is used + // to extract key from the request. + // Optional. Default value "header:Authorization". + // Possible values: + // - "header:" + // - "query:" + KeyLookup string `json:"key_lookup"` + + // AuthScheme to be used in the Authorization header. + // Optional. Default value "Bearer". + AuthScheme string + + // Validator is a function to validate key. + // Required. + Validator KeyAuthValidator +} +``` + +*Default Configuration* + +```go +DefaultKeyAuthConfig = KeyAuthConfig{ + Skipper: defaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization, + AuthScheme: "Bearer", +} +```