1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Merge branch 'coderhaoxin-method'

This commit is contained in:
Vishal Rana 2016-04-27 21:09:01 -07:00
commit 96824ff627
7 changed files with 242 additions and 95 deletions

View File

@ -162,6 +162,7 @@ const (
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXHTTPMethodOverride = "X-HTTP-Method-Override"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXRealIP = "X-Real-IP"
HeaderServer = "Server"

69
middleware/basic_auth.go Normal file
View File

@ -0,0 +1,69 @@
package middleware
import (
"encoding/base64"
"net/http"
"github.com/labstack/echo"
)
type (
// BasicAuthConfig defines the config for HTTP basic auth middleware.
BasicAuthConfig struct {
// Validator is the function to validate basic auth credentials.
Validator BasicAuthValidator
}
// BasicAuthValidator defines a function to validate basic auth credentials.
BasicAuthValidator func(string, string) bool
)
const (
basic = "Basic"
)
var (
// DefaultBasicAuthConfig is the default basic auth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{}
)
// BasicAuth returns an HTTP basic auth middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
}
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(basic)
if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if config.Validator(cred[:i], cred[i+1:]) {
return next(c)
}
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized
}
}
}
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
}
}
}

View File

@ -0,0 +1,50 @@
package middleware
import (
"encoding/base64"
"net/http"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
res := test.NewResponseRecorder()
c := e.NewContext(req, res)
f := func(u, p string) bool {
if u == "joe" && p == "secret" {
return true
}
return false
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Empty Authorization header
req.Header().Set(echo.HeaderAuthorization, "")
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
}

View File

@ -1,7 +1,6 @@
package middleware
import (
"encoding/base64"
"fmt"
"net/http"
@ -10,15 +9,6 @@ import (
)
type (
// BasicAuthConfig defines the config for HTTP basic auth middleware.
BasicAuthConfig struct {
// AuthFunc is the function to validate basic auth credentials.
AuthFunc BasicAuthFunc
}
// BasicAuthFunc defines a function to validate basic auth credentials.
BasicAuthFunc func(string, string) bool
// JWTAuthConfig defines the config for JWT auth middleware.
JWTAuthConfig struct {
// SigningKey is the key to validate token.
@ -45,7 +35,6 @@ type (
)
const (
basic = "Basic"
bearer = "Bearer"
)
@ -55,9 +44,6 @@ const (
)
var (
// DefaultBasicAuthConfig is the default basic auth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{}
// DefaultJWTAuthConfig is the default JWT auth middleware config.
DefaultJWTAuthConfig = JWTAuthConfig{
SigningMethod: AlgorithmHS256,
@ -66,47 +52,6 @@ var (
}
)
// BasicAuth returns an HTTP basic auth middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthFunc) echo.MiddlewareFunc {
c := DefaultBasicAuthConfig
c.AuthFunc = fn
return BasicAuthWithConfig(c)
}
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(basic)
if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err
}
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if config.AuthFunc(cred[:i], cred[i+1:]) {
return next(c)
}
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm=Restricted")
return echo.ErrUnauthorized
}
}
}
return echo.NewHTTPError(http.StatusBadRequest, "invalid basic-auth authorization header="+auth)
}
}
}
// JWTAuth returns a JSON Web Token (JWT) auth middleware.
//
// For valid token, it sets the user in context and calls next handler.

View File

@ -1,7 +1,6 @@
package middleware
import (
"encoding/base64"
"net/http"
"testing"
@ -11,45 +10,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
res := test.NewResponseRecorder()
c := e.NewContext(req, res)
f := func(u, p string) bool {
if u == "joe" && p == "secret" {
return true
}
return false
}
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Empty Authorization header
req.Header().Set(echo.HeaderAuthorization, "")
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
}
func TestJWTAuth(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)

View File

@ -0,0 +1,73 @@
package middleware
import (
"github.com/labstack/echo"
)
type (
// MethodOverrideConfig defines the config for method override middleware.
MethodOverrideConfig struct {
Getter MethodOverrideGetter
}
// MethodOverrideGetter is a function that gets overridden method from the request
// Optional, with default values as `MethodFromHeader(echo.HeaderXHTTPMethodOverride)`.
MethodOverrideGetter func(echo.Context) string
)
var (
// DefaultMethodOverrideConfig is the default method override middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
)
// MethodOverride returns a method override middleware.
// Method override middleware checks for the overriden method from the request and
// uses it instead of the original method.
//
// For security reasons, only `POST` method can be overriden.
func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
// MethodOverrideWithConfig returns a method override middleware from config.
// See `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
if req.Method() == echo.POST {
m := config.Getter(c)
if m != "" {
c.Request().SetMethod(m)
}
}
return next(c)
}
}
}
// MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from
// the request header.
func MethodFromHeader(header string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.Request().Header().Get(header)
}
}
// MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the
// form parameter.
func MethodFromForm(param string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.FormValue(param)
}
}
// MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from
// the query parameter.
func MethodFromQuery(param string) MethodOverrideGetter {
return func(c echo.Context) string {
return c.QueryParam(param)
}
}

View File

@ -0,0 +1,49 @@
package middleware
import (
"bytes"
"net/http"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
)
func TestMethodOverride(t *testing.T) {
e := echo.New()
m := MethodOverride()
h := func(c echo.Context) error {
return c.String(http.StatusOK, "Okay")
}
// Override with http header
req := test.NewRequest(echo.POST, "/", nil)
rec := test.NewResponseRecorder()
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
c := e.NewContext(req, rec)
m(h)(c)
assert.Equal(t, echo.DELETE, req.Method())
// Override with form parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE)))
rec = test.NewResponseRecorder()
req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec)
m(h)(c)
assert.Equal(t, echo.DELETE, req.Method())
// Override with query paramter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil)
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
m(h)(c)
assert.Equal(t, echo.DELETE, req.Method())
// Ignore `GET`
req = test.NewRequest(echo.GET, "/", nil)
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
assert.Equal(t, echo.GET, req.Method())
}