1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

Added method override middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-04-27 21:08:06 -07:00
parent 7f819e4d71
commit 24a19223b0
7 changed files with 204 additions and 139 deletions

View File

@ -162,6 +162,7 @@ const (
HeaderUpgrade = "Upgrade" HeaderUpgrade = "Upgrade"
HeaderVary = "Vary" HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate" HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXForwardedFor = "X-Forwarded-For" HeaderXForwardedFor = "X-Forwarded-For"
HeaderXRealIP = "X-Real-IP" HeaderXRealIP = "X-Real-IP"
HeaderServer = "Server" HeaderServer = "Server"

69
middleware/basic_auth.go Normal file
View File

@ -0,0 +1,69 @@
package middleware
import (
"encoding/base64"
"net/http"
"github.com/labstack/echo"
)
type (
// BasicAuthConfig defines the config for HTTP basic auth middleware.
BasicAuthConfig struct {
// Validator is the function to validate basic auth credentials.
Validator BasicAuthValidator
}
// BasicAuthValidator defines a function to validate basic auth credentials.
BasicAuthValidator func(string, string) bool
)
const (
basic = "Basic"
)
var (
// DefaultBasicAuthConfig is the default basic auth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{}
)
// BasicAuth returns an HTTP basic auth middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
}
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(basic)
if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if config.Validator(cred[:i], cred[i+1:]) {
return next(c)
}
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized
}
}
}
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
}
}
}

View File

@ -0,0 +1,50 @@
package middleware
import (
"encoding/base64"
"net/http"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
res := test.NewResponseRecorder()
c := e.NewContext(req, res)
f := func(u, p string) bool {
if u == "joe" && p == "secret" {
return true
}
return false
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Empty Authorization header
req.Header().Set(echo.HeaderAuthorization, "")
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
}

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
@ -10,15 +9,6 @@ import (
) )
type ( type (
// BasicAuthConfig defines the config for HTTP basic auth middleware.
BasicAuthConfig struct {
// AuthFunc is the function to validate basic auth credentials.
AuthFunc BasicAuthFunc
}
// BasicAuthFunc defines a function to validate basic auth credentials.
BasicAuthFunc func(string, string) bool
// JWTAuthConfig defines the config for JWT auth middleware. // JWTAuthConfig defines the config for JWT auth middleware.
JWTAuthConfig struct { JWTAuthConfig struct {
// SigningKey is the key to validate token. // SigningKey is the key to validate token.
@ -45,7 +35,6 @@ type (
) )
const ( const (
basic = "Basic"
bearer = "Bearer" bearer = "Bearer"
) )
@ -55,9 +44,6 @@ const (
) )
var ( var (
// DefaultBasicAuthConfig is the default basic auth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{}
// DefaultJWTAuthConfig is the default JWT auth middleware config. // DefaultJWTAuthConfig is the default JWT auth middleware config.
DefaultJWTAuthConfig = JWTAuthConfig{ DefaultJWTAuthConfig = JWTAuthConfig{
SigningMethod: AlgorithmHS256, SigningMethod: AlgorithmHS256,
@ -66,47 +52,6 @@ var (
} }
) )
// BasicAuth returns an HTTP basic auth middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.AuthFunc = fn
return BasicAuthWithConfig(c)
}
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(basic)
if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if config.AuthFunc(cred[:i], cred[i+1:]) {
return next(c)
}
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized
}
}
}
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
}
}
}
// JWTAuth returns a JSON Web Token (JWT) auth middleware. // JWTAuth returns a JSON Web Token (JWT) auth middleware.
// //
// For valid token, it sets the user in context and calls next handler. // For valid token, it sets the user in context and calls next handler.

View File

@ -1,7 +1,6 @@
package middleware package middleware
import ( import (
"encoding/base64"
"net/http" "net/http"
"testing" "testing"
@ -11,45 +10,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBasicAuth(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
res := test.NewResponseRecorder()
c := e.NewContext(req, res)
f := func(u, p string) bool {
if u == "joe" && p == "secret" {
return true
}
return false
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Empty Authorization header
req.Header().Set(echo.HeaderAuthorization, "")
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
}
func TestJWTAuth(t *testing.T) { func TestJWTAuth(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req := test.NewRequest(echo.GET, "/", nil)

View File

@ -4,36 +4,70 @@ import (
"github.com/labstack/echo" "github.com/labstack/echo"
) )
const ( type (
HttpMethodOverrideHeader = "X-HTTP-Method-Override" // MethodOverrideConfig defines the config for method override middleware.
MethodOverrideConfig struct {
Getter MethodOverrideGetter
}
// MethodOverrideGetter is a function that gets overridden method from the request
// Optional, with default values as `MethodFromHeader(echo.HeaderXHTTPMethodOverride)`.
MethodOverrideGetter func(echo.Context) string
) )
func OverrideMethod() echo.MiddlewareFunc { var (
return Override() // DefaultMethodOverrideConfig is the default method override middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
)
// MethodOverride returns a method override middleware.
// Method override middleware checks for the overriden method from the request and
// uses it instead of the original method.
//
// For security reasons, only `POST` method can be overriden.
func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
} }
// Override checks for the X-HTTP-Method-Override header // MethodOverrideWithConfig returns a method override middleware from config.
// or the body for parameter, `_method` // See `MethodOverride()`.
// and uses the http method instead of Request.Method. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
// It isn't secure to override e.g a GET to a POST,
// so only Request.Method which are POSTs are considered.
func Override() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
originalMethod := c.Request().Method() req := c.Request()
if req.Method() == echo.POST {
if originalMethod == "POST" { m := config.Getter(c)
m := c.FormValue("_method")
if m != "" {
c.Request().SetMethod(m)
}
m = c.Request().Header().Get(HttpMethodOverrideHeader)
if m != "" { if m != "" {
c.Request().SetMethod(m) c.Request().SetMethod(m)
} }
} }
return next(c) return next(c)
} }
} }
} }
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
// the request header.
func MethodFromHeader(header string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.Request().Header().Get(header)
}
}
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
// form parameter.
func MethodFromForm(param string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.FormValue(param)
}
}
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
// the query parameter.
func MethodFromQuery(param string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.QueryParam(param)
}
}

View File

@ -10,34 +10,40 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestOverrideMtrhod(t *testing.T) { func TestMethodOverride(t *testing.T) {
e := echo.New() e := echo.New()
methodOverride := OverrideMethod() m := MethodOverride()
h := methodOverride(func(c echo.Context) error { h := func(c echo.Context) error {
return c.String(http.StatusOK, c.Request().Method()) return c.String(http.StatusOK, "Okay")
}) }
// Override with http header // Override with http header
rq := test.NewRequest(echo.POST, "/", nil) req := test.NewRequest(echo.POST, "/", nil)
rq.Header().Set(HttpMethodOverrideHeader, "DELETE") rec := test.NewResponseRecorder()
rc := test.NewResponseRecorder() req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
c := e.NewContext(rq, rc) c := e.NewContext(req, rec)
h(c) m(h)(c)
assert.Equal(t, "DELETE", rc.Body.String()) assert.Equal(t, echo.DELETE, req.Method())
// Override with body parameter // Override with form parameter
rq = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method=DELETE"))) m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm) req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE)))
rc = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = e.NewContext(rq, rc) req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm)
h(c) c = e.NewContext(req, rec)
assert.Equal(t, "DELETE", rc.Body.String()) m(h)(c)
assert.Equal(t, echo.DELETE, req.Method())
// Ignore GET // Override with query paramter
rq = test.NewRequest(echo.GET, "/", nil) m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
rq.Header().Set(HttpMethodOverrideHeader, "DELETE") req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil)
rc = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = e.NewContext(rq, rc) c = e.NewContext(req, rec)
h(c) m(h)(c)
assert.Equal(t, "GET", rc.Body.String()) assert.Equal(t, echo.DELETE, req.Method())
// Ignore `GET`
req = test.NewRequest(echo.GET, "/", nil)
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
assert.Equal(t, echo.GET, req.Method())
} }