2016-03-31 21:17:18 +02:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
|
|
|
"github.com/labstack/echo"
|
|
|
|
)
|
|
|
|
|
2016-04-13 07:39:29 +02:00
|
|
|
type (
|
|
|
|
// TrailingSlashConfig defines the config for TrailingSlash middleware.
|
|
|
|
TrailingSlashConfig struct {
|
2016-07-27 18:34:44 +02:00
|
|
|
// Skipper defines a function to skip middleware.
|
|
|
|
Skipper Skipper
|
|
|
|
|
2016-05-10 20:52:04 +02:00
|
|
|
// Status code to be used when redirecting the request.
|
2016-04-20 16:32:51 +02:00
|
|
|
// Optional, but when provided the request is redirected using this code.
|
2016-05-19 03:53:54 +02:00
|
|
|
RedirectCode int `json:"redirect_code"`
|
2016-04-13 07:39:29 +02:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2016-07-27 18:34:44 +02:00
|
|
|
var (
|
|
|
|
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
|
|
|
|
DefaultTrailingSlashConfig = TrailingSlashConfig{
|
|
|
|
Skipper: defaultSkipper,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2016-03-31 21:17:18 +02:00
|
|
|
// AddTrailingSlash returns a root level (before router) middleware which adds a
|
|
|
|
// trailing slash to the request `URL#Path`.
|
|
|
|
//
|
|
|
|
// Usage `Echo#Pre(AddTrailingSlash())`
|
|
|
|
func AddTrailingSlash() echo.MiddlewareFunc {
|
2016-09-01 05:10:14 +02:00
|
|
|
return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig)
|
2016-04-13 07:39:29 +02:00
|
|
|
}
|
|
|
|
|
2016-09-01 05:10:14 +02:00
|
|
|
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config.
|
2016-04-13 07:39:29 +02:00
|
|
|
// See `AddTrailingSlash()`.
|
|
|
|
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
2016-07-27 18:34:44 +02:00
|
|
|
// Defaults
|
|
|
|
if config.Skipper == nil {
|
|
|
|
config.Skipper = DefaultTrailingSlashConfig.Skipper
|
|
|
|
}
|
|
|
|
|
2016-04-02 23:19:39 +02:00
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
|
|
return func(c echo.Context) error {
|
2016-07-27 18:34:44 +02:00
|
|
|
if config.Skipper(c) {
|
|
|
|
return next(c)
|
|
|
|
}
|
|
|
|
|
2016-04-24 19:21:23 +02:00
|
|
|
req := c.Request()
|
|
|
|
url := req.URL()
|
2016-03-31 21:17:18 +02:00
|
|
|
path := url.Path()
|
2016-04-13 22:48:33 +02:00
|
|
|
qs := url.QueryString()
|
2016-03-31 21:17:18 +02:00
|
|
|
if path != "/" && path[len(path)-1] != '/' {
|
2016-04-13 07:39:29 +02:00
|
|
|
path += "/"
|
2016-04-13 22:48:33 +02:00
|
|
|
uri := path
|
|
|
|
if qs != "" {
|
|
|
|
uri += "?" + qs
|
|
|
|
}
|
2016-05-23 20:23:15 +02:00
|
|
|
|
|
|
|
// Redirect
|
2016-04-13 07:39:29 +02:00
|
|
|
if config.RedirectCode != 0 {
|
|
|
|
return c.Redirect(config.RedirectCode, uri)
|
|
|
|
}
|
2016-05-23 20:23:15 +02:00
|
|
|
|
|
|
|
// Forward
|
2016-04-24 19:21:23 +02:00
|
|
|
req.SetURI(uri)
|
2016-04-13 07:39:29 +02:00
|
|
|
url.SetPath(path)
|
2016-03-31 21:17:18 +02:00
|
|
|
}
|
2016-04-02 23:19:39 +02:00
|
|
|
return next(c)
|
|
|
|
}
|
2016-03-31 21:17:18 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// RemoveTrailingSlash returns a root level (before router) middleware which removes
|
|
|
|
// a trailing slash from the request URI.
|
|
|
|
//
|
|
|
|
// Usage `Echo#Pre(RemoveTrailingSlash())`
|
|
|
|
func RemoveTrailingSlash() echo.MiddlewareFunc {
|
2016-04-13 07:39:29 +02:00
|
|
|
return RemoveTrailingSlashWithConfig(TrailingSlashConfig{})
|
|
|
|
}
|
|
|
|
|
2016-09-01 05:10:14 +02:00
|
|
|
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config.
|
2016-04-13 07:39:29 +02:00
|
|
|
// See `RemoveTrailingSlash()`.
|
|
|
|
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
|
2016-07-27 18:34:44 +02:00
|
|
|
// Defaults
|
|
|
|
if config.Skipper == nil {
|
|
|
|
config.Skipper = DefaultTrailingSlashConfig.Skipper
|
|
|
|
}
|
|
|
|
|
2016-04-02 23:19:39 +02:00
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
|
|
return func(c echo.Context) error {
|
2016-07-27 18:34:44 +02:00
|
|
|
if config.Skipper(c) {
|
|
|
|
return next(c)
|
|
|
|
}
|
|
|
|
|
2016-04-24 19:21:23 +02:00
|
|
|
req := c.Request()
|
|
|
|
url := req.URL()
|
2016-03-31 21:17:18 +02:00
|
|
|
path := url.Path()
|
2016-04-13 22:48:33 +02:00
|
|
|
qs := url.QueryString()
|
2016-03-31 21:17:18 +02:00
|
|
|
l := len(path) - 1
|
2016-06-01 04:10:34 +02:00
|
|
|
if l >= 0 && path != "/" && path[l] == '/' {
|
2016-04-13 07:39:29 +02:00
|
|
|
path = path[:l]
|
2016-04-13 22:48:33 +02:00
|
|
|
uri := path
|
|
|
|
if qs != "" {
|
|
|
|
uri += "?" + qs
|
|
|
|
}
|
2016-05-23 20:23:15 +02:00
|
|
|
|
|
|
|
// Redirect
|
2016-04-13 07:39:29 +02:00
|
|
|
if config.RedirectCode != 0 {
|
|
|
|
return c.Redirect(config.RedirectCode, uri)
|
|
|
|
}
|
2016-05-23 20:23:15 +02:00
|
|
|
|
|
|
|
// Forward
|
2016-04-24 19:21:23 +02:00
|
|
|
req.SetURI(uri)
|
2016-04-13 07:39:29 +02:00
|
|
|
url.SetPath(path)
|
2016-03-31 21:17:18 +02:00
|
|
|
}
|
2016-04-02 23:19:39 +02:00
|
|
|
return next(c)
|
|
|
|
}
|
2016-03-31 21:17:18 +02:00
|
|
|
}
|
|
|
|
}
|