1
0
mirror of https://github.com/labstack/echo.git synced 2025-05-29 23:17:34 +02:00

Change middleware function signature.

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-04-18 21:46:00 -07:00
parent a9c41b367c
commit f134ea3aea
10 changed files with 79 additions and 60 deletions

View File

@ -8,7 +8,7 @@ Echo is a fast HTTP router (zero memory allocation) and micro web framework in G
- Middleware
- `func(*echo.Context)`
- `func(*echo.Context) error`
- `func(echo.HandlerFunc) echo.HandlerFunc error`
- `func(echo.HandlerFunc) echo.HandlerFunc`
- `func(http.Handler) http.Handler`
- `http.Handler`
- `http.HandlerFunc`

101
echo.go
View File

@ -8,6 +8,10 @@ import (
"net/http"
"strings"
"sync"
"github.com/mattn/go-colorable"
"labstack.com/gommon/color"
)
type (
@ -23,7 +27,7 @@ type (
pool sync.Pool
}
Middleware interface{}
MiddlewareFunc func(HandlerFunc) (HandlerFunc, error)
MiddlewareFunc func(HandlerFunc) HandlerFunc
Handler interface{}
HandlerFunc func(*Context) error
@ -84,25 +88,7 @@ var (
// New creates an Echo instance.
func New() (e *Echo) {
e = &Echo{
maxParam: 5,
notFoundHandler: func(c *Context) error {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return nil
},
httpErrorHandler: func(err error, c *Context) {
http.Error(c.Response, err.Error(), http.StatusInternalServerError)
},
binder: func(r *http.Request, v interface{}) error {
ct := r.Header.Get(HeaderContentType)
if strings.HasPrefix(ct, MIMEJSON) {
return json.NewDecoder(r.Body).Decode(v)
} else if strings.HasPrefix(ct, MIMEForm) {
return nil
}
return ErrUnsupportedMediaType
},
}
e = &Echo{}
e.Router = NewRouter(e)
e.pool.New = func() interface{} {
return &Context{
@ -112,6 +98,31 @@ func New() (e *Echo) {
echo: e, // TODO: Do we need this?
}
}
//----------
// Defaults
//----------
e.MaxParam(5)
e.NotFoundHandler(func(c *Context) {
http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound)
})
e.HTTPErrorHandler(func(err error, c *Context) {
if err != nil {
// TODO: Warning
log.Println(color.Yellow("echo: HTTP error handler not registered"))
http.Error(c.Response, err.Error(), http.StatusInternalServerError)
}
})
e.Binder(func(r *http.Request, v interface{}) error {
ct := r.Header.Get(HeaderContentType)
if strings.HasPrefix(ct, MIMEJSON) {
return json.NewDecoder(r.Body).Decode(v)
} else if strings.HasPrefix(ct, MIMEForm) {
return nil
}
return ErrUnsupportedMediaType
})
return
}
@ -248,12 +259,8 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.reset(w, r, e)
// Middleware
var err error
for i := len(e.middleware) - 1; i >= 0; i-- {
if h, err = e.middleware[i](h); err != nil {
e.httpErrorHandler(err, c)
return
}
h = e.middleware[i](h)
}
// Handler
@ -290,50 +297,52 @@ func (e *Echo) RunTLSServer(server *http.Server, certFile, keyFile string) {
func wrapM(m Middleware) MiddlewareFunc {
switch m := m.(type) {
case func(*Context):
return func(h HandlerFunc) (HandlerFunc, error) {
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
m(c)
return h(c)
}, nil
}
}
case func(*Context) error:
return func(h HandlerFunc) (HandlerFunc, error) {
var err error
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
err = m(c)
if err := m(c); err != nil {
return err
}
return h(c)
}, err
}
}
case func(HandlerFunc) (HandlerFunc, error):
return MiddlewareFunc(m)
case func(HandlerFunc) HandlerFunc:
return m
case func(http.Handler) http.Handler:
return func(h HandlerFunc) (HandlerFunc, error) {
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
m(h).ServeHTTP(c.Response, c.Request)
return h(c)
}, nil
}
}
case http.Handler, http.HandlerFunc:
return func(h HandlerFunc) (HandlerFunc, error) {
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
m.(http.Handler).ServeHTTP(c.Response, c.Request)
return h(c)
}, nil
}
}
case func(http.ResponseWriter, *http.Request):
return func(h HandlerFunc) (HandlerFunc, error) {
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
m(c.Response, c.Request)
return h(c)
}, nil
}
}
case func(http.ResponseWriter, *http.Request) error:
return func(h HandlerFunc) (HandlerFunc, error) {
var err error
return func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
err = m(c.Response, c.Request)
if err := m(c.Response, c.Request); err != nil {
return err
}
return h(c)
}, err
}
}
default:
panic("echo: unknown middleware")
@ -349,7 +358,7 @@ func wrapH(h Handler) HandlerFunc {
return nil
}
case func(*Context) error:
return HandlerFunc(h)
return h
case http.Handler, http.HandlerFunc:
return func(c *Context) error {
h.(http.Handler).ServeHTTP(c.Response, c.Request)
@ -368,3 +377,7 @@ func wrapH(h Handler) HandlerFunc {
panic("echo: unknown handler")
}
}
func init() {
log.SetOutput(colorable.NewColorableStdout())
}

View File

@ -30,7 +30,7 @@ func TestEchoMaxParam(t *testing.T) {
func TestEchoIndex(t *testing.T) {
e := New()
e.Index("examples/public/index.html")
e.Index("examples/web/public/index.html")
w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/", nil)
e.ServeHTTP(w, r)
@ -41,7 +41,7 @@ func TestEchoIndex(t *testing.T) {
func TestEchoStatic(t *testing.T) {
e := New()
e.Static("/scripts", "examples/public/scripts")
e.Static("/scripts", "examples/web/public/scripts")
w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/scripts/main.js", nil)
e.ServeHTTP(w, r)
@ -66,11 +66,11 @@ func TestEchoMiddleware(t *testing.T) {
})
// func(echo.HandlerFunc) (echo.HandlerFunc, error)
e.Use(func(h HandlerFunc) (HandlerFunc, error) {
e.Use(func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
b.WriteString("c")
return h(c)
}, nil
}
})
// http.HandlerFunc
@ -97,8 +97,9 @@ func TestEchoMiddleware(t *testing.T) {
})
// func(http.ResponseWriter, *http.Request) error
e.Use(func(w http.ResponseWriter, r *http.Request) {
e.Use(func(w http.ResponseWriter, r *http.Request) error {
b.WriteString("h")
return nil
})
// Route

View File

@ -12,12 +12,11 @@ type (
user struct {
ID int
Name string
Age int
}
)
var (
users = map[int]user{}
users = map[int]*user{}
seq = 1
)
@ -32,7 +31,7 @@ func createUser(c *echo.Context) error {
if err := c.Bind(u); err != nil {
return err
}
users[u.ID] = *u
users[u.ID] = u
seq++
return c.JSON(http.StatusCreated, u)
}
@ -43,9 +42,13 @@ func getUser(c *echo.Context) error {
}
func updateUser(c *echo.Context) error {
// id, _ := strconv.Atoi(c.Param("id"))
// users[id]
return c.NoContent(http.StatusNoContent)
u := new(user)
if err := c.Bind(u); err != nil {
return err
}
id, _ := strconv.Atoi(c.Param("id"))
users[id].Name = u.Name
return c.JSON(http.StatusOK, users[id])
}
func deleteUser(c *echo.Context) error {
@ -63,9 +66,9 @@ func main() {
// Routes
e.Post("/users", createUser)
e.Get("/users/:id", getUser)
e.Put("/users/:id", updateUser)
e.Patch("/users/:id", updateUser)
e.Delete("/users/:id", deleteUser)
// Start server
e.Run(":8080")
e.Run(":4444")
}

View File

@ -22,5 +22,5 @@ func main() {
e.Get("/", hello)
// Start server
e.Run(":8080")
e.Run(":4444")
}

View File

@ -3,6 +3,8 @@ package echo
import (
"log"
"net/http"
"labstack.com/gommon/color"
)
type (
@ -21,7 +23,7 @@ func (r *response) Header() http.Header {
func (r *response) WriteHeader(n int) {
if r.committed {
// TODO: Warning
log.Println("echo: response already committed")
log.Println(color.Yellow("echo: response already committed"))
return
}
r.status = n