diff --git a/README.md b/README.md index 27d0d3f5..3492f917 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ Echo is a fast HTTP router (zero memory allocation) and micro web framework in Go. ## Features + - Fast :rocket: router which smartly resolves conflicting routes - Extensible middleware/handler, supports: - Middleware @@ -21,8 +22,11 @@ Echo is a fast HTTP router (zero memory allocation) and micro web framework in G - Serve static files, including index. ## Benchmark + Based on [julienschmidt/go-http-routing-benchmark] (https://github.com/vishr/go-http-routing-benchmark), April 1, 2015 + ##### [GitHub API](http://developer.github.com/v3) + > Echo: 42728 ns/op, 0 B/op, 0 allocs/op ``` @@ -54,9 +58,11 @@ BenchmarkZeus_GithubAll 2000 752907 ns/op 300688 B/op 2648 all ``` ## Installation + ```go get github.com/labstack/echo``` ## Example + [labstack/echo/example](https://github.com/labstack/echo/tree/master/example) ```go @@ -198,4 +204,5 @@ func init() { } ``` ## License + [MIT](https://github.com/labstack/echo/blob/master/LICENSE) diff --git a/context.go b/context.go index df5b34be..aed7b443 100644 --- a/context.go +++ b/context.go @@ -40,13 +40,13 @@ func (c *Context) Bind(v interface{}) error { } // Render invokes the registered HTML template renderer and sends a text/html -// response. -func (c *Context) Render(name string, data interface{}) error { +// response with status code. +func (c *Context) Render(code int, name string, data interface{}) error { if c.echo.renderer == nil { return ErrNoRenderer } c.Response.Header().Set(HeaderContentType, MIMEHTML+"; charset=utf-8") - c.Response.WriteHeader(http.StatusOK) + c.Response.WriteHeader(code) return c.echo.renderer.Render(c.Response, name, data) } @@ -73,6 +73,12 @@ func (c *Context) HTML(code int, html string) (err error) { return } +// NoContent sends a response with no body and a status code. +func (c *Context) NoContent(code int) error { + c.Response.WriteHeader(code) + return nil +} + // func (c *Context) File(code int, file, name string) { // } diff --git a/context_test.go b/context_test.go index d8d6c302..936c0bfa 100644 --- a/context_test.go +++ b/context_test.go @@ -31,9 +31,10 @@ func TestContext(t *testing.T) { echo: New(), } - //**********// - // Bind // - //**********// + //------ + // Bind + //------ + // JSON r.Header.Set(HeaderContentType, MIMEJSON) u2 := new(user) @@ -58,9 +59,10 @@ func TestContext(t *testing.T) { } // TODO: add verification - //***********// - // Param // - //***********// + //------- + // Param + //------- + // By id c.params = Params{{"id", "1"}} if c.P(0) != "1" { @@ -84,11 +86,11 @@ func TestContext(t *testing.T) { templates: template.Must(template.New("hello").Parse("{{.}}")), } c.echo.renderer = tpl - if err := c.Render("hello", "Joe"); err != nil { + if err := c.Render(http.StatusOK, "hello", "Joe"); err != nil { t.Errorf("render %v", err) } c.echo.renderer = nil - if err := c.Render("hello", "Joe"); err == nil { + if err := c.Render(http.StatusOK, "hello", "Joe"); err == nil { t.Error("render should error out") } diff --git a/echo.go b/echo.go index 7e42d920..5e8a7158 100644 --- a/echo.go +++ b/echo.go @@ -12,20 +12,25 @@ import ( type ( Echo struct { - Router *router - prefix string - middleware []MiddlewareFunc - maxParam byte - notFoundHandler HandlerFunc - binder BindFunc - renderer Renderer - pool sync.Pool + Router *router + prefix string + middleware []MiddlewareFunc + maxParam byte + notFoundHandler HandlerFunc + httpErrorHandler HTTPErrorHandler + binder BindFunc + renderer Renderer + pool sync.Pool } Middleware interface{} - MiddlewareFunc func(HandlerFunc) HandlerFunc + MiddlewareFunc func(HandlerFunc) (HandlerFunc, error) Handler interface{} - HandlerFunc func(*Context) - BindFunc func(r *http.Request, v interface{}) error + HandlerFunc func(*Context) error + + // HTTPErrorHandler is a centralized HTTP error handler. + HTTPErrorHandler func(error, *Context) + + BindFunc func(*http.Request, interface{}) error // Renderer is the interface that wraps the Render method. // @@ -81,8 +86,12 @@ var ( func New() (e *Echo) { e = &Echo{ maxParam: 5, - notFoundHandler: func(c *Context) { + notFoundHandler: func(c *Context) error { http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return nil + }, + httpErrorHandler: func(err error, c *Context) { + http.Error(c.Response, err.Error(), http.StatusInternalServerError) }, binder: func(r *http.Request, v interface{}) error { ct := r.Header.Get(HeaderContentType) @@ -133,6 +142,11 @@ func (e *Echo) NotFoundHandler(h Handler) { e.notFoundHandler = wrapH(h) } +// HTTPErrorHandler registers an HTTP error handler. +func (e *Echo) HTTPErrorHandler(h HTTPErrorHandler) { + e.httpErrorHandler = h +} + // Binder registers a custom binder. It's invoked by Context.Bind API. func (e *Echo) Binder(b BindFunc) { e.binder = b @@ -203,15 +217,17 @@ func (e *Echo) add(method, path string, h Handler) { // Static serves static files. func (e *Echo) Static(path, root string) { fs := http.StripPrefix(path, http.FileServer(http.Dir(root))) - e.Get(path+"/*", func(c *Context) { + e.Get(path+"/*", func(c *Context) error { fs.ServeHTTP(c.Response, c.Request) + return nil }) } // ServeFile serves a file. func (e *Echo) ServeFile(path, file string) { - e.Get(path, func(c *Context) { + e.Get(path, func(c *Context) error { http.ServeFile(c.Response, c.Request, file) + return nil }) } @@ -230,12 +246,21 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { h = e.notFoundHandler } c.reset(w, r, e) + // Middleware + var err error for i := len(e.middleware) - 1; i >= 0; i-- { - h = e.middleware[i](h) + if h, err = e.middleware[i](h); err != nil { + e.httpErrorHandler(err, c) + return + } } + // Handler - h(c) + if err := h(c); err != nil { + e.httpErrorHandler(err, c) + } + e.pool.Put(c) } @@ -265,34 +290,50 @@ func (e *Echo) RunTLSServer(server *http.Server, certFile, keyFile string) { func wrapM(m Middleware) MiddlewareFunc { switch m := m.(type) { case func(*Context): - return func(h HandlerFunc) HandlerFunc { - return func(c *Context) { + return func(h HandlerFunc) (HandlerFunc, error) { + return func(c *Context) error { m(c) - h(c) - } + return h(c) + }, nil } - case func(HandlerFunc) HandlerFunc: + case func(*Context) error: + return func(h HandlerFunc) (HandlerFunc, error) { + var err error + return func(c *Context) error { + err = m(c) + return h(c) + }, err + } + case func(HandlerFunc) (HandlerFunc, error): return MiddlewareFunc(m) case func(http.Handler) http.Handler: - return func(h HandlerFunc) HandlerFunc { - return func(c *Context) { + return func(h HandlerFunc) (HandlerFunc, error) { + return func(c *Context) error { m(h).ServeHTTP(c.Response, c.Request) - h(c) - } + return h(c) + }, nil } case http.Handler, http.HandlerFunc: - return func(h HandlerFunc) HandlerFunc { - return func(c *Context) { + return func(h HandlerFunc) (HandlerFunc, error) { + return func(c *Context) error { m.(http.Handler).ServeHTTP(c.Response, c.Request) - h(c) - } + return h(c) + }, nil } case func(http.ResponseWriter, *http.Request): - return func(h HandlerFunc) HandlerFunc { - return func(c *Context) { + return func(h HandlerFunc) (HandlerFunc, error) { + return func(c *Context) error { m(c.Response, c.Request) - h(c) - } + return h(c) + }, nil + } + case func(http.ResponseWriter, *http.Request) error: + return func(h HandlerFunc) (HandlerFunc, error) { + var err error + return func(c *Context) error { + err = m(c.Response, c.Request) + return h(c) + }, err } default: panic("echo: unknown middleware") @@ -303,14 +344,25 @@ func wrapM(m Middleware) MiddlewareFunc { func wrapH(h Handler) HandlerFunc { switch h := h.(type) { case func(*Context): + return func(c *Context) error { + h(c) + return nil + } + case func(*Context) error: return HandlerFunc(h) case http.Handler, http.HandlerFunc: - return func(c *Context) { + return func(c *Context) error { h.(http.Handler).ServeHTTP(c.Response, c.Request) + return nil } case func(http.ResponseWriter, *http.Request): - return func(c *Context) { + return func(c *Context) error { h(c.Response, c.Request) + return nil + } + case func(http.ResponseWriter, *http.Request) error: + return func(c *Context) error { + return h(c.Response, c.Request) } default: panic("echo: unknown handler") diff --git a/echo_test.go b/echo_test.go index 613754f9..7f77f8c1 100644 --- a/echo_test.go +++ b/echo_test.go @@ -30,7 +30,7 @@ func TestEchoMaxParam(t *testing.T) { func TestEchoIndex(t *testing.T) { e := New() - e.Index("example/public/index.html") + e.Index("examples/public/index.html") w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/", nil) e.ServeHTTP(w, r) @@ -41,7 +41,7 @@ func TestEchoIndex(t *testing.T) { func TestEchoStatic(t *testing.T) { e := New() - e.Static("/scripts", "example/public/scripts") + e.Static("/scripts", "examples/public/scripts") w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/scripts/main.js", nil) e.ServeHTTP(w, r) @@ -59,35 +59,46 @@ func TestEchoMiddleware(t *testing.T) { b.WriteString("a") }) - // func(echo.HandlerFunc) echo.HandlerFunc - e.Use(func(h HandlerFunc) HandlerFunc { - return HandlerFunc(func(c *Context) { - b.WriteString("b") - h(c) - }) + // func(*echo.Context) error + e.Use(func(c *Context) error { + b.WriteString("b") + return nil + }) + + // func(echo.HandlerFunc) (echo.HandlerFunc, error) + e.Use(func(h HandlerFunc) (HandlerFunc, error) { + return func(c *Context) error { + b.WriteString("c") + return h(c) + }, nil }) // http.HandlerFunc e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("c") + b.WriteString("d") })) // http.Handler e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("d") + b.WriteString("e") }))) // func(http.Handler) http.Handler e.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("e") + b.WriteString("f") h.ServeHTTP(w, r) }) }) // func(http.ResponseWriter, *http.Request) e.Use(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("f") + b.WriteString("g") + }) + + // func(http.ResponseWriter, *http.Request) error + e.Use(func(w http.ResponseWriter, r *http.Request) { + b.WriteString("h") }) // Route @@ -98,8 +109,8 @@ func TestEchoMiddleware(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/hello", nil) e.ServeHTTP(w, r) - if b.String() != "abcdef" { - t.Errorf("buffer should be abcdef, found %s", b.String()) + if b.String() != "abcdefgh" { + t.Errorf("buffer should be abcdefgh, found %s", b.String()) } if w.Body.String() != "world" { t.Error("body should be world") @@ -120,10 +131,10 @@ func TestEchoHandler(t *testing.T) { t.Error("body should be 1") } - // http.Handler/http.HandlerFunc - e.Get("/2", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("2")) - })) + // func(*echo.Context) error + e.Get("/2", func(c *Context) { + c.String(http.StatusOK, "2") + }) w = httptest.NewRecorder() r, _ = http.NewRequest(GET, "/2", nil) e.ServeHTTP(w, r) @@ -131,16 +142,39 @@ func TestEchoHandler(t *testing.T) { t.Error("body should be 2") } - // func(http.ResponseWriter, *http.Request) - e.Get("/3", func(w http.ResponseWriter, r *http.Request) { + // http.Handler/http.HandlerFunc + e.Get("/3", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("3")) - }) + })) w = httptest.NewRecorder() r, _ = http.NewRequest(GET, "/3", nil) e.ServeHTTP(w, r) if w.Body.String() != "3" { t.Error("body should be 3") } + + // func(http.ResponseWriter, *http.Request) + e.Get("/4", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("4")) + }) + w = httptest.NewRecorder() + r, _ = http.NewRequest(GET, "/4", nil) + e.ServeHTTP(w, r) + if w.Body.String() != "4" { + t.Error("body should be 4") + } + + // func(http.ResponseWriter, *http.Request) error + e.Get("/5", func(w http.ResponseWriter, r *http.Request) error { + w.Write([]byte("5")) + return nil + }) + w = httptest.NewRecorder() + r, _ = http.NewRequest(GET, "/5", nil) + e.ServeHTTP(w, r) + if w.Body.String() != "5" { + t.Error("body should be 5") + } } func TestEchoGroup(t *testing.T) { @@ -187,15 +221,16 @@ func TestEchoGroup(t *testing.T) { func TestEchoMethod(t *testing.T) { e := New() - e.Connect("/", func(*Context) {}) - e.Delete("/", func(*Context) {}) - e.Get("/", func(*Context) {}) - e.Head("/", func(*Context) {}) - e.Options("/", func(*Context) {}) - e.Patch("/", func(*Context) {}) - e.Post("/", func(*Context) {}) - e.Put("/", func(*Context) {}) - e.Trace("/", func(*Context) {}) + h := func(*Context) {} + e.Connect("/", h) + e.Delete("/", h) + e.Get("/", h) + e.Head("/", h) + e.Options("/", h) + e.Patch("/", h) + e.Post("/", h) + e.Put("/", h) + e.Trace("/", h) } func TestEchoNotFound(t *testing.T) { diff --git a/examples/crud/server.go b/examples/crud/server.go new file mode 100644 index 00000000..5a880189 --- /dev/null +++ b/examples/crud/server.go @@ -0,0 +1,71 @@ +package main + +import ( + "net/http" + "strconv" + + "github.com/labstack/echo" + mw "github.com/labstack/echo/middleware" +) + +type ( + user struct { + ID int + Name string + Age int + } +) + +var ( + users = map[int]user{} + seq = 1 +) + +//---------- +// Handlers +//---------- + +func createUser(c *echo.Context) error { + u := &user{ + ID: seq, + } + if err := c.Bind(u); err != nil { + return err + } + users[u.ID] = *u + seq++ + return c.JSON(http.StatusCreated, u) +} + +func getUser(c *echo.Context) error { + id, _ := strconv.Atoi(c.Param("id")) + return c.JSON(http.StatusOK, users[id]) +} + +func updateUser(c *echo.Context) error { + // id, _ := strconv.Atoi(c.Param("id")) + // users[id] + return c.NoContent(http.StatusNoContent) +} + +func deleteUser(c *echo.Context) error { + id, _ := strconv.Atoi(c.Param("id")) + delete(users, id) + return c.NoContent(http.StatusNoContent) +} + +func main() { + e := echo.New() + + // Middleware + e.Use(mw.Logger) + + // Routes + e.Post("/users", createUser) + e.Get("/users/:id", getUser) + e.Put("/users/:id", updateUser) + e.Delete("/users/:id", deleteUser) + + // Start server + e.Run(":8080") +} diff --git a/examples/hello/server.go b/examples/hello/server.go new file mode 100644 index 00000000..f3adbd21 --- /dev/null +++ b/examples/hello/server.go @@ -0,0 +1,26 @@ +package main + +import ( + "net/http" + + "github.com/labstack/echo" + mw "github.com/labstack/echo/middleware" +) + +// Handler +func hello(c *echo.Context) { + c.String(http.StatusOK, "Hello, World!\n") +} + +func main() { + e := echo.New() + + // Middleware + e.Use(mw.Logger) + + // Routes + e.Get("/", hello) + + // Start server + e.Run(":8080") +} diff --git a/example/main.go b/examples/main.go similarity index 90% rename from example/main.go rename to examples/main.go index 06dbe20b..d2094059 100644 --- a/example/main.go +++ b/examples/main.go @@ -60,14 +60,14 @@ func getUser(c *echo.Context) { func main() { e := echo.New() - //*************************// - // Built-in middleware // - //*************************// + //--------------------- + // Built-in middleware + //--------------------- e.Use(mw.Logger) - //****************************// - // Third-party middleware // - //****************************// + //------------------------ + // Third-party middleware + //------------------------ // https://github.com/rs/cors e.Use(cors.Default().Handler) @@ -85,9 +85,9 @@ func main() { // Serve static files e.Static("/scripts", "public/scripts") - //************// - // Routes // - //************// + //-------- + // Routes + //-------- e.Post("/users", createUser) e.Get("/users", getUsers) e.Get("/users/:id", getUser) diff --git a/example/public/index.html b/examples/public/index.html similarity index 100% rename from example/public/index.html rename to examples/public/index.html diff --git a/example/public/scripts/main.js b/examples/public/scripts/main.js similarity index 100% rename from example/public/scripts/main.js rename to examples/public/scripts/main.js diff --git a/example/public/views/welcome.html b/examples/public/views/welcome.html similarity index 100% rename from example/public/views/welcome.html rename to examples/public/views/welcome.html diff --git a/middleware/auth.go b/middleware/auth.go index 8513d94d..cf978092 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,69 +1,70 @@ package middleware -import ( - "encoding/base64" - "errors" - "strings" - - "github.com/dgrijalva/jwt-go" - "github.com/labstack/echo" -) - -type ( - BasicAuthFunc func(string, string) bool - AuthorizedHandler echo.HandlerFunc - UnauthorizedHandler func(*echo.Context, error) - JwtKeyFunc func(string) ([]byte, error) - Claims map[string]interface{} -) - -var ( - ErrBasicAuth = errors.New("echo: basic auth error") - ErrJwtAuth = errors.New("echo: jwt auth error") -) - -func BasicAuth(ah AuthorizedHandler, uah UnauthorizedHandler, fn BasicAuthFunc) echo.HandlerFunc { - return func(c *echo.Context) { - auth := strings.Fields(c.Request.Header.Get("Authorization")) - if len(auth) == 2 { - scheme := auth[0] - s, err := base64.StdEncoding.DecodeString(auth[1]) - if err != nil { - uah(c, err) - return - } - cred := strings.Split(string(s), ":") - if scheme == "Basic" && len(cred) == 2 { - if ok := fn(cred[0], cred[1]); ok { - ah(c) - return - } - } - } - uah(c, ErrBasicAuth) - } -} - -func JwtAuth(ah AuthorizedHandler, uah UnauthorizedHandler, fn JwtKeyFunc) echo.HandlerFunc { - return func(c *echo.Context) { - auth := strings.Fields(c.Request.Header.Get("Authorization")) - if len(auth) == 2 { - t, err := jwt.Parse(auth[1], func(token *jwt.Token) (interface{}, error) { - if kid := token.Header["kid"]; kid != nil { - return fn(kid.(string)) - } - return fn("") - }) - if t.Valid { - c.Set("claims", Claims(t.Claims)) - ah(c) - // c.Next() - } else { - // TODO: capture errors - uah(c, err) - } - return - } - uah(c, ErrJwtAuth) - } -} +// +// import ( +// "encoding/base64" +// "errors" +// "strings" +// +// "github.com/dgrijalva/jwt-go" +// "github.com/labstack/echo" +// ) +// +// type ( +// BasicAuthFunc func(string, string) bool +// AuthorizedHandler echo.HandlerFunc +// UnauthorizedHandler func(*echo.Context, error) +// JwtKeyFunc func(string) ([]byte, error) +// Claims map[string]interface{} +// ) +// +// var ( +// ErrBasicAuth = errors.New("echo: basic auth error") +// ErrJwtAuth = errors.New("echo: jwt auth error") +// ) +// +// func BasicAuth(ah AuthorizedHandler, uah UnauthorizedHandler, fn BasicAuthFunc) echo.HandlerFunc { +// return func(c *echo.Context) { +// auth := strings.Fields(c.Request.Header.Get("Authorization")) +// if len(auth) == 2 { +// scheme := auth[0] +// s, err := base64.StdEncoding.DecodeString(auth[1]) +// if err != nil { +// uah(c, err) +// return +// } +// cred := strings.Split(string(s), ":") +// if scheme == "Basic" && len(cred) == 2 { +// if ok := fn(cred[0], cred[1]); ok { +// ah(c) +// return +// } +// } +// } +// uah(c, ErrBasicAuth) +// } +// } +// +// func JwtAuth(ah AuthorizedHandler, uah UnauthorizedHandler, fn JwtKeyFunc) echo.HandlerFunc { +// return func(c *echo.Context) { +// auth := strings.Fields(c.Request.Header.Get("Authorization")) +// if len(auth) == 2 { +// t, err := jwt.Parse(auth[1], func(token *jwt.Token) (interface{}, error) { +// if kid := token.Header["kid"]; kid != nil { +// return fn(kid.(string)) +// } +// return fn("") +// }) +// if t.Valid { +// c.Set("claims", Claims(t.Claims)) +// ah(c) +// // c.Next() +// } else { +// // TODO: capture errors +// uah(c, err) +// } +// return +// } +// uah(c, ErrJwtAuth) +// } +// } diff --git a/middleware/logger.go b/middleware/logger.go index be6600df..3429542f 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -9,9 +9,9 @@ import ( "github.com/mattn/go-colorable" ) -func Logger(h echo.HandlerFunc) echo.HandlerFunc { +func Logger(h echo.HandlerFunc) (echo.HandlerFunc, error) { log.SetOutput(colorable.NewColorableStdout()) - return echo.HandlerFunc(func(c *echo.Context) { + return func(c *echo.Context) error { start := time.Now() h(c) end := time.Now() @@ -30,5 +30,6 @@ func Logger(h echo.HandlerFunc) echo.HandlerFunc { } log.Printf("%s %s %s %s", m, p, col(s), end.Sub(start)) - }) + return nil + }, nil } diff --git a/response_test.go b/response_test.go index 2d7a8831..fce5e977 100644 --- a/response_test.go +++ b/response_test.go @@ -7,21 +7,34 @@ import ( ) func TestResponse(t *testing.T) { - e := New() - e.Get("/hello", func(c *Context) { - c.String(http.StatusOK, "world") + r := &response{Writer: httptest.NewRecorder()} - // Status - if c.Response.Status() != http.StatusOK { - t.Error("status code should be 200") - } + // Header + if r.Header() == nil { + t.Error("header should not be nil") + } - // Size - if c.Response.Status() != http.StatusOK { - t.Error("size should be 5") - } - }) - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/hello", nil) - e.ServeHTTP(w, r) + // WriteHeader + r.WriteHeader(http.StatusOK) + if r.status != http.StatusOK { + t.Errorf("status should be %d", http.StatusOK) + } + if r.committed != true { + t.Error("response should be true") + } + // Response already committed + r.WriteHeader(http.StatusOK) + + // Status + r.status = http.StatusOK + if r.Status() != http.StatusOK { + t.Errorf("status should be %d", http.StatusOK) + } + + // Write & Size + s := "echo" + r.Write([]byte(s)) + if r.Size() != len(s) { + t.Errorf("size should be %d", len(s)) + } } diff --git a/router_test.go b/router_test.go index 5e566f5a..7cf11fee 100644 --- a/router_test.go +++ b/router_test.go @@ -283,8 +283,9 @@ func TestRouterStatic(t *testing.T) { r := New().Router b := new(bytes.Buffer) path := "/folders/a/files/echo.gif" - r.Add(GET, path, func(*Context) { + r.Add(GET, path, func(*Context) error { b.WriteString(path) + return nil }, nil) h, _ := r.Find(GET, path, params) if h == nil { @@ -298,7 +299,9 @@ func TestRouterStatic(t *testing.T) { func TestRouterParam(t *testing.T) { r := New().Router - r.Add(GET, "/users/:id", func(c *Context) {}, nil) + r.Add(GET, "/users/:id", func(c *Context) error { + return nil + }, nil) h, _ := r.Find(GET, "/users/1", params) if h == nil { t.Fatal("handler not found") @@ -310,7 +313,9 @@ func TestRouterParam(t *testing.T) { func TestRouterTwoParam(t *testing.T) { r := New().Router - r.Add(GET, "/users/:uid/files/:fid", func(*Context) {}, nil) + r.Add(GET, "/users/:uid/files/:fid", func(*Context) error { + return nil + }, nil) h, _ := r.Find(GET, "/users/1/files/1", params) if h == nil { t.Fatal("handler not found") @@ -325,7 +330,9 @@ func TestRouterTwoParam(t *testing.T) { func TestRouterCatchAll(t *testing.T) { r := New().Router - r.Add(GET, "/static/*", func(*Context) {}, nil) + r.Add(GET, "/static/*", func(*Context) error { + return nil + }, nil) h, _ := r.Find(GET, "/static/echo.gif", params) if h == nil { t.Fatal("handler not found") @@ -337,7 +344,9 @@ func TestRouterCatchAll(t *testing.T) { func TestRouterMicroParam(t *testing.T) { r := New().Router - r.Add(GET, "/:a/:b/:c", func(c *Context) {}, nil) + r.Add(GET, "/:a/:b/:c", func(c *Context) error { + return nil + }, nil) h, _ := r.Find(GET, "/1/2/3", params) if h == nil { t.Fatal("handler not found") @@ -358,10 +367,13 @@ func TestRouterMultiRoute(t *testing.T) { b := new(bytes.Buffer) // Routes - r.Add(GET, "/users", func(*Context) { + r.Add(GET, "/users", func(*Context) error { b.WriteString("/users") + return nil + }, nil) + r.Add(GET, "/users/:id", func(c *Context) error { + return nil }, nil) - r.Add(GET, "/users/:id", func(c *Context) {}, nil) // Route > /users h, _ := r.Find(GET, "/users", params) @@ -394,19 +406,26 @@ func TestRouterConflictingRoute(t *testing.T) { b := new(bytes.Buffer) // Routes - r.Add(GET, "/users", func(*Context) { + r.Add(GET, "/users", func(*Context) error { b.WriteString("/users") + return nil }, nil) - r.Add(GET, "/users/new", func(*Context) { + r.Add(GET, "/users/new", func(*Context) error { b.Reset() b.WriteString("/users/new") + return nil }, nil) - r.Add(GET, "/users/:id", func(c *Context) {}, nil) - r.Add(GET, "/users/new/moon", func(*Context) { + r.Add(GET, "/users/:id", func(c *Context) error { + return nil + }, nil) + r.Add(GET, "/users/new/moon", func(*Context) error { b.Reset() b.WriteString("/users/new/moon") + return nil + }, nil) + r.Add(GET, "/users/new/:id", func(*Context) error { + return nil }, nil) - r.Add(GET, "/users/new/:id", func(*Context) {}, nil) // Route > /users h, _ := r.Find(GET, "/users", params) @@ -495,7 +514,7 @@ func TestRouterConflictingRoute(t *testing.T) { func TestRouterAPI(t *testing.T) { r := New().Router for _, route := range api { - r.Add(route.method, route.path, func(c *Context) { + r.Add(route.method, route.path, func(c *Context) error { for _, p := range c.params { if p.Name != "" { if ":"+p.Name != p.Value { @@ -503,6 +522,7 @@ func TestRouterAPI(t *testing.T) { } } } + return nil }, nil) h, _ := r.Find(route.method, route.path, params) @@ -514,7 +534,9 @@ func TestRouterAPI(t *testing.T) { func TestRouterServeHTTP(t *testing.T) { r := New().Router - r.Add(GET, "/users", func(*Context) {}, nil) + r.Add(GET, "/users", func(*Context) error { + return nil + }, nil) // OK req, _ := http.NewRequest(GET, "/users", nil)