diff --git a/.travis.yml b/.travis.yml index 675044e9..27f892a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,6 @@ before_install: - go get golang.org/x/tools/cmd/cover script: - go test -coverprofile=echo.coverprofile - - go test -coverprofile=mw.coverprofile ./middleware + - go test -coverprofile=middleware.coverprofile ./middleware - $HOME/gopath/bin/gover - $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci diff --git a/echo.go b/echo.go index 67d771aa..8bad0190 100644 --- a/echo.go +++ b/echo.go @@ -37,15 +37,12 @@ func New() (b *Echo) { maxParam: 5, notFoundHandler: func(c *Context) { http.Error(c.Response, http.StatusText(http.StatusNotFound), http.StatusNotFound) - // c.Halt() }, methodNotAllowedHandler: func(c *Context) { http.Error(c.Response, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - // c.Halt() }, internalServerErrorHandler: func(c *Context) { http.Error(c.Response, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - // c.Halt() }, } b.Router = NewRouter(b) @@ -54,8 +51,7 @@ func New() (b *Echo) { Response: &response{}, params: make(Params, b.maxParam), store: make(store), - // i: -1, - echo: b, + echo: b, } } return @@ -194,7 +190,7 @@ func wrapM(m Middleware) MiddlewareFunc { switch m := m.(type) { case func(HandlerFunc) HandlerFunc: return MiddlewareFunc(m) - case http.HandlerFunc, func(http.ResponseWriter, *http.Request), http.Handler: + case http.HandlerFunc, http.Handler: return func(h HandlerFunc) HandlerFunc { return func(c *Context) { m.(http.Handler).ServeHTTP(c.Response, c.Request) diff --git a/echo_test.go b/echo_test.go index 7bee0de1..bd216e40 100644 --- a/echo_test.go +++ b/echo_test.go @@ -32,8 +32,8 @@ func TestEchoMaxParam(t *testing.T) { func TestEchoIndex(t *testing.T) { b := New() b.Index("example/public/index.html") - r, _ := http.NewRequest("GET", "/", nil) w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/", nil) b.ServeHTTP(w, r) if w.Code != 200 { t.Errorf("status code should be 200, found %d", w.Code) @@ -43,14 +43,67 @@ func TestEchoIndex(t *testing.T) { func TestEchoStatic(t *testing.T) { b := New() b.Static("/js", "example/public/js") - r, _ := http.NewRequest("GET", "/js/main.js", nil) w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/js/main.js", nil) b.ServeHTTP(w, r) if w.Code != 200 { t.Errorf("status code should be 200, found %d", w.Code) } } +func TestEchoMiddleware(t *testing.T) { + b := New() + + // func(HandlerFunc) HandlerFunc + b.Use(func(h HandlerFunc) HandlerFunc { + return HandlerFunc(func(c *Context) { + c.Request.Header.Set("a", "1") + h(c) + }) + }) + + // http.HandlerFunc + b.Use(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Header.Set("b", "2") + })) + + // http.Handler + b.Use(http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Header.Set("c", "3") + }))) + + // func(http.Handler) http.Handler + b.Use(func(http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Header.Set("d", "4") + }) + }) + + // Route + b.Get("/users", func(c *Context) { + h := c.Request.Header.Get("a") + if h != "1" { + t.Errorf("header a should be 1, found %s", h) + } + h = c.Request.Header.Get("b") + if h != "2" { + t.Errorf("header b should be 2, found %s", h) + } + h = c.Request.Header.Get("c") + if h != "3" { + t.Errorf("header c should be 3, found %s", h) + } + h = c.Request.Header.Get("d") + if h != "4" { + t.Errorf("header d should be 4, found %s", h) + } + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/users", nil) + b.ServeHTTP(w, r) +} + func verifyUser(rd io.Reader, t *testing.T) { var l int64 err := binary.Read(rd, binary.BigEndian, &l) // Body length diff --git a/response.go b/response.go index 11924df4..9f03e9c3 100644 --- a/response.go +++ b/response.go @@ -17,7 +17,6 @@ type ( } ) - func (r *response) WriteHeader(n int) { // TODO: fix when halted. if r.committed { @@ -53,7 +52,7 @@ func (r *response) Flusher() { func (r *response) Hijack() (net.Conn, *bufio.ReadWriter, error) { h, ok := r.ResponseWriter.(http.Hijacker) if !ok { - return nil, nil, errors.New("hijacker interface not supported") + return nil, nil, errors.New("bolt: hijacker interface not supported") } return h.Hijack() }