diff --git a/context.go b/context.go index d878a15c..256161d6 100644 --- a/context.go +++ b/context.go @@ -70,6 +70,16 @@ func (c *Context) Param(name string) (value string) { return } +// Get retrieves data from the context. +func (c *Context) Get(key string) interface{} { + return c.store[key] +} + +// Set saves data in the context. +func (c *Context) Set(key string, val interface{}) { + c.store[key] = val +} + // Bind binds the request body into specified type v. Default binder does it // based on Content-Type header. func (c *Context) Bind(i interface{}) error { @@ -82,21 +92,21 @@ func (c *Context) Render(code int, name string, data interface{}) error { if c.echo.renderer == nil { return RendererNotRegistered } - c.response.Header().Set(ContentType, TextHTML+"; charset=utf-8") + c.response.Header().Set(ContentType, TextHTML) c.response.WriteHeader(code) return c.echo.renderer.Render(c.response, name, data) } // JSON sends an application/json response with status code. func (c *Context) JSON(code int, i interface{}) error { - c.response.Header().Set(ContentType, ApplicationJSON+"; charset=utf-8") + c.response.Header().Set(ContentType, ApplicationJSON) c.response.WriteHeader(code) return json.NewEncoder(c.response).Encode(i) } // String sends a text/plain response with status code. func (c *Context) String(code int, s string) error { - c.response.Header().Set(ContentType, TextPlain+"; charset=utf-8") + c.response.Header().Set(ContentType, TextPlain) c.response.WriteHeader(code) _, err := c.response.Write([]byte(s)) return err @@ -104,7 +114,7 @@ func (c *Context) String(code int, s string) error { // HTML sends a text/html response with status code. func (c *Context) HTML(code int, html string) error { - c.response.Header().Set(ContentType, TextHTML+"; charset=utf-8") + c.response.Header().Set(ContentType, TextHTML) c.response.WriteHeader(code) _, err := c.response.Write([]byte(html)) return err @@ -126,16 +136,6 @@ func (c *Context) Error(err error) { c.echo.httpErrorHandler(err, c) } -// Get retrieves data from the context. -func (c *Context) Get(key string) interface{} { - return c.store[key] -} - -// Set saves data in the context. -func (c *Context) Set(key string, val interface{}) { - c.store[key] = val -} - func (c *Context) reset(r *http.Request, w http.ResponseWriter, e *Echo) { c.request = r c.response.reset(w) diff --git a/context_test.go b/context_test.go index 9bd0506b..5d3c08ee 100644 --- a/context_test.go +++ b/context_test.go @@ -1,8 +1,6 @@ package echo import ( - "bytes" - "encoding/json" "errors" "io" "net/http" @@ -10,6 +8,8 @@ import ( "testing" "text/template" + "strings" + "github.com/stretchr/testify/assert" ) @@ -24,47 +24,20 @@ func (t *Template) Render(w io.Writer, name string, data interface{}) error { } func TestContext(t *testing.T) { - b, _ := json.Marshal(u1) - r, _ := http.NewRequest(POST, "/users/1", bytes.NewReader(b)) - c := NewContext(r, NewResponse(httptest.NewRecorder()), New()) + usr := `{"id":"1","name":"Joe"}` + req, _ := http.NewRequest(POST, "/", strings.NewReader(usr)) + rec := httptest.NewRecorder() + c := NewContext(req, NewResponse(rec), New()) // Request - assert.NotEmpty(t, c.Request()) + assert.NotNil(t, c.Request()) // Response - assert.NotEmpty(t, c.Response()) + assert.NotNil(t, c.Response()) // Socket assert.Nil(t, c.Socket()) - //------ - // Bind - //------ - - // JSON - r.Header.Set(ContentType, ApplicationJSON) - u2 := new(user) - if he := c.Bind(u2); he != nil { - t.Errorf("bind %#v", he) - } - verifyUser(u2, t) - - // FORM - r.Header.Set(ContentType, ApplicationForm) - u2 = new(user) - if he := c.Bind(u2); he != nil { - t.Errorf("bind %#v", he) - } - // TODO: add verification - - // Unsupported - r.Header.Set(ContentType, "") - u2 = new(user) - if he := c.Bind(u2); he == nil { - t.Errorf("bind %#v", he) - } - // TODO: add verification - //------- // Param //------- @@ -72,69 +45,112 @@ func TestContext(t *testing.T) { // By id c.pnames = []string{"id"} c.pvalues = []string{"1"} - if c.P(0) != "1" { - t.Error("param id should be 1") - } + assert.Equal(t, "1", c.P(0)) // By name - if c.Param("id") != "1" { - t.Error("param id should be 1") - } + assert.Equal(t, "1", c.Param("id")) // Store - c.Set("user", u1.Name) - n := c.Get("user") - if n != u1.Name { - t.Error("user name should be Joe") - } + c.Set("user", "Joe") + assert.Equal(t, "Joe", c.Get("user")) - // Render - tpl := &Template{ - templates: template.Must(template.New("hello").Parse("{{.}}")), - } - c.echo.renderer = tpl - if he := c.Render(http.StatusOK, "hello", "Joe"); he != nil { - t.Errorf("render %#v", he.Error) - } - c.echo.renderer = nil - if he := c.Render(http.StatusOK, "hello", "Joe"); he.Error == nil { - t.Error("render should error out") - } + //------ + // Bind + //------ // JSON - r.Header.Set(Accept, ApplicationJSON) - c.response.committed = false - if he := c.JSON(http.StatusOK, u1); he != nil { - t.Errorf("json %#v", he) + testBind(t, c, ApplicationJSON) + + // TODO: Form + c.request.Header.Set(ContentType, ApplicationForm) + u := new(user) + err := c.Bind(u) + assert.NoError(t, err) + + // Unsupported + c.request.Header.Set(ContentType, "") + u = new(user) + err = c.Bind(u) + assert.Error(t, err) + + //-------- + // Render + //-------- + + tpl := &Template{ + templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), + } + c.echo.renderer = tpl + err = c.Render(http.StatusOK, "hello", "Joe") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, Joe!", rec.Body.String()) + } + + c.echo.renderer = nil + err = c.Render(http.StatusOK, "hello", "Joe") + assert.Error(t, err) + + // JSON + req.Header.Set(Accept, ApplicationJSON) + rec = httptest.NewRecorder() + c = NewContext(req, NewResponse(rec), New()) + err = c.JSON(http.StatusOK, user{"1", "Joe"}) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, ApplicationJSON, rec.Header().Get(ContentType)) + assert.Equal(t, usr, strings.TrimSpace(rec.Body.String())) } // String - r.Header.Set(Accept, TextPlain) - c.response.committed = false - if he := c.String(http.StatusOK, "Hello, World!"); he != nil { - t.Errorf("string %#v", he.Error) + req.Header.Set(Accept, TextPlain) + rec = httptest.NewRecorder() + c = NewContext(req, NewResponse(rec), New()) + err = c.String(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, TextPlain, rec.Header().Get(ContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // HTML - r.Header.Set(Accept, TextHTML) - c.response.committed = false - if he := c.HTML(http.StatusOK, "Hello, World!"); he != nil { - t.Errorf("html %v", he.Error) + req.Header.Set(Accept, TextHTML) + rec = httptest.NewRecorder() + c = NewContext(req, NewResponse(rec), New()) + err = c.HTML(http.StatusOK, "Hello, World!") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, TextHTML, rec.Header().Get(ContentType)) + assert.Equal(t, "Hello, World!", rec.Body.String()) } // NoContent + rec = httptest.NewRecorder() + c = NewContext(req, NewResponse(rec), New()) c.NoContent(http.StatusOK) assert.Equal(t, http.StatusOK, c.response.status) // Redirect - c.response.committed = false + rec = httptest.NewRecorder() + c = NewContext(req, NewResponse(rec), New()) c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo") // Error - c.response.committed = false + rec = httptest.NewRecorder() + c = NewContext(req, NewResponse(rec), New()) c.Error(errors.New("error")) assert.Equal(t, http.StatusInternalServerError, c.response.status) // reset - c.reset(r, NewResponse(httptest.NewRecorder()), New()) + c.reset(req, NewResponse(httptest.NewRecorder()), New()) +} + +func testBind(t *testing.T, c *Context, ct string) { + c.request.Header.Set(ContentType, ct) + u := new(user) + err := c.Bind(u) + if assert.NoError(t, err) { + assert.Equal(t, "1", u.ID) + assert.Equal(t, "Joe", u.Name) + } } diff --git a/echo.go b/echo.go index e9cd3646..82b500e0 100644 --- a/echo.go +++ b/echo.go @@ -489,8 +489,6 @@ func wrapMiddleware(m Middleware) MiddlewareFunc { } case http.Handler: return wrapHTTPHandlerFuncMW(m.ServeHTTP) - case http.HandlerFunc: - return wrapHTTPHandlerFuncMW(m) case func(http.ResponseWriter, *http.Request): return wrapHTTPHandlerFuncMW(m) default: diff --git a/echo_test.go b/echo_test.go index a7e0f8b0..79230485 100644 --- a/echo_test.go +++ b/echo_test.go @@ -10,6 +10,7 @@ import ( "reflect" "strings" + "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) @@ -20,11 +21,6 @@ type ( } ) -var u1 = user{ - ID: "1", - Name: "Joe", -} - // TODO: Improve me! func TestEchoMaxParam(t *testing.T) { e := New() @@ -71,7 +67,7 @@ func TestEchoMiddleware(t *testing.T) { e := New() b := new(bytes.Buffer) - // MiddlewareFunc + // echo.MiddlewareFunc e.Use(MiddlewareFunc(func(h HandlerFunc) HandlerFunc { return func(c *Context) error { b.WriteString("a") @@ -87,33 +83,44 @@ func TestEchoMiddleware(t *testing.T) { } }) + // echo.HandlerFunc + e.Use(HandlerFunc(func(c *Context) error { + b.WriteString("c") + return nil + })) + // func(*echo.Context) error e.Use(func(c *Context) error { - b.WriteString("c") + b.WriteString("d") return nil }) // func(http.Handler) http.Handler e.Use(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("d") + b.WriteString("e") h.ServeHTTP(w, r) }) }) // http.Handler e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("e") + b.WriteString("f") }))) // http.HandlerFunc e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("f") + b.WriteString("g") })) // func(http.ResponseWriter, *http.Request) e.Use(func(w http.ResponseWriter, r *http.Request) { - b.WriteString("g") + b.WriteString("h") + }) + + // Unknown + assert.Panics(t, func() { + e.Use(nil) }) // Route @@ -124,8 +131,8 @@ func TestEchoMiddleware(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/hello", nil) e.ServeHTTP(w, r) - if b.String() != "abcdefg" { - t.Errorf("buffer should be abcdefghi, found %s", b.String()) + if b.String() != "abcdefgh" { + t.Errorf("buffer should be abcdefgh, found %s", b.String()) } if w.Body.String() != "world" { t.Error("body should be world") @@ -178,6 +185,11 @@ func TestEchoHandler(t *testing.T) { if w.Body.String() != "4" { t.Error("body should be 4") } + + // Unknown + assert.Panics(t, func() { + e.Get("/5", nil) + }) } func TestEchoGroup(t *testing.T) { @@ -240,50 +252,50 @@ func TestEchoGroup(t *testing.T) { func TestEchoConnect(t *testing.T) { e := New() - testMethod(CONNECT, "/", e, nil, t) + testMethod(t, e, nil, CONNECT, "/") } func TestEchoDelete(t *testing.T) { e := New() - testMethod(DELETE, "/", e, nil, t) + testMethod(t, e, nil, DELETE, "/") } func TestEchoGet(t *testing.T) { e := New() - testMethod(GET, "/", e, nil, t) + testMethod(t, e, nil, GET, "/") } func TestEchoHead(t *testing.T) { e := New() - testMethod(HEAD, "/", e, nil, t) + testMethod(t, e, nil, HEAD, "/") } func TestEchoOptions(t *testing.T) { e := New() - testMethod(OPTIONS, "/", e, nil, t) + testMethod(t, e, nil, OPTIONS, "/") } func TestEchoPatch(t *testing.T) { e := New() - testMethod(PATCH, "/", e, nil, t) + testMethod(t, e, nil, PATCH, "/") } func TestEchoPost(t *testing.T) { e := New() - testMethod(POST, "/", e, nil, t) + testMethod(t, e, nil, POST, "/") } func TestEchoPut(t *testing.T) { e := New() - testMethod(PUT, "/", e, nil, t) + testMethod(t, e, nil, PUT, "/") } func TestEchoTrace(t *testing.T) { e := New() - testMethod(TRACE, "/", e, nil, t) + testMethod(t, e, nil, TRACE, "/") } -func testMethod(method, path string, e *Echo, g *Group, t *testing.T) { +func testMethod(t *testing.T, e *Echo, g *Group, method, path string) { m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) p := reflect.ValueOf(path) h := reflect.ValueOf(func(c *Context) error { @@ -375,11 +387,11 @@ func TestEchoNotFound(t *testing.T) { } } -func verifyUser(u2 *user, t *testing.T) { - if u2.ID != u1.ID { - t.Errorf("user id should be %s, found %s", u1.ID, u2.ID) - } - if u2.Name != u1.Name { - t.Errorf("user name should be %s, found %s", u1.Name, u2.Name) - } -} +//func verifyUser(u2 *user, t *testing.T) { +// if u2.ID != u1.ID { +// t.Errorf("user id should be %s, found %s", u1.ID, u2.ID) +// } +// if u2.Name != u1.Name { +// t.Errorf("user name should be %s, found %s", u1.Name, u2.Name) +// } +//} diff --git a/group_test.go b/group_test.go index f78d63c1..3a948679 100644 --- a/group_test.go +++ b/group_test.go @@ -4,45 +4,45 @@ import "testing" func TestGroupConnect(t *testing.T) { g := New().Group("/group") - testMethod(CONNECT, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, CONNECT, "/") } func TestGroupDelete(t *testing.T) { g := New().Group("/group") - testMethod(DELETE, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, DELETE, "/") } func TestGroupGet(t *testing.T) { g := New().Group("/group") - testMethod(GET, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, GET, "/") } func TestGroupHead(t *testing.T) { g := New().Group("/group") - testMethod(HEAD, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, HEAD, "/") } func TestGroupOptions(t *testing.T) { g := New().Group("/group") - testMethod(OPTIONS, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, OPTIONS, "/") } func TestGroupPatch(t *testing.T) { g := New().Group("/group") - testMethod(PATCH, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, PATCH, "/") } func TestGroupPost(t *testing.T) { g := New().Group("/group") - testMethod(POST, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, POST, "/") } func TestGroupPut(t *testing.T) { g := New().Group("/group") - testMethod(PUT, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, PUT, "/") } func TestGroupTrace(t *testing.T) { g := New().Group("/group") - testMethod(TRACE, "/", &g.echo, g, t) + testMethod(t, &g.echo, g, TRACE, "/") } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index d4d77165..377f6855 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -6,12 +6,14 @@ import ( "testing" "github.com/labstack/echo" + "net/http/httptest" + "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - req, _ := http.NewRequest(echo.POST, "/", nil) - res := &echo.Response{} - c := echo.NewContext(req, res, echo.New()) + req, _ := http.NewRequest(echo.GET, "/", nil) + rec := httptest.NewRecorder() + c := echo.NewContext(req, echo.NewResponse(rec), echo.New()) fn := func(u, p string) bool { if u == "joe" && p == "secret" { return true @@ -26,16 +28,12 @@ func TestBasicAuth(t *testing.T) { auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.Authorization, auth) - if ba(c) != nil { - t.Error("expected `pass`") - } + assert.NoError(t, ba(c)) // Case insensitive auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.Authorization, auth) - if ba(c) != nil { - t.Error("expected `pass`, with case insensitive header.") - } + assert.NoError(t, ba(c)) //--------------------- // Invalid credentials @@ -46,41 +44,30 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) he := ba(c).(*echo.HTTPError) - if ba(c) == nil { - t.Error("expected `fail`, with incorrect password.") - } else if he.Code() != http.StatusUnauthorized { - t.Errorf("expected status `401`, got %d", he.Code()) - } + assert.Equal(t, http.StatusUnauthorized, he.Code()) // Empty Authorization header req.Header.Set(echo.Authorization, "") ba = BasicAuth(fn) he = ba(c).(*echo.HTTPError) - if he == nil { - t.Error("expected `fail`, with empty Authorization header.") - } else if he.Code() != http.StatusBadRequest { - t.Errorf("expected status `400`, got %d", he.Code()) - } + assert.Equal(t, http.StatusBadRequest, he.Code()) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte(" :secret")) req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) he = ba(c).(*echo.HTTPError) - if he == nil { - t.Error("expected `fail`, with invalid Authorization header.") - } else if he.Code() != http.StatusBadRequest { - t.Errorf("expected status `400`, got %d", he.Code()) - } + assert.Equal(t, http.StatusBadRequest, he.Code()) // Invalid scheme - auth = "Ace " + base64.StdEncoding.EncodeToString([]byte(" :secret")) + auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret")) req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) he = ba(c).(*echo.HTTPError) - if he == nil { - t.Error("expected `fail`, with invalid scheme.") - } else if he.Code() != http.StatusBadRequest { - t.Errorf("expected status `400`, got %d", he.Code()) - } + assert.Equal(t, http.StatusBadRequest, he.Code()) + + // WebSocket + c.Request().Header.Set(echo.Upgrade, echo.WebSocket) + ba = BasicAuth(fn) + assert.NoError(t, ba(c)) } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 90417ac7..85d7fe94 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -9,43 +9,34 @@ import ( "bytes" "github.com/labstack/echo" + "github.com/stretchr/testify/assert" ) func TestGzip(t *testing.T) { // Empty Accept-Encoding header req, _ := http.NewRequest(echo.GET, "/", nil) - w := httptest.NewRecorder() - res := echo.NewResponse(w) - c := echo.NewContext(req, res, echo.New()) + rec := httptest.NewRecorder() + c := echo.NewContext(req, echo.NewResponse(rec), echo.New()) h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } Gzip()(h)(c) - s := w.Body.String() - if s != "test" { - t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s) - } + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "test", rec.Body.String()) - // Content-Encoding header + // With Accept-Encoding header + req, _ = http.NewRequest(echo.GET, "/", nil) req.Header.Set(echo.AcceptEncoding, "gzip") - w = httptest.NewRecorder() - c.Response().SetWriter(w) + rec = httptest.NewRecorder() + c = echo.NewContext(req, echo.NewResponse(rec), echo.New()) Gzip()(h)(c) - ce := w.Header().Get(echo.ContentEncoding) - if ce != "gzip" { - t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce) - } - - // Body - r, err := gzip.NewReader(w.Body) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding)) + r, err := gzip.NewReader(rec.Body) defer r.Close() - if err != nil { - t.Fatal(err) - } - buf := new(bytes.Buffer) - buf.ReadFrom(r) - s = buf.String() - if s != "test" { - t.Errorf("expected body `test`, got %s.", s) + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) } } diff --git a/middleware/logger.go b/middleware/logger.go index 92bc8ed8..19d1301d 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -15,7 +15,7 @@ func Logger() echo.MiddlewareFunc { if err := h(c); err != nil { c.Error(err) } - end := time.Now() + stop := time.Now() method := c.Request().Method path := c.Request().URL.Path if path == "" { @@ -34,7 +34,7 @@ func Logger() echo.MiddlewareFunc { code = color.Cyan(n) } - log.Printf("%s %s %s %s %d", method, path, code, end.Sub(start), size) + log.Printf("%s %s %s %s %d", method, path, code, stop.Sub(start), size) return nil } } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index 0c341e67..ee4809f1 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -5,14 +5,14 @@ import ( "net/http" "net/http/httptest" "testing" + "errors" ) func TestLogger(t *testing.T) { e := echo.New() req, _ := http.NewRequest(echo.GET, "/", nil) - w := httptest.NewRecorder() - res := echo.NewResponse(w) - c := echo.NewContext(req, res, e) + rec := httptest.NewRecorder() + c := echo.NewContext(req, echo.NewResponse(rec), e) // Status 2xx h := func(c *echo.Context) error { @@ -20,17 +20,28 @@ func TestLogger(t *testing.T) { } Logger()(h)(c) + // Status 3xx + rec = httptest.NewRecorder() + c = echo.NewContext(req, echo.NewResponse(rec), e) + h = func(c *echo.Context) error { + return c.String(http.StatusTemporaryRedirect, "test") + } + Logger()(h)(c) + // Status 4xx - c = echo.NewContext(req, echo.NewResponse(w), e) + rec = httptest.NewRecorder() + c = echo.NewContext(req, echo.NewResponse(rec), e) h = func(c *echo.Context) error { return c.String(http.StatusNotFound, "test") } Logger()(h)(c) - // Status 5xx - c = echo.NewContext(req, echo.NewResponse(w), e) + // Status 5xx with empty path + req, _ = http.NewRequest(echo.GET, "", nil) + rec = httptest.NewRecorder() + c = echo.NewContext(req, echo.NewResponse(rec), e) h = func(c *echo.Context) error { - return c.String(http.StatusInternalServerError, "test") + return errors.New("error") } Logger()(h)(c) } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index d60ac325..003c7519 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -1,33 +1,24 @@ package middleware import ( - "github.com/labstack/echo" "net/http" "net/http/httptest" - "strings" "testing" + + "github.com/labstack/echo" + "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() e.SetDebug(true) req, _ := http.NewRequest(echo.GET, "/", nil) - w := httptest.NewRecorder() - res := echo.NewResponse(w) - c := echo.NewContext(req, res, e) + rec := httptest.NewRecorder() + c := echo.NewContext(req, echo.NewResponse(rec), e) h := func(c *echo.Context) error { panic("test") } - - // Status Recover()(h)(c) - if w.Code != http.StatusInternalServerError { - t.Errorf("expected status `500`, got %d.", w.Code) - } - - // Body - s := w.Body.String() - if !strings.Contains(s, "panic recover") { - t.Error("expected body contains `panice recover`.") - } + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Contains(t, rec.Body.String(), "panic recover") } diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 67fe72bf..59eaf238 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -6,33 +6,22 @@ import ( "testing" "github.com/labstack/echo" + "github.com/stretchr/testify/assert" ) func TestStripTrailingSlash(t *testing.T) { req, _ := http.NewRequest(echo.GET, "/users/", nil) - res := echo.NewResponse(httptest.NewRecorder()) - c := echo.NewContext(req, res, echo.New()) + rec := httptest.NewRecorder() + c := echo.NewContext(req, echo.NewResponse(rec), echo.New()) StripTrailingSlash()(c) - p := c.Request().URL.Path - if p != "/users" { - t.Errorf("expected path `/users` got, %s.", p) - } + assert.Equal(t, "/users", c.Request().URL.Path) } func TestRedirectToSlash(t *testing.T) { req, _ := http.NewRequest(echo.GET, "/users", nil) - res := echo.NewResponse(httptest.NewRecorder()) - c := echo.NewContext(req, res, echo.New()) + rec := httptest.NewRecorder() + c := echo.NewContext(req, echo.NewResponse(rec), echo.New()) RedirectToSlash(RedirectToSlashOptions{Code: http.StatusTemporaryRedirect})(c) - - // Status code - if res.Status() != http.StatusTemporaryRedirect { - t.Errorf("expected status `307`, got %d.", res.Status()) - } - - // Location header - l := c.Response().Header().Get("Location") - if l != "/users/" { - t.Errorf("expected Location header `/users/`, got %s.", l) - } + assert.Equal(t, http.StatusTemporaryRedirect, rec.Code) + assert.Equal(t, "/users/", c.Response().Header().Get("Location")) } diff --git a/response_test.go b/response_test.go index fbc21bef..7078eeb9 100644 --- a/response_test.go +++ b/response_test.go @@ -24,11 +24,12 @@ func TestResponse(t *testing.T) { r.WriteHeader(http.StatusOK) assert.Equal(t, http.StatusOK, r.status) - // committed + // Committed assert.True(t, r.committed) - // Response already committed - r.WriteHeader(http.StatusOK) + // Already committed + r.WriteHeader(http.StatusTeapot) + assert.NotEqual(t, http.StatusTeapot, r.Status()) // Status r.status = http.StatusOK @@ -43,7 +44,7 @@ func TestResponse(t *testing.T) { r.Flush() // Size - assert.Equal(t, int64(len(s)), r.Size()) + assert.Len(t, s, int(r.Size())) // Hijack assert.Panics(t, func() { diff --git a/router.go b/router.go index b44c61cc..af9b4aea 100644 --- a/router.go +++ b/router.go @@ -150,9 +150,6 @@ func newNode(t ntype, pre string, p *node, c children, h HandlerFunc, pnames []s } } -func (n *node) addChild(c *node) { -} - func (n *node) findChild(l byte) *node { for _, c := range n.children { if c.label == l {