mirror of
https://github.com/labstack/echo.git
synced 2025-01-26 03:20:08 +02:00
Added method override middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
7f819e4d71
commit
24a19223b0
1
echo.go
1
echo.go
@ -162,6 +162,7 @@ const (
|
|||||||
HeaderUpgrade = "Upgrade"
|
HeaderUpgrade = "Upgrade"
|
||||||
HeaderVary = "Vary"
|
HeaderVary = "Vary"
|
||||||
HeaderWWWAuthenticate = "WWW-Authenticate"
|
HeaderWWWAuthenticate = "WWW-Authenticate"
|
||||||
|
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
|
||||||
HeaderXForwardedFor = "X-Forwarded-For"
|
HeaderXForwardedFor = "X-Forwarded-For"
|
||||||
HeaderXRealIP = "X-Real-IP"
|
HeaderXRealIP = "X-Real-IP"
|
||||||
HeaderServer = "Server"
|
HeaderServer = "Server"
|
||||||
|
69
middleware/basic_auth.go
Normal file
69
middleware/basic_auth.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/labstack/echo"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// BasicAuthConfig defines the config for HTTP basic auth middleware.
|
||||||
|
BasicAuthConfig struct {
|
||||||
|
// Validator is the function to validate basic auth credentials.
|
||||||
|
Validator BasicAuthValidator
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuthValidator defines a function to validate basic auth credentials.
|
||||||
|
BasicAuthValidator func(string, string) bool
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
basic = "Basic"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
||||||
|
DefaultBasicAuthConfig = BasicAuthConfig{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// BasicAuth returns an HTTP basic auth middleware.
|
||||||
|
//
|
||||||
|
// For valid credentials it calls the next handler.
|
||||||
|
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||||
|
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
|
||||||
|
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
|
||||||
|
c := DefaultBasicAuthConfig
|
||||||
|
c.Validator = fn
|
||||||
|
return BasicAuthWithConfig(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
|
||||||
|
// See `BasicAuth()`.
|
||||||
|
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
||||||
|
l := len(basic)
|
||||||
|
|
||||||
|
if len(auth) > l+1 && auth[:l] == basic {
|
||||||
|
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cred := string(b)
|
||||||
|
for i := 0; i < len(cred); i++ {
|
||||||
|
if cred[i] == ':' {
|
||||||
|
// Verify credentials
|
||||||
|
if config.Validator(cred[:i], cred[i+1:]) {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
||||||
|
return echo.ErrUnauthorized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
50
middleware/basic_auth_test.go
Normal file
50
middleware/basic_auth_test.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/labstack/echo"
|
||||||
|
"github.com/labstack/echo/test"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasicAuth(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
req := test.NewRequest(echo.GET, "/", nil)
|
||||||
|
res := test.NewResponseRecorder()
|
||||||
|
c := e.NewContext(req, res)
|
||||||
|
f := func(u, p string) bool {
|
||||||
|
if u == "joe" && p == "secret" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
h := BasicAuth(f)(func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Valid credentials
|
||||||
|
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||||
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
|
assert.NoError(t, h(c))
|
||||||
|
|
||||||
|
// Incorrect password
|
||||||
|
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||||
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
|
he := h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||||
|
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||||
|
|
||||||
|
// Empty Authorization header
|
||||||
|
req.Header().Set(echo.HeaderAuthorization, "")
|
||||||
|
he = h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||||
|
|
||||||
|
// Invalid Authorization header
|
||||||
|
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||||
|
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||||
|
he = h(c).(*echo.HTTPError)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
@ -10,15 +9,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// BasicAuthConfig defines the config for HTTP basic auth middleware.
|
|
||||||
BasicAuthConfig struct {
|
|
||||||
// AuthFunc is the function to validate basic auth credentials.
|
|
||||||
AuthFunc BasicAuthFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
// BasicAuthFunc defines a function to validate basic auth credentials.
|
|
||||||
BasicAuthFunc func(string, string) bool
|
|
||||||
|
|
||||||
// JWTAuthConfig defines the config for JWT auth middleware.
|
// JWTAuthConfig defines the config for JWT auth middleware.
|
||||||
JWTAuthConfig struct {
|
JWTAuthConfig struct {
|
||||||
// SigningKey is the key to validate token.
|
// SigningKey is the key to validate token.
|
||||||
@ -45,7 +35,6 @@ type (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
basic = "Basic"
|
|
||||||
bearer = "Bearer"
|
bearer = "Bearer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -55,9 +44,6 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
|
||||||
DefaultBasicAuthConfig = BasicAuthConfig{}
|
|
||||||
|
|
||||||
// DefaultJWTAuthConfig is the default JWT auth middleware config.
|
// DefaultJWTAuthConfig is the default JWT auth middleware config.
|
||||||
DefaultJWTAuthConfig = JWTAuthConfig{
|
DefaultJWTAuthConfig = JWTAuthConfig{
|
||||||
SigningMethod: AlgorithmHS256,
|
SigningMethod: AlgorithmHS256,
|
||||||
@ -66,47 +52,6 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// BasicAuth returns an HTTP basic auth middleware.
|
|
||||||
//
|
|
||||||
// For valid credentials it calls the next handler.
|
|
||||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
|
||||||
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
|
|
||||||
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
|
|
||||||
c := DefaultBasicAuthConfig
|
|
||||||
c.AuthFunc = fn
|
|
||||||
return BasicAuthWithConfig(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
|
|
||||||
// See `BasicAuth()`.
|
|
||||||
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
||||||
return func(c echo.Context) error {
|
|
||||||
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
|
||||||
l := len(basic)
|
|
||||||
|
|
||||||
if len(auth) > l+1 && auth[:l] == basic {
|
|
||||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cred := string(b)
|
|
||||||
for i := 0; i < len(cred); i++ {
|
|
||||||
if cred[i] == ':' {
|
|
||||||
// Verify credentials
|
|
||||||
if config.AuthFunc(cred[:i], cred[i+1:]) {
|
|
||||||
return next(c)
|
|
||||||
}
|
|
||||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
|
||||||
return echo.ErrUnauthorized
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
|
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
|
||||||
//
|
//
|
||||||
// For valid token, it sets the user in context and calls next handler.
|
// For valid token, it sets the user in context and calls next handler.
|
@ -1,7 +1,6 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -11,45 +10,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBasicAuth(t *testing.T) {
|
|
||||||
e := echo.New()
|
|
||||||
req := test.NewRequest(echo.GET, "/", nil)
|
|
||||||
res := test.NewResponseRecorder()
|
|
||||||
c := e.NewContext(req, res)
|
|
||||||
f := func(u, p string) bool {
|
|
||||||
if u == "joe" && p == "secret" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
h := BasicAuth(f)(func(c echo.Context) error {
|
|
||||||
return c.String(http.StatusOK, "test")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Valid credentials
|
|
||||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
|
||||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
|
||||||
assert.NoError(t, h(c))
|
|
||||||
|
|
||||||
// Incorrect password
|
|
||||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
|
||||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
|
||||||
he := h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
|
||||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
|
||||||
|
|
||||||
// Empty Authorization header
|
|
||||||
req.Header().Set(echo.HeaderAuthorization, "")
|
|
||||||
he = h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
|
||||||
|
|
||||||
// Invalid Authorization header
|
|
||||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
|
||||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
|
||||||
he = h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJWTAuth(t *testing.T) {
|
func TestJWTAuth(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
req := test.NewRequest(echo.GET, "/", nil)
|
req := test.NewRequest(echo.GET, "/", nil)
|
@ -4,36 +4,70 @@ import (
|
|||||||
"github.com/labstack/echo"
|
"github.com/labstack/echo"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
type (
|
||||||
HttpMethodOverrideHeader = "X-HTTP-Method-Override"
|
// MethodOverrideConfig defines the config for method override middleware.
|
||||||
|
MethodOverrideConfig struct {
|
||||||
|
Getter MethodOverrideGetter
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodOverrideGetter is a function that gets overridden method from the request
|
||||||
|
// Optional, with default values as `MethodFromHeader(echo.HeaderXHTTPMethodOverride)`.
|
||||||
|
MethodOverrideGetter func(echo.Context) string
|
||||||
)
|
)
|
||||||
|
|
||||||
func OverrideMethod() echo.MiddlewareFunc {
|
var (
|
||||||
return Override()
|
// DefaultMethodOverrideConfig is the default method override middleware config.
|
||||||
|
DefaultMethodOverrideConfig = MethodOverrideConfig{
|
||||||
|
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// MethodOverride returns a method override middleware.
|
||||||
|
// Method override middleware checks for the overriden method from the request and
|
||||||
|
// uses it instead of the original method.
|
||||||
|
//
|
||||||
|
// For security reasons, only `POST` method can be overriden.
|
||||||
|
func MethodOverride() echo.MiddlewareFunc {
|
||||||
|
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override checks for the X-HTTP-Method-Override header
|
// MethodOverrideWithConfig returns a method override middleware from config.
|
||||||
// or the body for parameter, `_method`
|
// See `MethodOverride()`.
|
||||||
// and uses the http method instead of Request.Method.
|
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
|
||||||
// It isn't secure to override e.g a GET to a POST,
|
|
||||||
// so only Request.Method which are POSTs are considered.
|
|
||||||
func Override() echo.MiddlewareFunc {
|
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c echo.Context) error {
|
return func(c echo.Context) error {
|
||||||
originalMethod := c.Request().Method()
|
req := c.Request()
|
||||||
|
if req.Method() == echo.POST {
|
||||||
if originalMethod == "POST" {
|
m := config.Getter(c)
|
||||||
m := c.FormValue("_method")
|
|
||||||
if m != "" {
|
|
||||||
c.Request().SetMethod(m)
|
|
||||||
}
|
|
||||||
m = c.Request().Header().Get(HttpMethodOverrideHeader)
|
|
||||||
if m != "" {
|
if m != "" {
|
||||||
c.Request().SetMethod(m)
|
c.Request().SetMethod(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return next(c)
|
return next(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
|
||||||
|
// the request header.
|
||||||
|
func MethodFromHeader(header string) MethodOverrideGetter {
|
||||||
|
return func(c echo.Context) string {
|
||||||
|
return c.Request().Header().Get(header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
|
||||||
|
// form parameter.
|
||||||
|
func MethodFromForm(param string) MethodOverrideGetter {
|
||||||
|
return func(c echo.Context) string {
|
||||||
|
return c.FormValue(param)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
|
||||||
|
// the query parameter.
|
||||||
|
func MethodFromQuery(param string) MethodOverrideGetter {
|
||||||
|
return func(c echo.Context) string {
|
||||||
|
return c.QueryParam(param)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -10,34 +10,40 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOverrideMtrhod(t *testing.T) {
|
func TestMethodOverride(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
methodOverride := OverrideMethod()
|
m := MethodOverride()
|
||||||
h := methodOverride(func(c echo.Context) error {
|
h := func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, c.Request().Method())
|
return c.String(http.StatusOK, "Okay")
|
||||||
})
|
}
|
||||||
|
|
||||||
// Override with http header
|
// Override with http header
|
||||||
rq := test.NewRequest(echo.POST, "/", nil)
|
req := test.NewRequest(echo.POST, "/", nil)
|
||||||
rq.Header().Set(HttpMethodOverrideHeader, "DELETE")
|
rec := test.NewResponseRecorder()
|
||||||
rc := test.NewResponseRecorder()
|
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
|
||||||
c := e.NewContext(rq, rc)
|
c := e.NewContext(req, rec)
|
||||||
h(c)
|
m(h)(c)
|
||||||
assert.Equal(t, "DELETE", rc.Body.String())
|
assert.Equal(t, echo.DELETE, req.Method())
|
||||||
|
|
||||||
// Override with body parameter
|
// Override with form parameter
|
||||||
rq = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method=DELETE")))
|
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
|
||||||
rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE)))
|
||||||
rc = test.NewResponseRecorder()
|
rec = test.NewResponseRecorder()
|
||||||
c = e.NewContext(rq, rc)
|
req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||||
h(c)
|
c = e.NewContext(req, rec)
|
||||||
assert.Equal(t, "DELETE", rc.Body.String())
|
m(h)(c)
|
||||||
|
assert.Equal(t, echo.DELETE, req.Method())
|
||||||
|
|
||||||
// Ignore GET
|
// Override with query paramter
|
||||||
rq = test.NewRequest(echo.GET, "/", nil)
|
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
|
||||||
rq.Header().Set(HttpMethodOverrideHeader, "DELETE")
|
req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil)
|
||||||
rc = test.NewResponseRecorder()
|
rec = test.NewResponseRecorder()
|
||||||
c = e.NewContext(rq, rc)
|
c = e.NewContext(req, rec)
|
||||||
h(c)
|
m(h)(c)
|
||||||
assert.Equal(t, "GET", rc.Body.String())
|
assert.Equal(t, echo.DELETE, req.Method())
|
||||||
|
|
||||||
|
// Ignore `GET`
|
||||||
|
req = test.NewRequest(echo.GET, "/", nil)
|
||||||
|
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
|
||||||
|
assert.Equal(t, echo.GET, req.Method())
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user