1
0
mirror of https://github.com/labstack/echo.git synced 2025-05-31 23:19:42 +02:00

Change middleware function signature.

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-04-18 21:46:00 -07:00
parent a9c41b367c
commit f134ea3aea
10 changed files with 79 additions and 60 deletions

View File

@ -8,7 +8,7 @@ Echo is a fast HTTP router (zero memory allocation) and micro web framework in G
- Middleware - Middleware
- `func(*echo.Context)` - `func(*echo.Context)`
- `func(*echo.Context) error` - `func(*echo.Context) error`
- `func(echo.HandlerFunc) echo.HandlerFunc error` - `func(echo.HandlerFunc) echo.HandlerFunc`
- `func(http.Handler) http.Handler` - `func(http.Handler) http.Handler`
- `http.Handler` - `http.Handler`
- `http.HandlerFunc` - `http.HandlerFunc`

101
echo.go
View File

@ -8,6 +8,10 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"github.com/mattn/go-colorable"
"labstack.com/gommon/color"
) )
type ( type (
@ -23,7 +27,7 @@ type (
pool sync.Pool pool sync.Pool
} }
Middleware interface{} Middleware interface{}
MiddlewareFunc func(HandlerFunc) (HandlerFunc, error) MiddlewareFunc func(HandlerFunc) HandlerFunc
Handler interface{} Handler interface{}
HandlerFunc func(*Context) error HandlerFunc func(*Context) error
@ -84,25 +88,7 @@ var (
// New creates an Echo instance. // New creates an Echo instance.
func New() (e *Echo) { func New() (e *Echo) {
e = &Echo{ e = &Echo{}
maxParam: 5,
notFoundHandler: func(c *Context) error {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
},
httpErrorHandler: func(err error, c *Context) {
http.Error(c.Response, err.Error(), http.StatusInternalServerError)
},
binder: func(r *http.Request, v interface{}) error {
ct := r.Header.Get(HeaderContentType)
if strings.HasPrefix(ct, MIMEJSON) {
return json.NewDecoder(r.Body).Decode(v)
} else if strings.HasPrefix(ct, MIMEForm) {
return nil
}
return ErrUnsupportedMediaType
},
}
e.Router = NewRouter(e) e.Router = NewRouter(e)
e.pool.New = func() interface{} { e.pool.New = func() interface{} {
return &Context{ return &Context{
@ -112,6 +98,31 @@ func New() (e *Echo) {
echo: e, // TODO: Do we need this? echo: e, // TODO: Do we need this?
} }
} }
//----------
// Defaults
//----------
e.MaxParam(5)
e.NotFoundHandler(func(c *Context) {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
})
e.HTTPErrorHandler(func(err error, c *Context) {
if err != nil {
// TODO: Warning
log.Println(color.Yellow("echo: HTTP error handler not registered"))
http.Error(c.Response, err.Error(), http.StatusInternalServerError)
}
})
e.Binder(func(r *http.Request, v interface{}) error {
ct := r.Header.Get(HeaderContentType)
if strings.HasPrefix(ct, MIMEJSON) {
return json.NewDecoder(r.Body).Decode(v)
} else if strings.HasPrefix(ct, MIMEForm) {
return nil
}
return ErrUnsupportedMediaType
})
return return
} }
@ -248,12 +259,8 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.reset(w, r, e) c.reset(w, r, e)
// Middleware // Middleware
var err error
for i := len(e.middleware) - 1; i >= 0; i-- { for i := len(e.middleware) - 1; i >= 0; i-- {
if h, err = e.middleware[i](h); err != nil { h = e.middleware[i](h)
e.httpErrorHandler(err, c)
return
}
} }
// Handler // Handler
@ -290,50 +297,52 @@ func (e *Echo) RunTLSServer(server *http.Server, certFile, keyFile string) {
func wrapM(m Middleware) MiddlewareFunc { func wrapM(m Middleware) MiddlewareFunc {
switch m := m.(type) { switch m := m.(type) {
case func(*Context): case func(*Context):
return func(h HandlerFunc) (HandlerFunc, error) { return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error { return func(c *Context) error {
m(c) m(c)
return h(c) return h(c)
}, nil }
} }
case func(*Context) error: case func(*Context) error:
return func(h HandlerFunc) (HandlerFunc, error) { return func(h HandlerFunc) HandlerFunc {
var err error
return func(c *Context) error { return func(c *Context) error {
err = m(c) if err := m(c); err != nil {
return err
}
return h(c) return h(c)
}, err }
} }
case func(HandlerFunc) (HandlerFunc, error): case func(HandlerFunc) HandlerFunc:
return MiddlewareFunc(m) return m
case func(http.Handler) http.Handler: case func(http.Handler) http.Handler:
return func(h HandlerFunc) (HandlerFunc, error) { return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error { return func(c *Context) error {
m(h).ServeHTTP(c.Response, c.Request) m(h).ServeHTTP(c.Response, c.Request)
return h(c) return h(c)
}, nil }
} }
case http.Handler, http.HandlerFunc: case http.Handler, http.HandlerFunc:
return func(h HandlerFunc) (HandlerFunc, error) { return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error { return func(c *Context) error {
m.(http.Handler).ServeHTTP(c.Response, c.Request) m.(http.Handler).ServeHTTP(c.Response, c.Request)
return h(c) return h(c)
}, nil }
} }
case func(http.ResponseWriter, *http.Request): case func(http.ResponseWriter, *http.Request):
return func(h HandlerFunc) (HandlerFunc, error) { return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error { return func(c *Context) error {
m(c.Response, c.Request) m(c.Response, c.Request)
return h(c) return h(c)
}, nil }
} }
case func(http.ResponseWriter, *http.Request) error: case func(http.ResponseWriter, *http.Request) error:
return func(h HandlerFunc) (HandlerFunc, error) { return func(h HandlerFunc) HandlerFunc {
var err error
return func(c *Context) error { return func(c *Context) error {
err = m(c.Response, c.Request) if err := m(c.Response, c.Request); err != nil {
return err
}
return h(c) return h(c)
}, err }
} }
default: default:
panic("echo: unknown middleware") panic("echo: unknown middleware")
@ -349,7 +358,7 @@ func wrapH(h Handler) HandlerFunc {
return nil return nil
} }
case func(*Context) error: case func(*Context) error:
return HandlerFunc(h) return h
case http.Handler, http.HandlerFunc: case http.Handler, http.HandlerFunc:
return func(c *Context) error { return func(c *Context) error {
h.(http.Handler).ServeHTTP(c.Response, c.Request) h.(http.Handler).ServeHTTP(c.Response, c.Request)
@ -368,3 +377,7 @@ func wrapH(h Handler) HandlerFunc {
panic("echo: unknown handler") panic("echo: unknown handler")
} }
} }
func init() {
log.SetOutput(colorable.NewColorableStdout())
}

View File

@ -30,7 +30,7 @@ func TestEchoMaxParam(t *testing.T) {
func TestEchoIndex(t *testing.T) { func TestEchoIndex(t *testing.T) {
e := New() e := New()
e.Index("examples/public/index.html") e.Index("examples/web/public/index.html")
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/", nil) r, _ := http.NewRequest(GET, "/", nil)
e.ServeHTTP(w, r) e.ServeHTTP(w, r)
@ -41,7 +41,7 @@ func TestEchoIndex(t *testing.T) {
func TestEchoStatic(t *testing.T) { func TestEchoStatic(t *testing.T) {
e := New() e := New()
e.Static("/scripts", "examples/public/scripts") e.Static("/scripts", "examples/web/public/scripts")
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/scripts/main.js", nil) r, _ := http.NewRequest(GET, "/scripts/main.js", nil)
e.ServeHTTP(w, r) e.ServeHTTP(w, r)
@ -66,11 +66,11 @@ func TestEchoMiddleware(t *testing.T) {
}) })
// func(echo.HandlerFunc) (echo.HandlerFunc, error) // func(echo.HandlerFunc) (echo.HandlerFunc, error)
e.Use(func(h HandlerFunc) (HandlerFunc, error) { e.Use(func(h HandlerFunc) HandlerFunc {
return func(c *Context) error { return func(c *Context) error {
b.WriteString("c") b.WriteString("c")
return h(c) return h(c)
}, nil }
}) })
// http.HandlerFunc // http.HandlerFunc
@ -97,8 +97,9 @@ func TestEchoMiddleware(t *testing.T) {
}) })
// func(http.ResponseWriter, *http.Request) error // func(http.ResponseWriter, *http.Request) error
e.Use(func(w http.ResponseWriter, r *http.Request) { e.Use(func(w http.ResponseWriter, r *http.Request) error {
b.WriteString("h") b.WriteString("h")
return nil
}) })
// Route // Route

View File

@ -12,12 +12,11 @@ type (
user struct { user struct {
ID int ID int
Name string Name string
Age int
} }
) )
var ( var (
users = map[int]user{} users = map[int]*user{}
seq = 1 seq = 1
) )
@ -32,7 +31,7 @@ func createUser(c *echo.Context) error {
if err := c.Bind(u); err != nil { if err := c.Bind(u); err != nil {
return err return err
} }
users[u.ID] = *u users[u.ID] = u
seq++ seq++
return c.JSON(http.StatusCreated, u) return c.JSON(http.StatusCreated, u)
} }
@ -43,9 +42,13 @@ func getUser(c *echo.Context) error {
} }
func updateUser(c *echo.Context) error { func updateUser(c *echo.Context) error {
// id, _ := strconv.Atoi(c.Param("id")) u := new(user)
// users[id] if err := c.Bind(u); err != nil {
return c.NoContent(http.StatusNoContent) return err
}
id, _ := strconv.Atoi(c.Param("id"))
users[id].Name = u.Name
return c.JSON(http.StatusOK, users[id])
} }
func deleteUser(c *echo.Context) error { func deleteUser(c *echo.Context) error {
@ -63,9 +66,9 @@ func main() {
// Routes // Routes
e.Post("/users", createUser) e.Post("/users", createUser)
e.Get("/users/:id", getUser) e.Get("/users/:id", getUser)
e.Put("/users/:id", updateUser) e.Patch("/users/:id", updateUser)
e.Delete("/users/:id", deleteUser) e.Delete("/users/:id", deleteUser)
// Start server // Start server
e.Run(":8080") e.Run(":4444")
} }

View File

@ -22,5 +22,5 @@ func main() {
e.Get("/", hello) e.Get("/", hello)
// Start server // Start server
e.Run(":8080") e.Run(":4444")
} }

View File

@ -3,6 +3,8 @@ package echo
import ( import (
"log" "log"
"net/http" "net/http"
"labstack.com/gommon/color"
) )
type ( type (
@ -21,7 +23,7 @@ func (r *response) Header() http.Header {
func (r *response) WriteHeader(n int) { func (r *response) WriteHeader(n int) {
if r.committed { if r.committed {
// TODO: Warning // TODO: Warning
log.Println("echo: response already committed") log.Println(color.Yellow("echo: response already committed"))
return return
} }
r.status = n r.status = n