1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Added MiddlewareFunc to the list of supported middleware, #47.

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-10 21:09:28 -07:00
parent 8bdbf0ae0f
commit 7f92fbf3ce
3 changed files with 49 additions and 40 deletions

32
echo.go
View File

@ -349,6 +349,19 @@ func (e *Echo) RunTLSServer(server *http.Server, certFile, keyFile string) {
// wraps Middleware // wraps Middleware
func wrapM(m Middleware) MiddlewareFunc { func wrapM(m Middleware) MiddlewareFunc {
switch m := m.(type) { switch m := m.(type) {
case MiddlewareFunc:
return m
case func(HandlerFunc) HandlerFunc:
return m
case func(*Context) *HTTPError:
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
if he := m(c); he != nil {
return he
}
return h(c)
}
}
case func(*Context): case func(*Context):
return func(h HandlerFunc) HandlerFunc { return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError { return func(c *Context) *HTTPError {
@ -359,17 +372,6 @@ func wrapM(m Middleware) MiddlewareFunc {
return nil return nil
} }
} }
case func(*Context) *HTTPError:
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError {
if he := m(c); he != nil {
return he
}
return h(c)
}
}
case func(HandlerFunc) HandlerFunc:
return m
case func(http.Handler) http.Handler: case func(http.Handler) http.Handler:
return func(h HandlerFunc) HandlerFunc { return func(h HandlerFunc) HandlerFunc {
return func(c *Context) (he *HTTPError) { return func(c *Context) (he *HTTPError) {
@ -432,15 +434,15 @@ func wrapH(h Handler) HandlerFunc {
h.(http.Handler).ServeHTTP(c.Response, c.Request) h.(http.Handler).ServeHTTP(c.Response, c.Request)
return nil return nil
} }
case func(http.ResponseWriter, *http.Request) *HTTPError:
return func(c *Context) *HTTPError {
return h(c.Response, c.Request)
}
case func(http.ResponseWriter, *http.Request): case func(http.ResponseWriter, *http.Request):
return func(c *Context) *HTTPError { return func(c *Context) *HTTPError {
h(c.Response, c.Request) h(c.Response, c.Request)
return nil return nil
} }
case func(http.ResponseWriter, *http.Request) *HTTPError:
return func(c *Context) *HTTPError {
return h(c.Response, c.Request)
}
default: default:
panic("echo: unknown handler") panic("echo: unknown handler")
} }

View File

@ -54,51 +54,59 @@ func TestEchoMiddleware(t *testing.T) {
e := New() e := New()
b := new(bytes.Buffer) b := new(bytes.Buffer)
// func(*echo.Context) // MiddlewareFunc
e.Use(func(c *Context) { e.Use(MiddlewareFunc(func(h HandlerFunc) HandlerFunc {
b.WriteString("a") return func(c *Context) *HTTPError {
}) b.WriteString("a")
return h(c)
// func(*echo.Context) *HTTPError }
e.Use(func(c *Context) *HTTPError { }))
b.WriteString("b")
return nil
})
// func(echo.HandlerFunc) (echo.HandlerFunc, error) // func(echo.HandlerFunc) (echo.HandlerFunc, error)
e.Use(func(h HandlerFunc) HandlerFunc { e.Use(func(h HandlerFunc) HandlerFunc {
return func(c *Context) *HTTPError { return func(c *Context) *HTTPError {
b.WriteString("c") b.WriteString("b")
return h(c) return h(c)
} }
}) })
// http.HandlerFunc // func(*echo.Context) *HTTPError
e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { e.Use(func(c *Context) *HTTPError {
b.WriteString("d") b.WriteString("c")
})) return nil
})
// http.Handler // func(*echo.Context)
e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { e.Use(func(c *Context) {
b.WriteString("e") b.WriteString("d")
}))) })
// func(http.Handler) http.Handler // func(http.Handler) http.Handler
e.Use(func(h http.Handler) http.Handler { e.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("f") b.WriteString("e")
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
}) })
}) })
// http.Handler
e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("f")
})))
// http.HandlerFunc
e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("g")
}))
// func(http.ResponseWriter, *http.Request) // func(http.ResponseWriter, *http.Request)
e.Use(func(w http.ResponseWriter, r *http.Request) { e.Use(func(w http.ResponseWriter, r *http.Request) {
b.WriteString("g") b.WriteString("h")
}) })
// func(http.ResponseWriter, *http.Request) *HTTPError // func(http.ResponseWriter, *http.Request) *HTTPError
e.Use(func(w http.ResponseWriter, r *http.Request) *HTTPError { e.Use(func(w http.ResponseWriter, r *http.Request) *HTTPError {
b.WriteString("h") b.WriteString("i")
return nil return nil
}) })
@ -110,8 +118,8 @@ func TestEchoMiddleware(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/hello", nil) r, _ := http.NewRequest(GET, "/hello", nil)
e.ServeHTTP(w, r) e.ServeHTTP(w, r)
if b.String() != "abcdefgh" { if b.String() != "abcdefghi" {
t.Errorf("buffer should be abcdefgh, found %s", b.String()) t.Errorf("buffer should be abcdefghi, found %s", b.String())
} }
if w.Body.String() != "world" { if w.Body.String() != "world" {
t.Error("body should be world") t.Error("body should be world")

View File

@ -535,7 +535,6 @@ func TestRouterPriority(t *testing.T) {
} }
} }
func TestRouterParamNames(t *testing.T) { func TestRouterParamNames(t *testing.T) {
r := New().Router r := New().Router
b := new(bytes.Buffer) b := new(bytes.Buffer)