mirror of
https://github.com/labstack/echo.git
synced 2025-01-12 01:22:21 +02:00
Trailing slash middleware with option to redirect
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
6c27cffc66
commit
b9aa2181b0
@ -183,8 +183,8 @@ func TestContext(t *testing.T) {
|
||||
rec = test.NewResponseRecorder()
|
||||
c = NewContext(rq, rec, e)
|
||||
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
|
||||
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
|
||||
assert.Equal(t, http.StatusMovedPermanently, rec.Status())
|
||||
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
|
||||
|
||||
// Error
|
||||
rec = test.NewResponseRecorder()
|
||||
|
@ -38,6 +38,9 @@ type (
|
||||
// URI returns the unmodified `Request-URI` sent by the client.
|
||||
URI() string
|
||||
|
||||
// SetURI sets the URI of the request.
|
||||
SetURI(string)
|
||||
|
||||
// URL returns `engine.URL`.
|
||||
URL() URL
|
||||
|
||||
|
@ -77,6 +77,11 @@ func (r *Request) URI() string {
|
||||
return string(r.RequestURI())
|
||||
}
|
||||
|
||||
// SetURI implements `engine.Request#SetURI` function.
|
||||
func (r *Request) SetURI(uri string) {
|
||||
r.Request.Header.SetRequestURI(uri)
|
||||
}
|
||||
|
||||
// Body implements `engine.Request#Body` function.
|
||||
func (r *Request) Body() io.Reader {
|
||||
return bytes.NewBuffer(r.PostBody())
|
||||
|
@ -89,6 +89,11 @@ func (r *Request) URI() string {
|
||||
return r.RequestURI
|
||||
}
|
||||
|
||||
// SetURI implements `engine.Request#SetURI` function.
|
||||
func (r *Request) SetURI(uri string) {
|
||||
r.RequestURI = uri
|
||||
}
|
||||
|
||||
// Body implements `engine.Request#Body` function.
|
||||
func (r *Request) Body() io.Reader {
|
||||
return r.Request.Body
|
||||
|
@ -4,17 +4,39 @@ import (
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
type (
|
||||
// TrailingSlashConfig defines the config for TrailingSlash middleware.
|
||||
TrailingSlashConfig struct {
|
||||
// RedirectCode is the status code used when redirecting the request.
|
||||
// Optional but when provided the request is redirected using this code.
|
||||
RedirectCode int
|
||||
}
|
||||
)
|
||||
|
||||
// AddTrailingSlash returns a root level (before router) middleware which adds a
|
||||
// trailing slash to the request `URL#Path`.
|
||||
//
|
||||
// Usage `Echo#Pre(AddTrailingSlash())`
|
||||
func AddTrailingSlash() echo.MiddlewareFunc {
|
||||
return AddTrailingSlashWithConfig(TrailingSlashConfig{})
|
||||
}
|
||||
|
||||
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware from config.
|
||||
// See `AddTrailingSlash()`.
|
||||
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
url := c.Request().URL()
|
||||
rq := c.Request()
|
||||
url := rq.URL()
|
||||
path := url.Path()
|
||||
if path != "/" && path[len(path)-1] != '/' {
|
||||
url.SetPath(path + "/")
|
||||
path += "/"
|
||||
uri := path + "?" + url.QueryString()
|
||||
if config.RedirectCode != 0 {
|
||||
return c.Redirect(config.RedirectCode, uri)
|
||||
}
|
||||
rq.SetURI(uri)
|
||||
url.SetPath(path)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
@ -26,13 +48,26 @@ func AddTrailingSlash() echo.MiddlewareFunc {
|
||||
//
|
||||
// Usage `Echo#Pre(RemoveTrailingSlash())`
|
||||
func RemoveTrailingSlash() echo.MiddlewareFunc {
|
||||
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
|
||||
}
|
||||
|
||||
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware from config.
|
||||
// See `RemoveTrailingSlash()`.
|
||||
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
url := c.Request().URL()
|
||||
rq := c.Request()
|
||||
url := rq.URL()
|
||||
path := url.Path()
|
||||
l := len(path) - 1
|
||||
if path != "/" && path[l] == '/' {
|
||||
url.SetPath(path[:l])
|
||||
path = path[:l]
|
||||
uri := path + "?" + url.QueryString()
|
||||
if config.RedirectCode != 0 {
|
||||
return c.Redirect(config.RedirectCode, uri)
|
||||
}
|
||||
rq.SetURI(uri)
|
||||
url.SetPath(path)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
@ -18,6 +19,17 @@ func TestAddTrailingSlash(t *testing.T) {
|
||||
})
|
||||
h(c)
|
||||
assert.Equal(t, "/add-slash/", rq.URL().Path())
|
||||
|
||||
// With config
|
||||
rq = test.NewRequest(echo.GET, "/add-slash?key=value", nil)
|
||||
rc = test.NewResponseRecorder()
|
||||
c = echo.NewContext(rq, rc, e)
|
||||
h = AddTrailingSlashWithConfig(TrailingSlashConfig{RedirectCode: http.StatusMovedPermanently})(func(c echo.Context) error {
|
||||
return nil
|
||||
})
|
||||
h(c)
|
||||
assert.Equal(t, http.StatusMovedPermanently, rc.Status())
|
||||
assert.Equal(t, "/add-slash/?key=value", rc.Header().Get(echo.HeaderLocation))
|
||||
}
|
||||
|
||||
func TestRemoveTrailingSlash(t *testing.T) {
|
||||
@ -30,4 +42,15 @@ func TestRemoveTrailingSlash(t *testing.T) {
|
||||
})
|
||||
h(c)
|
||||
assert.Equal(t, "/remove-slash", rq.URL().Path())
|
||||
|
||||
// With config
|
||||
rq = test.NewRequest(echo.GET, "/remove-slash/?key=value", nil)
|
||||
rc = test.NewResponseRecorder()
|
||||
c = echo.NewContext(rq, rc, e)
|
||||
h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{RedirectCode: http.StatusMovedPermanently})(func(c echo.Context) error {
|
||||
return nil
|
||||
})
|
||||
h(c)
|
||||
assert.Equal(t, http.StatusMovedPermanently, rc.Status())
|
||||
assert.Equal(t, "/remove-slash?key=value", rc.Header().Get(echo.HeaderLocation))
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ type (
|
||||
|
||||
func NewRequest(method, url string, body io.Reader) engine.Request {
|
||||
r, _ := http.NewRequest(method, url, body)
|
||||
r.RequestURI = url
|
||||
return &Request{
|
||||
request: r,
|
||||
url: &URL{url: r.URL},
|
||||
@ -84,6 +85,10 @@ func (r *Request) URI() string {
|
||||
return r.request.RequestURI
|
||||
}
|
||||
|
||||
func (r *Request) SetURI(uri string) {
|
||||
r.request.RequestURI = uri
|
||||
}
|
||||
|
||||
func (r *Request) Body() io.Reader {
|
||||
return r.request.Body
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user