mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Middleware interface
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
f27de9a804
commit
65fcca2ce3
8
glide.lock
generated
8
glide.lock
generated
@ -1,5 +1,5 @@
|
||||
hash: 57f60c6f58e40de3ea10123144095336abe2a4024e256feea9721e4e89ad38be
|
||||
updated: 2016-02-04T19:09:01.247772095-08:00
|
||||
updated: 2016-02-07T21:37:59.770272986-08:00
|
||||
imports:
|
||||
- name: github.com/labstack/gommon
|
||||
version: 28bf9248c3e81039c18bff6a1cf2a7f0841caf35
|
||||
@ -11,15 +11,15 @@ imports:
|
||||
- name: github.com/mattn/go-isatty
|
||||
version: 56b76bdf51f7708750eac80fa38b952bb9f32639
|
||||
- name: github.com/valyala/fasthttp
|
||||
version: df213349e25e909911691ffe752149440d969641
|
||||
version: 338fe46307aad35771b73f3f146dec3851c23f7f
|
||||
- name: golang.org/x/crypto
|
||||
version: 1f22c0103821b9390939b6776727195525381532
|
||||
- name: golang.org/x/net
|
||||
version: 493a26246902f2887349f625a5f846bf0286af49
|
||||
version: 7f88271ea9913b72aca44fa7fc8af919eacc17ce
|
||||
subpackages:
|
||||
- /context
|
||||
- http2
|
||||
- websocket
|
||||
- name: golang.org/x/text
|
||||
version: cd1d59e467ef512633026c0f15282e108e41d453
|
||||
version: 2ea5e055772cf5daa0b1478d6e88c8d0c3d4cb79
|
||||
devImports: []
|
||||
|
@ -12,38 +12,40 @@ type (
|
||||
)
|
||||
|
||||
const (
|
||||
Basic = "Basic"
|
||||
basic = "Basic"
|
||||
)
|
||||
|
||||
// BasicAuth returns an HTTP basic authentication middleware.
|
||||
//
|
||||
// For valid credentials it calls the next handler.
|
||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||
func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Skip WebSocket
|
||||
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
|
||||
return nil
|
||||
}
|
||||
func BasicAuth(fn BasicValidateFunc) MiddlewareFunc {
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Skip WebSocket
|
||||
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
|
||||
return nil
|
||||
}
|
||||
|
||||
auth := c.Request().Header().Get(echo.Authorization)
|
||||
l := len(Basic)
|
||||
auth := c.Request().Header().Get(echo.Authorization)
|
||||
l := len(basic)
|
||||
|
||||
if len(auth) > l+1 && auth[:l] == Basic {
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err == nil {
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
if fn(cred[:i], cred[i+1:]) {
|
||||
return nil
|
||||
if len(auth) > l+1 && auth[:l] == basic {
|
||||
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
|
||||
if err == nil {
|
||||
cred := string(b)
|
||||
for i := 0; i < len(cred); i++ {
|
||||
if cred[i] == ':' {
|
||||
// Verify credentials
|
||||
if fn(cred[:i], cred[i+1:]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted")
|
||||
return echo.NewHTTPError(http.StatusUnauthorized)
|
||||
}
|
||||
c.Response().Header().Set(echo.WWWAuthenticate, Basic+" realm=Restricted")
|
||||
return echo.NewHTTPError(http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
@ -21,38 +21,41 @@ func TestBasicAuth(t *testing.T) {
|
||||
}
|
||||
return false
|
||||
}
|
||||
ba := BasicAuth(fn)
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
mw := BasicAuth(fn)(h)
|
||||
|
||||
// Valid credentials
|
||||
auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header().Set(echo.Authorization, auth)
|
||||
assert.NoError(t, ba(c))
|
||||
assert.NoError(t, mw(c))
|
||||
|
||||
//---------------------
|
||||
// Invalid credentials
|
||||
//---------------------
|
||||
|
||||
// Incorrect password
|
||||
auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
|
||||
req.Header().Set(echo.Authorization, auth)
|
||||
he := ba(c).(*echo.HTTPError)
|
||||
he := mw(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code())
|
||||
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
|
||||
|
||||
// Empty Authorization header
|
||||
req.Header().Set(echo.Authorization, "")
|
||||
he = ba(c).(*echo.HTTPError)
|
||||
he = mw(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code())
|
||||
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header().Set(echo.Authorization, auth)
|
||||
he = ba(c).(*echo.HTTPError)
|
||||
he = mw(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code())
|
||||
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
|
||||
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
|
||||
|
||||
// WebSocket
|
||||
c.Request().Header().Set(echo.Upgrade, echo.WebSocket)
|
||||
assert.NoError(t, ba(c))
|
||||
assert.NoError(t, mw(c))
|
||||
}
|
||||
|
@ -48,10 +48,10 @@ var writerPool = sync.Pool{
|
||||
|
||||
// Gzip returns a middleware which compresses HTTP response using gzip compression
|
||||
// scheme.
|
||||
func Gzip() echo.MiddlewareFunc {
|
||||
scheme := "gzip"
|
||||
|
||||
func Gzip() MiddlewareFunc {
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
scheme := "gzip"
|
||||
|
||||
return func(c echo.Context) error {
|
||||
c.Response().Header().Add(echo.Vary, echo.AcceptEncoding)
|
||||
if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) {
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"github.com/labstack/gommon/color"
|
||||
)
|
||||
|
||||
func Logger() echo.MiddlewareFunc {
|
||||
func Log() MiddlewareFunc {
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
req := c.Request()
|
@ -12,18 +12,19 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
func TestLog(t *testing.T) {
|
||||
// Note: Just for the test coverage, not a real test.
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.GET, "/", nil)
|
||||
rec := test.NewResponseRecorder()
|
||||
c := echo.NewContext(req, rec, e)
|
||||
|
||||
// Status 2xx
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
Logger()(h)(c)
|
||||
mw := Log()(h)
|
||||
|
||||
// Status 2xx
|
||||
mw(c)
|
||||
|
||||
// Status 3xx
|
||||
rec = test.NewResponseRecorder()
|
||||
@ -31,7 +32,7 @@ func TestLogger(t *testing.T) {
|
||||
h = func(c echo.Context) error {
|
||||
return c.String(http.StatusTemporaryRedirect, "test")
|
||||
}
|
||||
Logger()(h)(c)
|
||||
mw(c)
|
||||
|
||||
// Status 4xx
|
||||
rec = test.NewResponseRecorder()
|
||||
@ -39,7 +40,7 @@ func TestLogger(t *testing.T) {
|
||||
h = func(c echo.Context) error {
|
||||
return c.String(http.StatusNotFound, "test")
|
||||
}
|
||||
Logger()(h)(c)
|
||||
mw(c)
|
||||
|
||||
// Status 5xx with empty path
|
||||
req = test.NewRequest(echo.GET, "", nil)
|
||||
@ -48,10 +49,10 @@ func TestLogger(t *testing.T) {
|
||||
h = func(c echo.Context) error {
|
||||
return errors.New("error")
|
||||
}
|
||||
Logger()(h)(c)
|
||||
mw(c)
|
||||
}
|
||||
|
||||
func TestLoggerIPAddress(t *testing.T) {
|
||||
func TestLogIPAddress(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := test.NewRequest(echo.GET, "/", nil)
|
||||
rec := test.NewResponseRecorder()
|
||||
@ -62,23 +63,22 @@ func TestLoggerIPAddress(t *testing.T) {
|
||||
h := func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
}
|
||||
|
||||
mw := Logger()
|
||||
mw := Log()(h)
|
||||
|
||||
// With X-Real-IP
|
||||
req.Header().Add(echo.XRealIP, ip)
|
||||
mw(h)(c)
|
||||
mw(c)
|
||||
assert.Contains(t, buf.String(), ip)
|
||||
|
||||
// With X-Forwarded-For
|
||||
buf.Reset()
|
||||
req.Header().Del(echo.XRealIP)
|
||||
req.Header().Add(echo.XForwardedFor, ip)
|
||||
mw(h)(c)
|
||||
mw(c)
|
||||
assert.Contains(t, buf.String(), ip)
|
||||
|
||||
// with req.RemoteAddr
|
||||
buf.Reset()
|
||||
mw(h)(c)
|
||||
mw(c)
|
||||
assert.Contains(t, buf.String(), ip)
|
||||
}
|
15
middleware/middleware.go
Normal file
15
middleware/middleware.go
Normal file
@ -0,0 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/labstack/echo"
|
||||
|
||||
type (
|
||||
Middleware interface {
|
||||
Process(echo.HandlerFunc) echo.HandlerFunc
|
||||
}
|
||||
|
||||
MiddlewareFunc func(echo.HandlerFunc) echo.HandlerFunc
|
||||
)
|
||||
|
||||
func (f MiddlewareFunc) Process(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
return f(h)
|
||||
}
|
@ -10,9 +10,9 @@ import (
|
||||
|
||||
// Recover returns a middleware which recovers from panics anywhere in the chain
|
||||
// and handles the control to the centralized HTTPErrorHandler.
|
||||
func Recover() echo.MiddlewareFunc {
|
||||
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
|
||||
func Recover() MiddlewareFunc {
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
|
||||
return func(c echo.Context) error {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user