From f5a385b547487de69f298f2df1b56bb3085b11f1 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Thu, 14 May 2015 16:25:49 -0700 Subject: [PATCH] Middleware with options Signed-off-by: Vishal Rana --- echo_test.go | 8 ++++---- middleware/slash.go | 23 ++++++++++++++++++++--- middleware/slash_test.go | 7 +++---- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/echo_test.go b/echo_test.go index 0ee49837..763d60ad 100644 --- a/echo_test.go +++ b/echo_test.go @@ -34,7 +34,7 @@ func TestEchoIndex(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/", nil) e.ServeHTTP(w, r) - if w.Code != 200 { + if w.Code != http.StatusOK { t.Errorf("status code should be 200, found %d", w.Code) } } @@ -45,7 +45,7 @@ func TestEchoFavicon(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/favicon.ico", nil) e.ServeHTTP(w, r) - if w.Code != 200 { + if w.Code != http.StatusOK { t.Errorf("status code should be 200, found %d", w.Code) } } @@ -56,7 +56,7 @@ func TestEchoStatic(t *testing.T) { w := httptest.NewRecorder() r, _ := http.NewRequest(GET, "/scripts/main.js", nil) e.ServeHTTP(w, r) - if w.Code != 200 { + if w.Code != http.StatusOK { t.Errorf("status code should be 200, found %d", w.Code) } } @@ -227,7 +227,7 @@ func TestEchoGroup(t *testing.T) { w = httptest.NewRecorder() r, _ = http.NewRequest(GET, "/group3/group4/home", nil) e.ServeHTTP(w, r) - if w.Code != 200 { + if w.Code != http.StatusOK { t.Errorf("status code should be 200, found %d", w.Code) } } diff --git a/middleware/slash.go b/middleware/slash.go index c343c8af..70b0eed8 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -1,6 +1,15 @@ package middleware -import "github.com/labstack/echo" +import ( + "github.com/labstack/echo" + "net/http" +) + +type ( + RedirectToSlashOptions struct { + Code int + } +) // StripTrailingSlash removes trailing slash from request path. func StripTrailingSlash() echo.HandlerFunc { @@ -15,8 +24,16 @@ func StripTrailingSlash() echo.HandlerFunc { } // RedirectToSlash redirects requests without trailing slash path to trailing slash -// path, with status code. -func RedirectToSlash(code int) echo.HandlerFunc { +// path, with . +func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc { + code := http.StatusMovedPermanently + + for _, o := range opts { + if o.Code != 0 { + code = o.Code + } + } + return func(c *echo.Context) (he *echo.HTTPError) { p := c.Request.URL.Path l := len(p) diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 4f0f55fc..83f24552 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -22,10 +22,9 @@ func TestRedirectToSlash(t *testing.T) { req, _ := http.NewRequest(echo.GET, "/users", nil) res := &echo.Response{Writer: httptest.NewRecorder()} c := echo.NewContext(req, res, echo.New()) - RedirectToSlash(301)(c) - println(c.Response.Header().Get("Location")) - if res.Status() != 301 { - t.Errorf("status code should be 301, found %d", res.Status()) + RedirectToSlash(RedirectToSlashOptions{Code: http.StatusTemporaryRedirect})(c) + if res.Status() != http.StatusTemporaryRedirect { + t.Errorf("status code should be 307, found %d", res.Status()) } if c.Response.Header().Get("Location") != "/users/" { t.Error("Location header should be /users/")