From 86ae297e23260c553e569a620fd24127835836d7 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sat, 27 Aug 2016 13:03:40 -0700 Subject: [PATCH] New redirect middleware Signed-off-by: Vishal Rana --- middleware/jwt_test.go | 10 +++-- middleware/redirect.go | 73 +++++++++++++++++++++++++++++++++++++ middleware/redirect_test.go | 62 +++++++++++++++++++++++++++++++ middleware/slash_test.go | 1 - test/request.go | 1 - 5 files changed, 141 insertions(+), 6 deletions(-) create mode 100644 middleware/redirect.go create mode 100644 middleware/redirect_test.go diff --git a/middleware/jwt_test.go b/middleware/jwt_test.go index 7718d65e..d976a8c5 100644 --- a/middleware/jwt_test.go +++ b/middleware/jwt_test.go @@ -41,7 +41,10 @@ func TestJWT(t *testing.T) { hdrCookie string // test.Request doesn't provide SetCookie(); use name=val info string }{ - {expPanic: true, info: "No signing key provided"}, + { + expPanic: true, + info: "No signing key provided", + }, { expErrCode: http.StatusBadRequest, config: JWTConfig{ @@ -141,7 +144,6 @@ func TestJWT(t *testing.T) { info: "Empty cookie", }, } { - if tc.reqURL == "" { tc.reqURL = "/" } @@ -173,8 +175,8 @@ func TestJWT(t *testing.T) { case jwt.MapClaims: assert.Equal(t, claims["name"], "John Doe", tc.info) case *jwtCustomClaims: - assert.Equal(t, claims.Name, "John Doe") - assert.Equal(t, claims.Admin, true) + assert.Equal(t, claims.Name, "John Doe", tc.info) + assert.Equal(t, claims.Admin, true, tc.info) default: panic("unexpected type of claims") } diff --git a/middleware/redirect.go b/middleware/redirect.go new file mode 100644 index 00000000..8afad96d --- /dev/null +++ b/middleware/redirect.go @@ -0,0 +1,73 @@ +package middleware + +import ( + "net/http" + + "github.com/labstack/echo" +) + +// HTTPSRedirect redirects HTTP requests to HTTPS. +// For example, http://labstack.com will be redirect to https://labstack.com. +func HTTPSRedirect() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + host := req.Host() + uri := req.URI() + if !req.IsTLS() { + return c.Redirect(http.StatusMovedPermanently, "https://"+host+uri) + } + return next(c) + } + } +} + +// HTTPSWWWRedirect redirects HTTP requests to WWW HTTPS. +// For example, http://labstack.com will be redirect to https://www.labstack.com. +func HTTPSWWWRedirect() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + host := req.Host() + uri := req.URI() + if !req.IsTLS() && host[:3] != "www" { + return c.Redirect(http.StatusMovedPermanently, "https://www."+host+uri) + } + return next(c) + } + } +} + +// WWWRedirect redirects non WWW requests to WWW. +// For example, http://labstack.com will be redirect to http://www.labstack.com. +func WWWRedirect() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + scheme := req.Scheme() + host := req.Host() + if host[:3] != "www" { + uri := req.URI() + return c.Redirect(http.StatusMovedPermanently, scheme+"://www."+host+uri) + } + return next(c) + } + } +} + +// NonWWWRedirect redirects WWW request to non WWW. +// For example, http://www.labstack.com will be redirect to http://labstack.com. +func NonWWWRedirect() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + req := c.Request() + scheme := req.Scheme() + host := req.Host() + if host[:3] == "www" { + uri := req.URI() + return c.Redirect(http.StatusMovedPermanently, scheme+"://"+host[4:]+uri) + } + return next(c) + } + } +} diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go new file mode 100644 index 00000000..1da64fd6 --- /dev/null +++ b/middleware/redirect_test.go @@ -0,0 +1,62 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/labstack/echo" + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +func TestHTTPSRedirect(t *testing.T) { + e := echo.New() + next := func(c echo.Context) (err error) { + return c.NoContent(http.StatusOK) + } + req := test.NewRequest(echo.GET, "http://labstack.com", nil) + res := test.NewResponseRecorder() + c := e.NewContext(req, res) + HTTPSRedirect()(next)(c) + assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, "https://labstack.com", res.Header().Get(echo.HeaderLocation)) +} + +func TestHTTPSWWWRedirect(t *testing.T) { + e := echo.New() + next := func(c echo.Context) (err error) { + return c.NoContent(http.StatusOK) + } + req := test.NewRequest(echo.GET, "http://labstack.com", nil) + res := test.NewResponseRecorder() + c := e.NewContext(req, res) + HTTPSWWWRedirect()(next)(c) + assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, "https://www.labstack.com", res.Header().Get(echo.HeaderLocation)) +} + +func TestWWWRedirect(t *testing.T) { + e := echo.New() + next := func(c echo.Context) (err error) { + return c.NoContent(http.StatusOK) + } + req := test.NewRequest(echo.GET, "http://labstack.com", nil) + res := test.NewResponseRecorder() + c := e.NewContext(req, res) + WWWRedirect()(next)(c) + assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, "http://www.labstack.com", res.Header().Get(echo.HeaderLocation)) +} + +func TestNonWWWRedirect(t *testing.T) { + e := echo.New() + next := func(c echo.Context) (err error) { + return c.NoContent(http.StatusOK) + } + req := test.NewRequest(echo.GET, "http://www.labstack.com", nil) + res := test.NewResponseRecorder() + c := e.NewContext(req, res) + NonWWWRedirect()(next)(c) + assert.Equal(t, http.StatusMovedPermanently, res.Status()) + assert.Equal(t, "http://labstack.com", res.Header().Get(echo.HeaderLocation)) +} diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 525633f2..084703d0 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -69,5 +69,4 @@ func TestRemoveTrailingSlash(t *testing.T) { }) h(c) assert.Equal(t, "", req.URL().Path()) - assert.Equal(t, "http://localhost", req.URI()) } diff --git a/test/request.go b/test/request.go index 2aaf6c4a..d3f867fd 100644 --- a/test/request.go +++ b/test/request.go @@ -26,7 +26,6 @@ const ( func NewRequest(method, url string, body io.Reader) engine.Request { r, _ := http.NewRequest(method, url, body) - r.RequestURI = url return &Request{ request: r, url: &URL{url: r.URL},