1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Middleware interface

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-02-07 23:02:37 -08:00
parent f27de9a804
commit 65fcca2ce3
8 changed files with 74 additions and 54 deletions

8
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: 57f60c6f58e40de3ea10123144095336abe2a4024e256feea9721e4e89ad38be
updated: 2016-02-04T19:09:01.247772095-08:00
updated: 2016-02-07T21:37:59.770272986-08:00
imports:
- name: github.com/labstack/gommon
version: 28bf9248c3e81039c18bff6a1cf2a7f0841caf35
@ -11,15 +11,15 @@ imports:
- name: github.com/mattn/go-isatty
version: 56b76bdf51f7708750eac80fa38b952bb9f32639
- name: github.com/valyala/fasthttp
version: df213349e25e909911691ffe752149440d969641
version: 338fe46307aad35771b73f3f146dec3851c23f7f
- name: golang.org/x/crypto
version: 1f22c0103821b9390939b6776727195525381532
- name: golang.org/x/net
version: 493a26246902f2887349f625a5f846bf0286af49
version: 7f88271ea9913b72aca44fa7fc8af919eacc17ce
subpackages:
- /context
- http2
- websocket
- name: golang.org/x/text
version: cd1d59e467ef512633026c0f15282e108e41d453
version: 2ea5e055772cf5daa0b1478d6e88c8d0c3d4cb79
devImports: []

View File

@ -12,38 +12,40 @@ type (
)
const (
Basic = "Basic"
basic = "Basic"
)
// BasicAuth returns an HTTP basic authentication middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Skip WebSocket
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
return nil
}
func BasicAuth(fn BasicValidateFunc) MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Skip WebSocket
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
return nil
}
auth := c.Request().Header().Get(echo.Authorization)
l := len(Basic)
auth := c.Request().Header().Get(echo.Authorization)
l := len(basic)
if len(auth) > l+1 && auth[:l] == Basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err == nil {
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if fn(cred[:i], cred[i+1:]) {
return nil
if len(auth) > l+1 && auth[:l] == basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err == nil {
cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
if fn(cred[:i], cred[i+1:]) {
return nil
}
}
}
}
}
c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted")
return echo.NewHTTPError(http.StatusUnauthorized)
}
c.Response().Header().Set(echo.WWWAuthenticate, Basic+" realm=Restricted")
return echo.NewHTTPError(http.StatusUnauthorized)
}
}

View File

@ -21,38 +21,41 @@ func TestBasicAuth(t *testing.T) {
}
return false
}
ba := BasicAuth(fn)
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw := BasicAuth(fn)(h)
// Valid credentials
auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.Authorization, auth)
assert.NoError(t, ba(c))
assert.NoError(t, mw(c))
//---------------------
// Invalid credentials
//---------------------
// Incorrect password
auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.Authorization, auth)
he := ba(c).(*echo.HTTPError)
he := mw(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
// Empty Authorization header
req.Header().Set(echo.Authorization, "")
he = ba(c).(*echo.HTTPError)
he = mw(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.Authorization, auth)
he = ba(c).(*echo.HTTPError)
he = mw(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
// WebSocket
c.Request().Header().Set(echo.Upgrade, echo.WebSocket)
assert.NoError(t, ba(c))
assert.NoError(t, mw(c))
}

View File

@ -48,10 +48,10 @@ var writerPool = sync.Pool{
// Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme.
func Gzip() echo.MiddlewareFunc {
scheme := "gzip"
func Gzip() MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
scheme := "gzip"
return func(c echo.Context) error {
c.Response().Header().Add(echo.Vary, echo.AcceptEncoding)
if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) {

View File

@ -8,7 +8,7 @@ import (
"github.com/labstack/gommon/color"
)
func Logger() echo.MiddlewareFunc {
func Log() MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()

View File

@ -12,18 +12,19 @@ import (
"github.com/stretchr/testify/assert"
)
func TestLogger(t *testing.T) {
func TestLog(t *testing.T) {
// Note: Just for the test coverage, not a real test.
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
c := echo.NewContext(req, rec, e)
// Status 2xx
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
Logger()(h)(c)
mw := Log()(h)
// Status 2xx
mw(c)
// Status 3xx
rec = test.NewResponseRecorder()
@ -31,7 +32,7 @@ func TestLogger(t *testing.T) {
h = func(c echo.Context) error {
return c.String(http.StatusTemporaryRedirect, "test")
}
Logger()(h)(c)
mw(c)
// Status 4xx
rec = test.NewResponseRecorder()
@ -39,7 +40,7 @@ func TestLogger(t *testing.T) {
h = func(c echo.Context) error {
return c.String(http.StatusNotFound, "test")
}
Logger()(h)(c)
mw(c)
// Status 5xx with empty path
req = test.NewRequest(echo.GET, "", nil)
@ -48,10 +49,10 @@ func TestLogger(t *testing.T) {
h = func(c echo.Context) error {
return errors.New("error")
}
Logger()(h)(c)
mw(c)
}
func TestLoggerIPAddress(t *testing.T) {
func TestLogIPAddress(t *testing.T) {
e := echo.New()
req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder()
@ -62,23 +63,22 @@ func TestLoggerIPAddress(t *testing.T) {
h := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw := Logger()
mw := Log()(h)
// With X-Real-IP
req.Header().Add(echo.XRealIP, ip)
mw(h)(c)
mw(c)
assert.Contains(t, buf.String(), ip)
// With X-Forwarded-For
buf.Reset()
req.Header().Del(echo.XRealIP)
req.Header().Add(echo.XForwardedFor, ip)
mw(h)(c)
mw(c)
assert.Contains(t, buf.String(), ip)
// with req.RemoteAddr
buf.Reset()
mw(h)(c)
mw(c)
assert.Contains(t, buf.String(), ip)
}

15
middleware/middleware.go Normal file
View File

@ -0,0 +1,15 @@
package middleware
import "github.com/labstack/echo"
type (
Middleware interface {
Process(echo.HandlerFunc) echo.HandlerFunc
}
MiddlewareFunc func(echo.HandlerFunc) echo.HandlerFunc
)
func (f MiddlewareFunc) Process(h echo.HandlerFunc) echo.HandlerFunc {
return f(h)
}

View File

@ -10,9 +10,9 @@ import (
// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
func Recover() MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
return func(c echo.Context) error {
defer func() {
if err := recover(); err != nil {