From c654c422c48aa416e4fb30f1a99155ecbd1b24ca Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Mon, 6 Jun 2016 22:27:36 -0700 Subject: [PATCH] More coverage and better response adapter for standard/response Signed-off-by: Vishal Rana --- engine/fasthttp/response.go | 2 +- engine/standard/response.go | 36 +++++++---------- engine/standard/response_test.go | 11 +++--- engine/standard/server.go | 4 +- middleware/compress_test.go | 18 +++------ middleware/csrf_test.go | 49 ++++++++++++++++++++++-- middleware/static_test.go | 66 ++++++++++++++++++++++++++++++++ 7 files changed, 139 insertions(+), 47 deletions(-) diff --git a/engine/fasthttp/response.go b/engine/fasthttp/response.go index 7b5c7a90..457a3a45 100644 --- a/engine/fasthttp/response.go +++ b/engine/fasthttp/response.go @@ -52,7 +52,7 @@ func (r *Response) WriteHeader(code int) { // Write implements `engine.Response#Write` function. func (r *Response) Write(b []byte) (n int, err error) { - if !r.Committed() { + if !r.committed { r.WriteHeader(http.StatusOK) } n, err = r.writer.Write(b) diff --git a/engine/standard/response.go b/engine/standard/response.go index b0c43e53..8b24d5eb 100644 --- a/engine/standard/response.go +++ b/engine/standard/response.go @@ -14,6 +14,7 @@ type ( // Response implements `engine.Response`. Response struct { http.ResponseWriter + adapter *responseAdapter header engine.Header status int size int64 @@ -23,19 +24,20 @@ type ( } responseAdapter struct { - http.ResponseWriter - response *Response + *Response } ) // NewResponse returns `Response` instance. -func NewResponse(w http.ResponseWriter, l log.Logger) *Response { - return &Response{ +func NewResponse(w http.ResponseWriter, l log.Logger) (r *Response) { + r = &Response{ ResponseWriter: w, header: &Header{Header: w.Header()}, writer: w, logger: l, } + r.adapter = &responseAdapter{Response: r} + return } // Header implements `engine.Response#Header` function. @@ -56,7 +58,7 @@ func (r *Response) WriteHeader(code int) { // Write implements `engine.Response#Write` function. func (r *Response) Write(b []byte) (n int, err error) { - if !r.Committed() { + if !r.committed { r.WriteHeader(http.StatusOK) } n, err = r.writer.Write(b) @@ -126,7 +128,8 @@ func (r *Response) CloseNotify() <-chan bool { } func (r *Response) reset(w http.ResponseWriter, a *responseAdapter, h engine.Header) { - r.ResponseWriter = a + r.ResponseWriter = w + r.adapter = a r.header = h r.status = http.StatusOK r.size = 0 @@ -134,23 +137,10 @@ func (r *Response) reset(w http.ResponseWriter, a *responseAdapter, h engine.Hea r.writer = w } -func (a *responseAdapter) Write(b []byte) (n int, err error) { - return a.response.Write(b) +func (r *responseAdapter) Header() http.Header { + return r.ResponseWriter.Header() } -func (a *responseAdapter) Flush() { - a.ResponseWriter.(http.Flusher).Flush() -} - -func (a *responseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return a.ResponseWriter.(http.Hijacker).Hijack() -} - -func (a *responseAdapter) CloseNotify() <-chan bool { - return a.ResponseWriter.(http.CloseNotifier).CloseNotify() -} - -func (a *responseAdapter) reset(w http.ResponseWriter, r *Response) { - a.ResponseWriter = w - a.response = r +func (r *responseAdapter) reset(res *Response) { + r.Response = res } diff --git a/engine/standard/response_test.go b/engine/standard/response_test.go index 0a61de50..d579c466 100644 --- a/engine/standard/response_test.go +++ b/engine/standard/response_test.go @@ -1,16 +1,17 @@ package standard import ( - "github.com/labstack/gommon/log" - "github.com/stretchr/testify/assert" "io/ioutil" "net/http" "net/http/httptest" "testing" "time" + + "github.com/labstack/gommon/log" + "github.com/stretchr/testify/assert" ) -func TestResponse_WriteHeader(t *testing.T) { +func TestResponseWriteHeader(t *testing.T) { recorder := httptest.NewRecorder() resp := NewResponse(recorder, log.New("echo")) @@ -20,7 +21,7 @@ func TestResponse_WriteHeader(t *testing.T) { assert.True(t, resp.Committed()) } -func TestResponse_Write(t *testing.T) { +func TestResponseWrite(t *testing.T) { recorder := httptest.NewRecorder() resp := NewResponse(recorder, log.New("echo")) resp.Write([]byte("Hello")) @@ -32,7 +33,7 @@ func TestResponse_Write(t *testing.T) { assert.True(t, recorder.Flushed) } -func TestResponse_SetCookie(t *testing.T) { +func TestResponseSetCookie(t *testing.T) { recorder := httptest.NewRecorder() resp := NewResponse(recorder, log.New("echo")) diff --git a/engine/standard/server.go b/engine/standard/server.go index 61140760..efe2096e 100644 --- a/engine/standard/server.go +++ b/engine/standard/server.go @@ -130,7 +130,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Response res := s.pool.response.Get().(*Response) resAdpt := s.pool.responseAdapter.Get().(*responseAdapter) - resAdpt.reset(w, res) + resAdpt.reset(res) resHdr := s.pool.header.Get().(*Header) resHdr.reset(w.Header()) res.reset(w, resAdpt, resHdr) @@ -150,7 +150,7 @@ func WrapHandler(h http.Handler) echo.HandlerFunc { return func(c echo.Context) error { req := c.Request().(*Request) res := c.Response().(*Response) - h.ServeHTTP(res.ResponseWriter, req.Request) + h.ServeHTTP(res.adapter, req.Request) return nil } } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 6899dc52..7bf0519b 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -3,7 +3,6 @@ package middleware import ( "bytes" "compress/gzip" - "io/ioutil" "net/http" "testing" @@ -52,13 +51,10 @@ func TestGzipNoContent(t *testing.T) { h := Gzip()(func(c echo.Context) error { return c.NoContent(http.StatusOK) }) - h(c) - - assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) - b, err := ioutil.ReadAll(rec.Body) - if assert.NoError(t, err) { - assert.Equal(t, 0, len(b)) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) } } @@ -71,10 +67,6 @@ func TestGzipErrorReturned(t *testing.T) { req := test.NewRequest(echo.GET, "/", nil) rec := test.NewResponseRecorder() e.ServeHTTP(req, rec) - assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) - b, err := ioutil.ReadAll(rec.Body) - if assert.NoError(t, err) { - assert.Equal(t, "error", string(b)) - } + assert.Equal(t, "error", rec.Body.String()) } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index c69d828b..5fa67218 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -2,6 +2,8 @@ package middleware import ( "net/http" + "net/url" + "strings" "testing" "github.com/labstack/echo" @@ -14,11 +16,20 @@ func TestCSRF(t *testing.T) { req := test.NewRequest(echo.GET, "/", nil) rec := test.NewResponseRecorder() c := e.NewContext(req, rec) - csrf := CSRF([]byte("secret")) + csrf := CSRFWithConfig(CSRFConfig{ + Secret: []byte("secret"), + CookiePath: "/", + CookieDomain: "labstack.com", + }) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) + // No secret + assert.Panics(t, func() { + CSRF(nil) + }) + // Generate CSRF token h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "csrf") @@ -35,6 +46,38 @@ func TestCSRF(t *testing.T) { salt, _ := generateSalt(8) token := generateCSRFToken([]byte("secret"), salt) req.Header().Set(echo.HeaderXCSRFToken, token) - h(c) - assert.Equal(t, http.StatusOK, rec.Status()) + if assert.NoError(t, h(c)) { + assert.Equal(t, http.StatusOK, rec.Status()) + } +} + +func TestCSRFTokenFromForm(t *testing.T) { + f := make(url.Values) + f.Set("csrf", "token") + e := echo.New() + req := test.NewRequest(echo.POST, "/", strings.NewReader(f.Encode())) + req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) + c := e.NewContext(req, nil) + token, err := csrfTokenFromForm("csrf")(c) + if assert.NoError(t, err) { + assert.Equal(t, "token", token) + } + token, err = csrfTokenFromForm("invalid")(c) + assert.Error(t, err) +} + +func TestCSRFTokenFromQuery(t *testing.T) { + q := make(url.Values) + q.Set("csrf", "token") + e := echo.New() + req := test.NewRequest(echo.GET, "/?"+q.Encode(), nil) + req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) + c := e.NewContext(req, nil) + token, err := csrfTokenFromQuery("csrf")(c) + if assert.NoError(t, err) { + assert.Equal(t, "token", token) + } + token, err = csrfTokenFromQuery("invalid")(c) + assert.Error(t, err) + csrfTokenFromQuery("csrf") } diff --git a/middleware/static_test.go b/middleware/static_test.go index c870d7c1..b1ce6649 100644 --- a/middleware/static_test.go +++ b/middleware/static_test.go @@ -1 +1,67 @@ package middleware + +import ( + "net/http" + "testing" + + "github.com/labstack/echo" + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +func TestStatic(t *testing.T) { + e := echo.New() + req := test.NewRequest(echo.GET, "/", nil) + rec := test.NewResponseRecorder() + c := e.NewContext(req, rec) + h := Static("../_fixture")(func(c echo.Context) error { + return echo.ErrNotFound + }) + + // Directory + if assert.NoError(t, h(c)) { + assert.Contains(t, rec.Body.String(), "Echo") + } + + // HTML5 mode + req = test.NewRequest(echo.GET, "/client", nil) + rec = test.NewResponseRecorder() + c = e.NewContext(req, rec) + static := StaticWithConfig(StaticConfig{ + Root: "../_fixture", + HTML5: true, + }) + h = static(func(c echo.Context) error { + return echo.ErrNotFound + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, http.StatusOK, rec.Status()) + } + + // Browse + req = test.NewRequest(echo.GET, "/", nil) + rec = test.NewResponseRecorder() + c = e.NewContext(req, rec) + static = StaticWithConfig(StaticConfig{ + Root: "../_fixture/images", + Browse: true, + }) + h = static(func(c echo.Context) error { + return echo.ErrNotFound + }) + if assert.NoError(t, h(c)) { + assert.Contains(t, rec.Body.String(), "walle") + } + + // Not found + req = test.NewRequest(echo.GET, "/not-found", nil) + rec = test.NewResponseRecorder() + c = e.NewContext(req, rec) + static = StaticWithConfig(StaticConfig{ + Root: "../_fixture/images", + }) + h = static(func(c echo.Context) error { + return echo.ErrNotFound + }) + assert.Error(t, h(c)) +}