1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added MiddlewareFunc to the list of supported middleware, #47.

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-10 21:09:28 -07:00
parent 8bdbf0ae0f
commit 7f92fbf3ce
3 changed files with 49 additions and 40 deletions

32
echo.go
View File

@ -349,6 +349,19 @@ func (e *Echo) RunTLSServer(server *http.Server, certFile, keyFile string) {
// wraps Middleware
func wrapM(m Middleware) MiddlewareFunc {
switch m := m.(type) {
case MiddlewareFunc:
return m
case func(HandlerFunc) HandlerFunc:
return m
case func(*Context) *HTTPError:
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
if he := m(c); he != nil {
return he
}
return h(c)
}
}
case func(*Context):
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
@ -359,17 +372,6 @@ func wrapM(m Middleware) MiddlewareFunc {
return nil
}
}
case func(*Context) *HTTPError:
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
if he := m(c); he != nil {
return he
}
return h(c)
}
}
case func(HandlerFunc) HandlerFunc:
return m
case func(http.Handler) http.Handler:
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) (he *HTTPError) {
@ -432,15 +434,15 @@ func wrapH(h Handler) HandlerFunc {
h.(http.Handler).ServeHTTP(c.Response, c.Request)
return nil
}
case func(http.ResponseWriter, *http.Request) *HTTPError:
return func(c *Context) *HTTPError {
return h(c.Response, c.Request)
}
case func(http.ResponseWriter, *http.Request):
return func(c *Context) *HTTPError {
h(c.Response, c.Request)
return nil
}
case func(http.ResponseWriter, *http.Request) *HTTPError:
return func(c *Context) *HTTPError {
return h(c.Response, c.Request)
}
default:
panic("echo: unknown handler")
}

View File

@ -54,51 +54,59 @@ func TestEchoMiddleware(t *testing.T) {
e := New()
b := new(bytes.Buffer)
// func(*echo.Context)
e.Use(func(c *Context) {
b.WriteString("a")
})
// func(*echo.Context) *HTTPError
e.Use(func(c *Context) *HTTPError {
b.WriteString("b")
return nil
})
// MiddlewareFunc
e.Use(MiddlewareFunc(func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
b.WriteString("a")
return h(c)
}
}))
// func(echo.HandlerFunc) (echo.HandlerFunc, error)
e.Use(func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
b.WriteString("c")
b.WriteString("b")
return h(c)
}
})
// http.HandlerFunc
e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("d")
}))
// func(*echo.Context) *HTTPError
e.Use(func(c *Context) *HTTPError {
b.WriteString("c")
return nil
})
// http.Handler
e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("e")
})))
// func(*echo.Context)
e.Use(func(c *Context) {
b.WriteString("d")
})
// func(http.Handler) http.Handler
e.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("f")
b.WriteString("e")
h.ServeHTTP(w, r)
})
})
// http.Handler
e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("f")
})))
// http.HandlerFunc
e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("g")
}))
// func(http.ResponseWriter, *http.Request)
e.Use(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("g")
b.WriteString("h")
})
// func(http.ResponseWriter, *http.Request) *HTTPError
e.Use(func(w http.ResponseWriter, r *http.Request) *HTTPError {
b.WriteString("h")
b.WriteString("i")
return nil
})
@ -110,8 +118,8 @@ func TestEchoMiddleware(t *testing.T) {
w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/hello", nil)
e.ServeHTTP(w, r)
if b.String() != "abcdefgh" {
t.Errorf("buffer should be abcdefgh, found %s", b.String())
if b.String() != "abcdefghi" {
t.Errorf("buffer should be abcdefghi, found %s", b.String())
}
if w.Body.String() != "world" {
t.Error("body should be world")

View File

@ -535,7 +535,6 @@ func TestRouterPriority(t *testing.T) {
}
}
func TestRouterParamNames(t *testing.T) {
r := New().Router
b := new(bytes.Buffer)