mirror of
https://github.com/labstack/echo.git
synced 2025-01-26 03:20:08 +02:00
Added method override middleware
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
7f819e4d71
commit
24a19223b0
1
echo.go
1
echo.go
@ -162,6 +162,7 @@ const (
|
||||
HeaderUpgrade = "Upgrade"
|
||||
HeaderVary = "Vary"
|
||||
HeaderWWWAuthenticate = "WWW-Authenticate"
|
||||
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
|
||||
HeaderXForwardedFor = "X-Forwarded-For"
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
HeaderServer = "Server"
|
||||
|
69
middleware/basic_auth.go
Normal file
69
middleware/basic_auth.go
Normal file
@ -0,0 +1,69 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
type (
|
||||
// BasicAuthConfig defines the config for HTTP basic auth middleware.
|
||||
BasicAuthConfig struct {
|
||||
// Validator is the function to validate basic auth credentials.
|
||||
Validator BasicAuthValidator
|
||||
}
|
||||
|
||||
// BasicAuthValidator defines a function to validate basic auth credentials.
|
||||
BasicAuthValidator func(string, string) bool
|
||||
)
|
||||
|
||||
const (
|
||||
basic = "Basic"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
||||
DefaultBasicAuthConfig = BasicAuthConfig{}
|
||||
)
|
||||
|
||||
// BasicAuth returns an HTTP basic auth middleware.
|
||||
//
|
||||
// For valid credentials it calls the next handler.
|
||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
|
||||
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
|
||||
c := DefaultBasicAuthConfig
|
||||
c.Validator = fn
|
||||
return BasicAuthWithConfig(c)
|
||||
}
|
||||
|
||||
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
|
||||
// See `BasicAuth()`.
|
||||
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
||||
l := len(basic)
|
||||
|
||||
if len(auth) > l+1 && auth[:l] == basic {
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
if config.Validator(cred[:i], cred[i+1:]) {
|
||||
return next(c)
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
||||
return echo.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
|
||||
}
|
||||
}
|
||||
}
|
50
middleware/basic_auth_test.go
Normal file
50
middleware/basic_auth_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"github.com/labstack/echo/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.GET, "/", nil)
|
||||
res := test.NewResponseRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
f := func(u, p string) bool {
|
||||
if u == "joe" && p == "secret" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
h := BasicAuth(f)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
// Incorrect password
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Empty Authorization header
|
||||
req.Header().Set(echo.HeaderAuthorization, "")
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@ -10,15 +9,6 @@ import (
|
||||
)
|
||||
|
||||
type (
|
||||
// BasicAuthConfig defines the config for HTTP basic auth middleware.
|
||||
BasicAuthConfig struct {
|
||||
// AuthFunc is the function to validate basic auth credentials.
|
||||
AuthFunc BasicAuthFunc
|
||||
}
|
||||
|
||||
// BasicAuthFunc defines a function to validate basic auth credentials.
|
||||
BasicAuthFunc func(string, string) bool
|
||||
|
||||
// JWTAuthConfig defines the config for JWT auth middleware.
|
||||
JWTAuthConfig struct {
|
||||
// SigningKey is the key to validate token.
|
||||
@ -45,7 +35,6 @@ type (
|
||||
)
|
||||
|
||||
const (
|
||||
basic = "Basic"
|
||||
bearer = "Bearer"
|
||||
)
|
||||
|
||||
@ -55,9 +44,6 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultBasicAuthConfig is the default basic auth middleware config.
|
||||
DefaultBasicAuthConfig = BasicAuthConfig{}
|
||||
|
||||
// DefaultJWTAuthConfig is the default JWT auth middleware config.
|
||||
DefaultJWTAuthConfig = JWTAuthConfig{
|
||||
SigningMethod: AlgorithmHS256,
|
||||
@ -66,47 +52,6 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// BasicAuth returns an HTTP basic auth middleware.
|
||||
//
|
||||
// For valid credentials it calls the next handler.
|
||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
|
||||
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
|
||||
c := DefaultBasicAuthConfig
|
||||
c.AuthFunc = fn
|
||||
return BasicAuthWithConfig(c)
|
||||
}
|
||||
|
||||
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
|
||||
// See `BasicAuth()`.
|
||||
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
auth := c.Request().Header().Get(echo.HeaderAuthorization)
|
||||
l := len(basic)
|
||||
|
||||
if len(auth) > l+1 && auth[:l] == basic {
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
if config.AuthFunc(cred[:i], cred[i+1:]) {
|
||||
return next(c)
|
||||
}
|
||||
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
|
||||
return echo.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
}
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
|
||||
//
|
||||
// For valid token, it sets the user in context and calls next handler.
|
@ -1,7 +1,6 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
@ -11,45 +10,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.GET, "/", nil)
|
||||
res := test.NewResponseRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
f := func(u, p string) bool {
|
||||
if u == "joe" && p == "secret" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
h := BasicAuth(f)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
// Incorrect password
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Empty Authorization header
|
||||
req.Header().Set(echo.HeaderAuthorization, "")
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header().Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
}
|
||||
|
||||
func TestJWTAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.GET, "/", nil)
|
@ -4,36 +4,70 @@ import (
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
const (
|
||||
HttpMethodOverrideHeader = "X-HTTP-Method-Override"
|
||||
type (
|
||||
// MethodOverrideConfig defines the config for method override middleware.
|
||||
MethodOverrideConfig struct {
|
||||
Getter MethodOverrideGetter
|
||||
}
|
||||
|
||||
// MethodOverrideGetter is a function that gets overridden method from the request
|
||||
// Optional, with default values as `MethodFromHeader(echo.HeaderXHTTPMethodOverride)`.
|
||||
MethodOverrideGetter func(echo.Context) string
|
||||
)
|
||||
|
||||
func OverrideMethod() echo.MiddlewareFunc {
|
||||
return Override()
|
||||
var (
|
||||
// DefaultMethodOverrideConfig is the default method override middleware config.
|
||||
DefaultMethodOverrideConfig = MethodOverrideConfig{
|
||||
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
|
||||
}
|
||||
)
|
||||
|
||||
// MethodOverride returns a method override middleware.
|
||||
// Method override middleware checks for the overriden method from the request and
|
||||
// uses it instead of the original method.
|
||||
//
|
||||
// For security reasons, only `POST` method can be overriden.
|
||||
func MethodOverride() echo.MiddlewareFunc {
|
||||
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
|
||||
}
|
||||
|
||||
// Override checks for the X-HTTP-Method-Override header
|
||||
// or the body for parameter, `_method`
|
||||
// and uses the http method instead of Request.Method.
|
||||
// It isn't secure to override e.g a GET to a POST,
|
||||
// so only Request.Method which are POSTs are considered.
|
||||
func Override() echo.MiddlewareFunc {
|
||||
// MethodOverrideWithConfig returns a method override middleware from config.
|
||||
// See `MethodOverride()`.
|
||||
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
originalMethod := c.Request().Method()
|
||||
|
||||
if originalMethod == "POST" {
|
||||
m := c.FormValue("_method")
|
||||
if m != "" {
|
||||
c.Request().SetMethod(m)
|
||||
}
|
||||
m = c.Request().Header().Get(HttpMethodOverrideHeader)
|
||||
req := c.Request()
|
||||
if req.Method() == echo.POST {
|
||||
m := config.Getter(c)
|
||||
if m != "" {
|
||||
c.Request().SetMethod(m)
|
||||
}
|
||||
}
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
|
||||
// the request header.
|
||||
func MethodFromHeader(header string) MethodOverrideGetter {
|
||||
return func(c echo.Context) string {
|
||||
return c.Request().Header().Get(header)
|
||||
}
|
||||
}
|
||||
|
||||
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
|
||||
// form parameter.
|
||||
func MethodFromForm(param string) MethodOverrideGetter {
|
||||
return func(c echo.Context) string {
|
||||
return c.FormValue(param)
|
||||
}
|
||||
}
|
||||
|
||||
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
|
||||
// the query parameter.
|
||||
func MethodFromQuery(param string) MethodOverrideGetter {
|
||||
return func(c echo.Context) string {
|
||||
return c.QueryParam(param)
|
||||
}
|
||||
}
|
||||
|
@ -10,34 +10,40 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestOverrideMtrhod(t *testing.T) {
|
||||
func TestMethodOverride(t *testing.T) {
|
||||
e := echo.New()
|
||||
methodOverride := OverrideMethod()
|
||||
h := methodOverride(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, c.Request().Method())
|
||||
})
|
||||
m := MethodOverride()
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "Okay")
|
||||
}
|
||||
|
||||
// Override with http header
|
||||
rq := test.NewRequest(echo.POST, "/", nil)
|
||||
rq.Header().Set(HttpMethodOverrideHeader, "DELETE")
|
||||
rc := test.NewResponseRecorder()
|
||||
c := e.NewContext(rq, rc)
|
||||
h(c)
|
||||
assert.Equal(t, "DELETE", rc.Body.String())
|
||||
req := test.NewRequest(echo.POST, "/", nil)
|
||||
rec := test.NewResponseRecorder()
|
||||
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
|
||||
c := e.NewContext(req, rec)
|
||||
m(h)(c)
|
||||
assert.Equal(t, echo.DELETE, req.Method())
|
||||
|
||||
// Override with body parameter
|
||||
rq = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method=DELETE")))
|
||||
rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
rc = test.NewResponseRecorder()
|
||||
c = e.NewContext(rq, rc)
|
||||
h(c)
|
||||
assert.Equal(t, "DELETE", rc.Body.String())
|
||||
// Override with form parameter
|
||||
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
|
||||
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE)))
|
||||
rec = test.NewResponseRecorder()
|
||||
req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm)
|
||||
c = e.NewContext(req, rec)
|
||||
m(h)(c)
|
||||
assert.Equal(t, echo.DELETE, req.Method())
|
||||
|
||||
// Ignore GET
|
||||
rq = test.NewRequest(echo.GET, "/", nil)
|
||||
rq.Header().Set(HttpMethodOverrideHeader, "DELETE")
|
||||
rc = test.NewResponseRecorder()
|
||||
c = e.NewContext(rq, rc)
|
||||
h(c)
|
||||
assert.Equal(t, "GET", rc.Body.String())
|
||||
// Override with query paramter
|
||||
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
|
||||
req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil)
|
||||
rec = test.NewResponseRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
m(h)(c)
|
||||
assert.Equal(t, echo.DELETE, req.Method())
|
||||
|
||||
// Ignore `GET`
|
||||
req = test.NewRequest(echo.GET, "/", nil)
|
||||
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
|
||||
assert.Equal(t, echo.GET, req.Method())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user