diff --git a/context.go b/context.go
index d878a15c..256161d6 100644
--- a/context.go
+++ b/context.go
@@ -70,6 +70,16 @@ func (c *Context) Param(name string) (value string) {
return
}
+// Get retrieves data from the context.
+func (c *Context) Get(key string) interface{} {
+ return c.store[key]
+}
+
+// Set saves data in the context.
+func (c *Context) Set(key string, val interface{}) {
+ c.store[key] = val
+}
+
// Bind binds the request body into specified type v. Default binder does it
// based on Content-Type header.
func (c *Context) Bind(i interface{}) error {
@@ -82,21 +92,21 @@ func (c *Context) Render(code int, name string, data interface{}) error {
if c.echo.renderer == nil {
return RendererNotRegistered
}
- c.response.Header().Set(ContentType, TextHTML+"; charset=utf-8")
+ c.response.Header().Set(ContentType, TextHTML)
c.response.WriteHeader(code)
return c.echo.renderer.Render(c.response, name, data)
}
// JSON sends an application/json response with status code.
func (c *Context) JSON(code int, i interface{}) error {
- c.response.Header().Set(ContentType, ApplicationJSON+"; charset=utf-8")
+ c.response.Header().Set(ContentType, ApplicationJSON)
c.response.WriteHeader(code)
return json.NewEncoder(c.response).Encode(i)
}
// String sends a text/plain response with status code.
func (c *Context) String(code int, s string) error {
- c.response.Header().Set(ContentType, TextPlain+"; charset=utf-8")
+ c.response.Header().Set(ContentType, TextPlain)
c.response.WriteHeader(code)
_, err := c.response.Write([]byte(s))
return err
@@ -104,7 +114,7 @@ func (c *Context) String(code int, s string) error {
// HTML sends a text/html response with status code.
func (c *Context) HTML(code int, html string) error {
- c.response.Header().Set(ContentType, TextHTML+"; charset=utf-8")
+ c.response.Header().Set(ContentType, TextHTML)
c.response.WriteHeader(code)
_, err := c.response.Write([]byte(html))
return err
@@ -126,16 +136,6 @@ func (c *Context) Error(err error) {
c.echo.httpErrorHandler(err, c)
}
-// Get retrieves data from the context.
-func (c *Context) Get(key string) interface{} {
- return c.store[key]
-}
-
-// Set saves data in the context.
-func (c *Context) Set(key string, val interface{}) {
- c.store[key] = val
-}
-
func (c *Context) reset(r *http.Request, w http.ResponseWriter, e *Echo) {
c.request = r
c.response.reset(w)
diff --git a/context_test.go b/context_test.go
index 9bd0506b..5d3c08ee 100644
--- a/context_test.go
+++ b/context_test.go
@@ -1,8 +1,6 @@
package echo
import (
- "bytes"
- "encoding/json"
"errors"
"io"
"net/http"
@@ -10,6 +8,8 @@ import (
"testing"
"text/template"
+ "strings"
+
"github.com/stretchr/testify/assert"
)
@@ -24,47 +24,20 @@ func (t *Template) Render(w io.Writer, name string, data interface{}) error {
}
func TestContext(t *testing.T) {
- b, _ := json.Marshal(u1)
- r, _ := http.NewRequest(POST, "/users/1", bytes.NewReader(b))
- c := NewContext(r, NewResponse(httptest.NewRecorder()), New())
+ usr := `{"id":"1","name":"Joe"}`
+ req, _ := http.NewRequest(POST, "/", strings.NewReader(usr))
+ rec := httptest.NewRecorder()
+ c := NewContext(req, NewResponse(rec), New())
// Request
- assert.NotEmpty(t, c.Request())
+ assert.NotNil(t, c.Request())
// Response
- assert.NotEmpty(t, c.Response())
+ assert.NotNil(t, c.Response())
// Socket
assert.Nil(t, c.Socket())
- //------
- // Bind
- //------
-
- // JSON
- r.Header.Set(ContentType, ApplicationJSON)
- u2 := new(user)
- if he := c.Bind(u2); he != nil {
- t.Errorf("bind %#v", he)
- }
- verifyUser(u2, t)
-
- // FORM
- r.Header.Set(ContentType, ApplicationForm)
- u2 = new(user)
- if he := c.Bind(u2); he != nil {
- t.Errorf("bind %#v", he)
- }
- // TODO: add verification
-
- // Unsupported
- r.Header.Set(ContentType, "")
- u2 = new(user)
- if he := c.Bind(u2); he == nil {
- t.Errorf("bind %#v", he)
- }
- // TODO: add verification
-
//-------
// Param
//-------
@@ -72,69 +45,112 @@ func TestContext(t *testing.T) {
// By id
c.pnames = []string{"id"}
c.pvalues = []string{"1"}
- if c.P(0) != "1" {
- t.Error("param id should be 1")
- }
+ assert.Equal(t, "1", c.P(0))
// By name
- if c.Param("id") != "1" {
- t.Error("param id should be 1")
- }
+ assert.Equal(t, "1", c.Param("id"))
// Store
- c.Set("user", u1.Name)
- n := c.Get("user")
- if n != u1.Name {
- t.Error("user name should be Joe")
- }
+ c.Set("user", "Joe")
+ assert.Equal(t, "Joe", c.Get("user"))
- // Render
- tpl := &Template{
- templates: template.Must(template.New("hello").Parse("{{.}}")),
- }
- c.echo.renderer = tpl
- if he := c.Render(http.StatusOK, "hello", "Joe"); he != nil {
- t.Errorf("render %#v", he.Error)
- }
- c.echo.renderer = nil
- if he := c.Render(http.StatusOK, "hello", "Joe"); he.Error == nil {
- t.Error("render should error out")
- }
+ //------
+ // Bind
+ //------
// JSON
- r.Header.Set(Accept, ApplicationJSON)
- c.response.committed = false
- if he := c.JSON(http.StatusOK, u1); he != nil {
- t.Errorf("json %#v", he)
+ testBind(t, c, ApplicationJSON)
+
+ // TODO: Form
+ c.request.Header.Set(ContentType, ApplicationForm)
+ u := new(user)
+ err := c.Bind(u)
+ assert.NoError(t, err)
+
+ // Unsupported
+ c.request.Header.Set(ContentType, "")
+ u = new(user)
+ err = c.Bind(u)
+ assert.Error(t, err)
+
+ //--------
+ // Render
+ //--------
+
+ tpl := &Template{
+ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
+ }
+ c.echo.renderer = tpl
+ err = c.Render(http.StatusOK, "hello", "Joe")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "Hello, Joe!", rec.Body.String())
+ }
+
+ c.echo.renderer = nil
+ err = c.Render(http.StatusOK, "hello", "Joe")
+ assert.Error(t, err)
+
+ // JSON
+ req.Header.Set(Accept, ApplicationJSON)
+ rec = httptest.NewRecorder()
+ c = NewContext(req, NewResponse(rec), New())
+ err = c.JSON(http.StatusOK, user{"1", "Joe"})
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, ApplicationJSON, rec.Header().Get(ContentType))
+ assert.Equal(t, usr, strings.TrimSpace(rec.Body.String()))
}
// String
- r.Header.Set(Accept, TextPlain)
- c.response.committed = false
- if he := c.String(http.StatusOK, "Hello, World!"); he != nil {
- t.Errorf("string %#v", he.Error)
+ req.Header.Set(Accept, TextPlain)
+ rec = httptest.NewRecorder()
+ c = NewContext(req, NewResponse(rec), New())
+ err = c.String(http.StatusOK, "Hello, World!")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, TextPlain, rec.Header().Get(ContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
}
// HTML
- r.Header.Set(Accept, TextHTML)
- c.response.committed = false
- if he := c.HTML(http.StatusOK, "Hello, World!"); he != nil {
- t.Errorf("html %v", he.Error)
+ req.Header.Set(Accept, TextHTML)
+ rec = httptest.NewRecorder()
+ c = NewContext(req, NewResponse(rec), New())
+ err = c.HTML(http.StatusOK, "Hello, World!")
+ if assert.NoError(t, err) {
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, TextHTML, rec.Header().Get(ContentType))
+ assert.Equal(t, "Hello, World!", rec.Body.String())
}
// NoContent
+ rec = httptest.NewRecorder()
+ c = NewContext(req, NewResponse(rec), New())
c.NoContent(http.StatusOK)
assert.Equal(t, http.StatusOK, c.response.status)
// Redirect
- c.response.committed = false
+ rec = httptest.NewRecorder()
+ c = NewContext(req, NewResponse(rec), New())
c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")
// Error
- c.response.committed = false
+ rec = httptest.NewRecorder()
+ c = NewContext(req, NewResponse(rec), New())
c.Error(errors.New("error"))
assert.Equal(t, http.StatusInternalServerError, c.response.status)
// reset
- c.reset(r, NewResponse(httptest.NewRecorder()), New())
+ c.reset(req, NewResponse(httptest.NewRecorder()), New())
+}
+
+func testBind(t *testing.T, c *Context, ct string) {
+ c.request.Header.Set(ContentType, ct)
+ u := new(user)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, "1", u.ID)
+ assert.Equal(t, "Joe", u.Name)
+ }
}
diff --git a/echo.go b/echo.go
index e9cd3646..82b500e0 100644
--- a/echo.go
+++ b/echo.go
@@ -489,8 +489,6 @@ func wrapMiddleware(m Middleware) MiddlewareFunc {
}
case http.Handler:
return wrapHTTPHandlerFuncMW(m.ServeHTTP)
- case http.HandlerFunc:
- return wrapHTTPHandlerFuncMW(m)
case func(http.ResponseWriter, *http.Request):
return wrapHTTPHandlerFuncMW(m)
default:
diff --git a/echo_test.go b/echo_test.go
index a7e0f8b0..79230485 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -10,6 +10,7 @@ import (
"reflect"
"strings"
+ "github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)
@@ -20,11 +21,6 @@ type (
}
)
-var u1 = user{
- ID: "1",
- Name: "Joe",
-}
-
// TODO: Improve me!
func TestEchoMaxParam(t *testing.T) {
e := New()
@@ -71,7 +67,7 @@ func TestEchoMiddleware(t *testing.T) {
e := New()
b := new(bytes.Buffer)
- // MiddlewareFunc
+ // echo.MiddlewareFunc
e.Use(MiddlewareFunc(func(h HandlerFunc) HandlerFunc {
return func(c *Context) error {
b.WriteString("a")
@@ -87,33 +83,44 @@ func TestEchoMiddleware(t *testing.T) {
}
})
+ // echo.HandlerFunc
+ e.Use(HandlerFunc(func(c *Context) error {
+ b.WriteString("c")
+ return nil
+ }))
+
// func(*echo.Context) error
e.Use(func(c *Context) error {
- b.WriteString("c")
+ b.WriteString("d")
return nil
})
// func(http.Handler) http.Handler
e.Use(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- b.WriteString("d")
+ b.WriteString("e")
h.ServeHTTP(w, r)
})
})
// http.Handler
e.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- b.WriteString("e")
+ b.WriteString("f")
})))
// http.HandlerFunc
e.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- b.WriteString("f")
+ b.WriteString("g")
}))
// func(http.ResponseWriter, *http.Request)
e.Use(func(w http.ResponseWriter, r *http.Request) {
- b.WriteString("g")
+ b.WriteString("h")
+ })
+
+ // Unknown
+ assert.Panics(t, func() {
+ e.Use(nil)
})
// Route
@@ -124,8 +131,8 @@ func TestEchoMiddleware(t *testing.T) {
w := httptest.NewRecorder()
r, _ := http.NewRequest(GET, "/hello", nil)
e.ServeHTTP(w, r)
- if b.String() != "abcdefg" {
- t.Errorf("buffer should be abcdefghi, found %s", b.String())
+ if b.String() != "abcdefgh" {
+ t.Errorf("buffer should be abcdefgh, found %s", b.String())
}
if w.Body.String() != "world" {
t.Error("body should be world")
@@ -178,6 +185,11 @@ func TestEchoHandler(t *testing.T) {
if w.Body.String() != "4" {
t.Error("body should be 4")
}
+
+ // Unknown
+ assert.Panics(t, func() {
+ e.Get("/5", nil)
+ })
}
func TestEchoGroup(t *testing.T) {
@@ -240,50 +252,50 @@ func TestEchoGroup(t *testing.T) {
func TestEchoConnect(t *testing.T) {
e := New()
- testMethod(CONNECT, "/", e, nil, t)
+ testMethod(t, e, nil, CONNECT, "/")
}
func TestEchoDelete(t *testing.T) {
e := New()
- testMethod(DELETE, "/", e, nil, t)
+ testMethod(t, e, nil, DELETE, "/")
}
func TestEchoGet(t *testing.T) {
e := New()
- testMethod(GET, "/", e, nil, t)
+ testMethod(t, e, nil, GET, "/")
}
func TestEchoHead(t *testing.T) {
e := New()
- testMethod(HEAD, "/", e, nil, t)
+ testMethod(t, e, nil, HEAD, "/")
}
func TestEchoOptions(t *testing.T) {
e := New()
- testMethod(OPTIONS, "/", e, nil, t)
+ testMethod(t, e, nil, OPTIONS, "/")
}
func TestEchoPatch(t *testing.T) {
e := New()
- testMethod(PATCH, "/", e, nil, t)
+ testMethod(t, e, nil, PATCH, "/")
}
func TestEchoPost(t *testing.T) {
e := New()
- testMethod(POST, "/", e, nil, t)
+ testMethod(t, e, nil, POST, "/")
}
func TestEchoPut(t *testing.T) {
e := New()
- testMethod(PUT, "/", e, nil, t)
+ testMethod(t, e, nil, PUT, "/")
}
func TestEchoTrace(t *testing.T) {
e := New()
- testMethod(TRACE, "/", e, nil, t)
+ testMethod(t, e, nil, TRACE, "/")
}
-func testMethod(method, path string, e *Echo, g *Group, t *testing.T) {
+func testMethod(t *testing.T, e *Echo, g *Group, method, path string) {
m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:]))
p := reflect.ValueOf(path)
h := reflect.ValueOf(func(c *Context) error {
@@ -375,11 +387,11 @@ func TestEchoNotFound(t *testing.T) {
}
}
-func verifyUser(u2 *user, t *testing.T) {
- if u2.ID != u1.ID {
- t.Errorf("user id should be %s, found %s", u1.ID, u2.ID)
- }
- if u2.Name != u1.Name {
- t.Errorf("user name should be %s, found %s", u1.Name, u2.Name)
- }
-}
+//func verifyUser(u2 *user, t *testing.T) {
+// if u2.ID != u1.ID {
+// t.Errorf("user id should be %s, found %s", u1.ID, u2.ID)
+// }
+// if u2.Name != u1.Name {
+// t.Errorf("user name should be %s, found %s", u1.Name, u2.Name)
+// }
+//}
diff --git a/group_test.go b/group_test.go
index f78d63c1..3a948679 100644
--- a/group_test.go
+++ b/group_test.go
@@ -4,45 +4,45 @@ import "testing"
func TestGroupConnect(t *testing.T) {
g := New().Group("/group")
- testMethod(CONNECT, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, CONNECT, "/")
}
func TestGroupDelete(t *testing.T) {
g := New().Group("/group")
- testMethod(DELETE, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, DELETE, "/")
}
func TestGroupGet(t *testing.T) {
g := New().Group("/group")
- testMethod(GET, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, GET, "/")
}
func TestGroupHead(t *testing.T) {
g := New().Group("/group")
- testMethod(HEAD, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, HEAD, "/")
}
func TestGroupOptions(t *testing.T) {
g := New().Group("/group")
- testMethod(OPTIONS, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, OPTIONS, "/")
}
func TestGroupPatch(t *testing.T) {
g := New().Group("/group")
- testMethod(PATCH, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, PATCH, "/")
}
func TestGroupPost(t *testing.T) {
g := New().Group("/group")
- testMethod(POST, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, POST, "/")
}
func TestGroupPut(t *testing.T) {
g := New().Group("/group")
- testMethod(PUT, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, PUT, "/")
}
func TestGroupTrace(t *testing.T) {
g := New().Group("/group")
- testMethod(TRACE, "/", &g.echo, g, t)
+ testMethod(t, &g.echo, g, TRACE, "/")
}
diff --git a/middleware/auth_test.go b/middleware/auth_test.go
index d4d77165..377f6855 100644
--- a/middleware/auth_test.go
+++ b/middleware/auth_test.go
@@ -6,12 +6,14 @@ import (
"testing"
"github.com/labstack/echo"
+ "net/http/httptest"
+ "github.com/stretchr/testify/assert"
)
func TestBasicAuth(t *testing.T) {
- req, _ := http.NewRequest(echo.POST, "/", nil)
- res := &echo.Response{}
- c := echo.NewContext(req, res, echo.New())
+ req, _ := http.NewRequest(echo.GET, "/", nil)
+ rec := httptest.NewRecorder()
+ c := echo.NewContext(req, echo.NewResponse(rec), echo.New())
fn := func(u, p string) bool {
if u == "joe" && p == "secret" {
return true
@@ -26,16 +28,12 @@ func TestBasicAuth(t *testing.T) {
auth := Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.Authorization, auth)
- if ba(c) != nil {
- t.Error("expected `pass`")
- }
+ assert.NoError(t, ba(c))
// Case insensitive
auth = "basic " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.Authorization, auth)
- if ba(c) != nil {
- t.Error("expected `pass`, with case insensitive header.")
- }
+ assert.NoError(t, ba(c))
//---------------------
// Invalid credentials
@@ -46,41 +44,30 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
he := ba(c).(*echo.HTTPError)
- if ba(c) == nil {
- t.Error("expected `fail`, with incorrect password.")
- } else if he.Code() != http.StatusUnauthorized {
- t.Errorf("expected status `401`, got %d", he.Code())
- }
+ assert.Equal(t, http.StatusUnauthorized, he.Code())
// Empty Authorization header
req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn)
he = ba(c).(*echo.HTTPError)
- if he == nil {
- t.Error("expected `fail`, with empty Authorization header.")
- } else if he.Code() != http.StatusBadRequest {
- t.Errorf("expected status `400`, got %d", he.Code())
- }
+ assert.Equal(t, http.StatusBadRequest, he.Code())
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
he = ba(c).(*echo.HTTPError)
- if he == nil {
- t.Error("expected `fail`, with invalid Authorization header.")
- } else if he.Code() != http.StatusBadRequest {
- t.Errorf("expected status `400`, got %d", he.Code())
- }
+ assert.Equal(t, http.StatusBadRequest, he.Code())
// Invalid scheme
- auth = "Ace " + base64.StdEncoding.EncodeToString([]byte(" :secret"))
+ auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn)
he = ba(c).(*echo.HTTPError)
- if he == nil {
- t.Error("expected `fail`, with invalid scheme.")
- } else if he.Code() != http.StatusBadRequest {
- t.Errorf("expected status `400`, got %d", he.Code())
- }
+ assert.Equal(t, http.StatusBadRequest, he.Code())
+
+ // WebSocket
+ c.Request().Header.Set(echo.Upgrade, echo.WebSocket)
+ ba = BasicAuth(fn)
+ assert.NoError(t, ba(c))
}
diff --git a/middleware/compress_test.go b/middleware/compress_test.go
index 90417ac7..85d7fe94 100644
--- a/middleware/compress_test.go
+++ b/middleware/compress_test.go
@@ -9,43 +9,34 @@ import (
"bytes"
"github.com/labstack/echo"
+ "github.com/stretchr/testify/assert"
)
func TestGzip(t *testing.T) {
// Empty Accept-Encoding header
req, _ := http.NewRequest(echo.GET, "/", nil)
- w := httptest.NewRecorder()
- res := echo.NewResponse(w)
- c := echo.NewContext(req, res, echo.New())
+ rec := httptest.NewRecorder()
+ c := echo.NewContext(req, echo.NewResponse(rec), echo.New())
h := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
Gzip()(h)(c)
- s := w.Body.String()
- if s != "test" {
- t.Errorf("expected `test`, with empty Accept-Encoding header, got %s.", s)
- }
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "test", rec.Body.String())
- // Content-Encoding header
+ // With Accept-Encoding header
+ req, _ = http.NewRequest(echo.GET, "/", nil)
req.Header.Set(echo.AcceptEncoding, "gzip")
- w = httptest.NewRecorder()
- c.Response().SetWriter(w)
+ rec = httptest.NewRecorder()
+ c = echo.NewContext(req, echo.NewResponse(rec), echo.New())
Gzip()(h)(c)
- ce := w.Header().Get(echo.ContentEncoding)
- if ce != "gzip" {
- t.Errorf("expected Content-Encoding header `gzip`, got %d.", ce)
- }
-
- // Body
- r, err := gzip.NewReader(w.Body)
+ assert.Equal(t, http.StatusOK, rec.Code)
+ assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding))
+ r, err := gzip.NewReader(rec.Body)
defer r.Close()
- if err != nil {
- t.Fatal(err)
- }
- buf := new(bytes.Buffer)
- buf.ReadFrom(r)
- s = buf.String()
- if s != "test" {
- t.Errorf("expected body `test`, got %s.", s)
+ if assert.NoError(t, err) {
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(r)
+ assert.Equal(t, "test", buf.String())
}
}
diff --git a/middleware/logger.go b/middleware/logger.go
index 92bc8ed8..19d1301d 100644
--- a/middleware/logger.go
+++ b/middleware/logger.go
@@ -15,7 +15,7 @@ func Logger() echo.MiddlewareFunc {
if err := h(c); err != nil {
c.Error(err)
}
- end := time.Now()
+ stop := time.Now()
method := c.Request().Method
path := c.Request().URL.Path
if path == "" {
@@ -34,7 +34,7 @@ func Logger() echo.MiddlewareFunc {
code = color.Cyan(n)
}
- log.Printf("%s %s %s %s %d", method, path, code, end.Sub(start), size)
+ log.Printf("%s %s %s %s %d", method, path, code, stop.Sub(start), size)
return nil
}
}
diff --git a/middleware/logger_test.go b/middleware/logger_test.go
index 0c341e67..ee4809f1 100644
--- a/middleware/logger_test.go
+++ b/middleware/logger_test.go
@@ -5,14 +5,14 @@ import (
"net/http"
"net/http/httptest"
"testing"
+ "errors"
)
func TestLogger(t *testing.T) {
e := echo.New()
req, _ := http.NewRequest(echo.GET, "/", nil)
- w := httptest.NewRecorder()
- res := echo.NewResponse(w)
- c := echo.NewContext(req, res, e)
+ rec := httptest.NewRecorder()
+ c := echo.NewContext(req, echo.NewResponse(rec), e)
// Status 2xx
h := func(c *echo.Context) error {
@@ -20,17 +20,28 @@ func TestLogger(t *testing.T) {
}
Logger()(h)(c)
+ // Status 3xx
+ rec = httptest.NewRecorder()
+ c = echo.NewContext(req, echo.NewResponse(rec), e)
+ h = func(c *echo.Context) error {
+ return c.String(http.StatusTemporaryRedirect, "test")
+ }
+ Logger()(h)(c)
+
// Status 4xx
- c = echo.NewContext(req, echo.NewResponse(w), e)
+ rec = httptest.NewRecorder()
+ c = echo.NewContext(req, echo.NewResponse(rec), e)
h = func(c *echo.Context) error {
return c.String(http.StatusNotFound, "test")
}
Logger()(h)(c)
- // Status 5xx
- c = echo.NewContext(req, echo.NewResponse(w), e)
+ // Status 5xx with empty path
+ req, _ = http.NewRequest(echo.GET, "", nil)
+ rec = httptest.NewRecorder()
+ c = echo.NewContext(req, echo.NewResponse(rec), e)
h = func(c *echo.Context) error {
- return c.String(http.StatusInternalServerError, "test")
+ return errors.New("error")
}
Logger()(h)(c)
}
diff --git a/middleware/recover_test.go b/middleware/recover_test.go
index d60ac325..003c7519 100644
--- a/middleware/recover_test.go
+++ b/middleware/recover_test.go
@@ -1,33 +1,24 @@
package middleware
import (
- "github.com/labstack/echo"
"net/http"
"net/http/httptest"
- "strings"
"testing"
+
+ "github.com/labstack/echo"
+ "github.com/stretchr/testify/assert"
)
func TestRecover(t *testing.T) {
e := echo.New()
e.SetDebug(true)
req, _ := http.NewRequest(echo.GET, "/", nil)
- w := httptest.NewRecorder()
- res := echo.NewResponse(w)
- c := echo.NewContext(req, res, e)
+ rec := httptest.NewRecorder()
+ c := echo.NewContext(req, echo.NewResponse(rec), e)
h := func(c *echo.Context) error {
panic("test")
}
-
- // Status
Recover()(h)(c)
- if w.Code != http.StatusInternalServerError {
- t.Errorf("expected status `500`, got %d.", w.Code)
- }
-
- // Body
- s := w.Body.String()
- if !strings.Contains(s, "panic recover") {
- t.Error("expected body contains `panice recover`.")
- }
+ assert.Equal(t, http.StatusInternalServerError, rec.Code)
+ assert.Contains(t, rec.Body.String(), "panic recover")
}
diff --git a/middleware/slash_test.go b/middleware/slash_test.go
index 67fe72bf..59eaf238 100644
--- a/middleware/slash_test.go
+++ b/middleware/slash_test.go
@@ -6,33 +6,22 @@ import (
"testing"
"github.com/labstack/echo"
+ "github.com/stretchr/testify/assert"
)
func TestStripTrailingSlash(t *testing.T) {
req, _ := http.NewRequest(echo.GET, "/users/", nil)
- res := echo.NewResponse(httptest.NewRecorder())
- c := echo.NewContext(req, res, echo.New())
+ rec := httptest.NewRecorder()
+ c := echo.NewContext(req, echo.NewResponse(rec), echo.New())
StripTrailingSlash()(c)
- p := c.Request().URL.Path
- if p != "/users" {
- t.Errorf("expected path `/users` got, %s.", p)
- }
+ assert.Equal(t, "/users", c.Request().URL.Path)
}
func TestRedirectToSlash(t *testing.T) {
req, _ := http.NewRequest(echo.GET, "/users", nil)
- res := echo.NewResponse(httptest.NewRecorder())
- c := echo.NewContext(req, res, echo.New())
+ rec := httptest.NewRecorder()
+ c := echo.NewContext(req, echo.NewResponse(rec), echo.New())
RedirectToSlash(RedirectToSlashOptions{Code: http.StatusTemporaryRedirect})(c)
-
- // Status code
- if res.Status() != http.StatusTemporaryRedirect {
- t.Errorf("expected status `307`, got %d.", res.Status())
- }
-
- // Location header
- l := c.Response().Header().Get("Location")
- if l != "/users/" {
- t.Errorf("expected Location header `/users/`, got %s.", l)
- }
+ assert.Equal(t, http.StatusTemporaryRedirect, rec.Code)
+ assert.Equal(t, "/users/", c.Response().Header().Get("Location"))
}
diff --git a/response_test.go b/response_test.go
index fbc21bef..7078eeb9 100644
--- a/response_test.go
+++ b/response_test.go
@@ -24,11 +24,12 @@ func TestResponse(t *testing.T) {
r.WriteHeader(http.StatusOK)
assert.Equal(t, http.StatusOK, r.status)
- // committed
+ // Committed
assert.True(t, r.committed)
- // Response already committed
- r.WriteHeader(http.StatusOK)
+ // Already committed
+ r.WriteHeader(http.StatusTeapot)
+ assert.NotEqual(t, http.StatusTeapot, r.Status())
// Status
r.status = http.StatusOK
@@ -43,7 +44,7 @@ func TestResponse(t *testing.T) {
r.Flush()
// Size
- assert.Equal(t, int64(len(s)), r.Size())
+ assert.Len(t, s, int(r.Size()))
// Hijack
assert.Panics(t, func() {
diff --git a/router.go b/router.go
index b44c61cc..af9b4aea 100644
--- a/router.go
+++ b/router.go
@@ -150,9 +150,6 @@ func newNode(t ntype, pre string, p *node, c children, h HandlerFunc, pnames []s
}
}
-func (n *node) addChild(c *node) {
-}
-
func (n *node) findChild(l byte) *node {
for _, c := range n.children {
if c.label == l {