1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-26 03:20:08 +02:00

Trailing slash middleware with option to redirect

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-04-12 22:39:29 -07:00
parent 6c27cffc66
commit b9aa2181b0
7 changed files with 81 additions and 5 deletions

View File

@ -183,8 +183,8 @@ func TestContext(t *testing.T) {
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = NewContext(rq, rec, e)
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
assert.Equal(t, http.StatusMovedPermanently, rec.Status()) assert.Equal(t, http.StatusMovedPermanently, rec.Status())
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
// Error // Error
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()

View File

@ -38,6 +38,9 @@ type (
// URI returns the unmodified `Request-URI` sent by the client. // URI returns the unmodified `Request-URI` sent by the client.
URI() string URI() string
// SetURI sets the URI of the request.
SetURI(string)
// URL returns `engine.URL`. // URL returns `engine.URL`.
URL() URL URL() URL

View File

@ -77,6 +77,11 @@ func (r *Request) URI() string {
return string(r.RequestURI()) return string(r.RequestURI())
} }
// SetURI implements `engine.Request#SetURI` function.
func (r *Request) SetURI(uri string) {
r.Request.Header.SetRequestURI(uri)
}
// Body implements `engine.Request#Body` function. // Body implements `engine.Request#Body` function.
func (r *Request) Body() io.Reader { func (r *Request) Body() io.Reader {
return bytes.NewBuffer(r.PostBody()) return bytes.NewBuffer(r.PostBody())

View File

@ -89,6 +89,11 @@ func (r *Request) URI() string {
return r.RequestURI return r.RequestURI
} }
// SetURI implements `engine.Request#SetURI` function.
func (r *Request) SetURI(uri string) {
r.RequestURI = uri
}
// Body implements `engine.Request#Body` function. // Body implements `engine.Request#Body` function.
func (r *Request) Body() io.Reader { func (r *Request) Body() io.Reader {
return r.Request.Body return r.Request.Body

View File

@ -4,17 +4,39 @@ import (
"github.com/labstack/echo" "github.com/labstack/echo"
) )
type (
// TrailingSlashConfig defines the config for TrailingSlash middleware.
TrailingSlashConfig struct {
// RedirectCode is the status code used when redirecting the request.
// Optional but when provided the request is redirected using this code.
RedirectCode int
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a // AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`. // trailing slash to the request `URL#Path`.
// //
// Usage `Echo#Pre(AddTrailingSlash())` // Usage `Echo#Pre(AddTrailingSlash())`
func AddTrailingSlash() echo.MiddlewareFunc { func AddTrailingSlash() echo.MiddlewareFunc {
return AddTrailingSlashWithConfig(TrailingSlashConfig{})
}
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware from config.
// See `AddTrailingSlash()`.
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
url := c.Request().URL() rq := c.Request()
url := rq.URL()
path := url.Path() path := url.Path()
if path != "/" && path[len(path)-1] != '/' { if path != "/" && path[len(path)-1] != '/' {
url.SetPath(path + "/") path += "/"
uri := path + "?" + url.QueryString()
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
}
rq.SetURI(uri)
url.SetPath(path)
} }
return next(c) return next(c)
} }
@ -26,13 +48,26 @@ func AddTrailingSlash() echo.MiddlewareFunc {
// //
// Usage `Echo#Pre(RemoveTrailingSlash())` // Usage `Echo#Pre(RemoveTrailingSlash())`
func RemoveTrailingSlash() echo.MiddlewareFunc { func RemoveTrailingSlash() echo.MiddlewareFunc {
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
}
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware from config.
// See `RemoveTrailingSlash()`.
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
url := c.Request().URL() rq := c.Request()
url := rq.URL()
path := url.Path() path := url.Path()
l := len(path) - 1 l := len(path) - 1
if path != "/" && path[l] == '/' { if path != "/" && path[l] == '/' {
url.SetPath(path[:l]) path = path[:l]
uri := path + "?" + url.QueryString()
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
}
rq.SetURI(uri)
url.SetPath(path)
} }
return next(c) return next(c)
} }

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"net/http"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
@ -18,6 +19,17 @@ func TestAddTrailingSlash(t *testing.T) {
}) })
h(c) h(c)
assert.Equal(t, "/add-slash/", rq.URL().Path()) assert.Equal(t, "/add-slash/", rq.URL().Path())
// With config
rq = test.NewRequest(echo.GET, "/add-slash?key=value", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
h = AddTrailingSlashWithConfig(TrailingSlashConfig{RedirectCode: http.StatusMovedPermanently})(func(c echo.Context) error {
return nil
})
h(c)
assert.Equal(t, http.StatusMovedPermanently, rc.Status())
assert.Equal(t, "/add-slash/?key=value", rc.Header().Get(echo.HeaderLocation))
} }
func TestRemoveTrailingSlash(t *testing.T) { func TestRemoveTrailingSlash(t *testing.T) {
@ -30,4 +42,15 @@ func TestRemoveTrailingSlash(t *testing.T) {
}) })
h(c) h(c)
assert.Equal(t, "/remove-slash", rq.URL().Path()) assert.Equal(t, "/remove-slash", rq.URL().Path())
// With config
rq = test.NewRequest(echo.GET, "/remove-slash/?key=value", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{RedirectCode: http.StatusMovedPermanently})(func(c echo.Context) error {
return nil
})
h(c)
assert.Equal(t, http.StatusMovedPermanently, rc.Status())
assert.Equal(t, "/remove-slash?key=value", rc.Header().Get(echo.HeaderLocation))
} }

View File

@ -18,6 +18,7 @@ type (
func NewRequest(method, url string, body io.Reader) engine.Request { func NewRequest(method, url string, body io.Reader) engine.Request {
r, _ := http.NewRequest(method, url, body) r, _ := http.NewRequest(method, url, body)
r.RequestURI = url
return &Request{ return &Request{
request: r, request: r,
url: &URL{url: r.URL}, url: &URL{url: r.URL},
@ -84,6 +85,10 @@ func (r *Request) URI() string {
return r.request.RequestURI return r.request.RequestURI
} }
func (r *Request) SetURI(uri string) {
r.request.RequestURI = uri
}
func (r *Request) Body() io.Reader { func (r *Request) Body() io.Reader {
return r.request.Body return r.request.Body
} }