1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Ability to skip a middleware via callback

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-07-27 09:34:44 -07:00
parent eb48e7de60
commit 27f9b326b8
13 changed files with 229 additions and 46 deletions

View File

@ -7,13 +7,16 @@ import (
)
type (
// BasicAuthConfig defines the config for HTTP basic auth middleware.
// BasicAuthConfig defines the config for BasicAuth middleware.
BasicAuthConfig struct {
// Validator is a function to validate basic auth credentials.
// Skipper defines a function to skip middleware.
Skipper Skipper
// Validator is a function to validate BasicAuth credentials.
Validator BasicAuthValidator
}
// BasicAuthValidator defines a function to validate basic auth credentials.
// BasicAuthValidator defines a function to validate BasicAuth credentials.
BasicAuthValidator func(string, string) bool
)
@ -21,20 +24,38 @@ const (
basic = "Basic"
)
// BasicAuth returns an HTTP basic auth middleware.
var (
// DefaultBasicAuthConfig is the default BasicAuth middleware config.
DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: defaultSkipper,
}
)
// BasicAuth returns an BasicAuth middleware.
//
// For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response.
// For empty or invalid `Authorization` header, it sends "400 - Bad Request" response.
func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
return BasicAuthWithConfig(BasicAuthConfig{fn})
c := DefaultBasicAuthConfig
c.Validator = fn
return BasicAuthWithConfig(c)
}
// BasicAuthWithConfig returns an HTTP basic auth middleware from config.
// BasicAuthWithConfig returns an BasicAuth middleware from config.
// See `BasicAuth()`.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
auth := c.Request().Header().Get(echo.HeaderAuthorization)
l := len(basic)

View File

@ -10,8 +10,11 @@ import (
)
type (
// BodyLimitConfig defines the config for body limit middleware.
// BodyLimitConfig defines the config for BodyLimit middleware.
BodyLimitConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Maximum allowed size for a request body, it can be specified
// as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P.
Limit string `json:"limit"`
@ -26,21 +29,35 @@ type (
}
)
// BodyLimit returns a body limit middleware.
var (
// DefaultBodyLimitConfig is the default Gzip middleware config.
DefaultBodyLimitConfig = BodyLimitConfig{
Skipper: defaultSkipper,
}
)
// BodyLimit returns a BodyLimit middleware.
//
// BodyLimit middleware sets the maximum allowed size for a request body, if the
// size exceeds the configured limit, it sends "413 - Request Entity Too Large"
// response. The body limit is determined based on both `Content-Length` request
// response. The BodyLimit is determined based on both `Content-Length` request
// header and actual content read, which makes it super secure.
// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M,
// G, T or P.
func BodyLimit(limit string) echo.MiddlewareFunc {
return BodyLimitWithConfig(BodyLimitConfig{Limit: limit})
c := DefaultBodyLimitConfig
c.Limit = limit
return BodyLimitWithConfig(c)
}
// BodyLimitWithConfig returns a body limit middleware from config.
// BodyLimitWithConfig returns a BodyLimit middleware from config.
// See: `BodyLimit()`.
func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultBodyLimitConfig.Skipper
}
limit, err := bytes.Parse(config.Limit)
if err != nil {
panic(fmt.Errorf("invalid body-limit=%s", config.Limit))
@ -50,6 +67,10 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
// Based on content length

View File

@ -13,8 +13,11 @@ import (
)
type (
// GzipConfig defines the config for gzip middleware.
// GzipConfig defines the config for Gzip middleware.
GzipConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Gzip compression level.
// Optional. Default value -1.
Level int `json:"level"`
@ -27,9 +30,10 @@ type (
)
var (
// DefaultGzipConfig is the default gzip middleware config.
// DefaultGzipConfig is the default Gzip middleware config.
DefaultGzipConfig = GzipConfig{
Level: -1,
Skipper: defaultSkipper,
Level: -1,
}
)
@ -39,10 +43,13 @@ func Gzip() echo.MiddlewareFunc {
return GzipWithConfig(DefaultGzipConfig)
}
// GzipWithConfig return gzip middleware from config.
// GzipWithConfig return Gzip middleware from config.
// See: `Gzip()`.
func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.Level == 0 {
config.Level = DefaultGzipConfig.Level
}
@ -52,6 +59,10 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header().Get(echo.HeaderAcceptEncoding), scheme) {

View File

@ -11,6 +11,9 @@ import (
type (
// CORSConfig defines the config for CORS middleware.
CORSConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}.
AllowOrigins []string `json:"allow_origins"`
@ -47,6 +50,7 @@ type (
var (
// DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{
Skipper: defaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.PATCH, echo.POST, echo.DELETE},
}
@ -62,6 +66,9 @@ func CORS() echo.MiddlewareFunc {
// See: `CORS()`.
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper
}
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
@ -75,6 +82,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
origin := req.Header().Get(echo.HeaderOrigin)

View File

@ -14,6 +14,9 @@ import (
type (
// CSRFConfig defines the config for CSRF middleware.
CSRFConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// TokenLength is the length of the generated token.
TokenLength uint8 `json:"token_length"`
// Optional. Default value 32.
@ -64,6 +67,7 @@ type (
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
Skipper: defaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
@ -83,6 +87,9 @@ func CSRF() echo.MiddlewareFunc {
// See `CSRF()`.
func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultCSRFConfig.Skipper
}
if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength
}
@ -111,6 +118,10 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
k, err := c.Cookie(config.CookieName)
token := ""

View File

@ -11,8 +11,11 @@ import (
)
type (
// JWTConfig defines the config for JWT auth middleware.
// JWTConfig defines the config for JWT middleware.
JWTConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Signing key to validate token.
// Required.
SigningKey []byte `json:"signing_key"`
@ -49,6 +52,7 @@ const (
var (
// DefaultJWTConfig is the default JWT auth middleware config.
DefaultJWTConfig = JWTConfig{
Skipper: defaultSkipper,
SigningMethod: AlgorithmHS256,
ContextKey: "user",
TokenLookup: "header:" + echo.HeaderAuthorization,
@ -72,6 +76,9 @@ func JWT(key []byte) echo.MiddlewareFunc {
// See: `JWT()`.
func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultJWTConfig.Skipper
}
if config.SigningKey == nil {
panic("jwt middleware requires signing key")
}
@ -95,6 +102,10 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
auth, err := extractor(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, err.Error())

View File

@ -16,8 +16,11 @@ import (
)
type (
// LoggerConfig defines the config for logger middleware.
// LoggerConfig defines the config for Logger middleware.
LoggerConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Log format which can be constructed using the following tags:
//
// - time_rfc3339
@ -32,8 +35,8 @@ type (
// - status
// - latency (In microseconds)
// - latency_human (Human readable)
// - rx_bytes (Bytes received)
// - tx_bytes (Bytes sent)
// - bytes_in (Bytes received)
// - bytes_out (Bytes sent)
//
// Example "${remote_ip} ${status}"
//
@ -51,14 +54,15 @@ type (
)
var (
// DefaultLoggerConfig is the default logger middleware config.
// DefaultLoggerConfig is the default Logger middleware config.
DefaultLoggerConfig = LoggerConfig{
Skipper: defaultSkipper,
Format: `{"time":"${time_rfc3339}","remote_ip":"${remote_ip}",` +
`"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` +
`"latency_human":"${latency_human}","rx_bytes":${rx_bytes},` +
`"tx_bytes":${tx_bytes}}` + "\n",
color: color.New(),
`"latency_human":"${latency_human}","bytes_in":${bytes_in},` +
`"bytes_out":${bytes_out}}` + "\n",
Output: os.Stdout,
color: color.New(),
}
)
@ -67,10 +71,13 @@ func Logger() echo.MiddlewareFunc {
return LoggerWithConfig(DefaultLoggerConfig)
}
// LoggerWithConfig returns a logger middleware from config.
// LoggerWithConfig returns a Logger middleware from config.
// See: `Logger()`.
func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultLoggerConfig.Skipper
}
if config.Format == "" {
config.Format = DefaultLoggerConfig.Format
}
@ -91,6 +98,10 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()
start := time.Now()
@ -149,13 +160,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
return w.Write([]byte(strconv.FormatInt(l, 10)))
case "latency_human":
return w.Write([]byte(stop.Sub(start).String()))
case "rx_bytes":
case "bytes_in":
b := req.Header().Get(echo.HeaderContentLength)
if b == "" {
b = "0"
}
return w.Write([]byte(b))
case "tx_bytes":
case "bytes_out":
return w.Write([]byte(strconv.FormatInt(res.Size(), 10)))
}
return 0, nil

View File

@ -1,12 +1,13 @@
package middleware
import (
"github.com/labstack/echo"
)
import "github.com/labstack/echo"
type (
// MethodOverrideConfig defines the config for method override middleware.
// MethodOverrideConfig defines the config for MethodOverride middleware.
MethodOverrideConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Getter is a function that gets overridden method from the request.
// Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride).
Getter MethodOverrideGetter
@ -17,13 +18,14 @@ type (
)
var (
// DefaultMethodOverrideConfig is the default method override middleware config.
// DefaultMethodOverrideConfig is the default MethodOverride middleware config.
DefaultMethodOverrideConfig = MethodOverrideConfig{
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
Skipper: defaultSkipper,
Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride),
}
)
// MethodOverride returns a method override middleware.
// MethodOverride returns a MethodOverride middleware.
// MethodOverride middleware checks for the overridden method from the request and
// uses it instead of the original method.
//
@ -32,16 +34,23 @@ func MethodOverride() echo.MiddlewareFunc {
return MethodOverrideWithConfig(DefaultMethodOverrideConfig)
}
// MethodOverrideWithConfig returns a method override middleware from config.
// MethodOverrideWithConfig returns a MethodOverride middleware from config.
// See: `MethodOverride()`.
func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultMethodOverrideConfig.Skipper
}
if config.Getter == nil {
config.Getter = DefaultMethodOverrideConfig.Getter
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
if req.Method() == echo.POST {
m := config.Getter(c)

14
middleware/middleware.go Normal file
View File

@ -0,0 +1,14 @@
package middleware
import "github.com/labstack/echo"
type (
// Skipper defines a function to skip middleware. Returning true skips processing
// the middleware.
Skipper func(c echo.Context) bool
)
// defaultSkipper returns false which processes the middleware.
func defaultSkipper(c echo.Context) bool {
return false
}

View File

@ -9,8 +9,11 @@ import (
)
type (
// RecoverConfig defines the config for recover middleware.
// RecoverConfig defines the config for Recover middleware.
RecoverConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Size of the stack to be printed.
// Optional. Default value 4KB.
StackSize int `json:"stack_size"`
@ -27,8 +30,9 @@ type (
)
var (
// DefaultRecoverConfig is the default recover middleware config.
// DefaultRecoverConfig is the default Recover middleware config.
DefaultRecoverConfig = RecoverConfig{
Skipper: defaultSkipper,
StackSize: 4 << 10, // 4 KB
DisableStackAll: false,
DisablePrintStack: false,
@ -41,16 +45,23 @@ func Recover() echo.MiddlewareFunc {
return RecoverWithConfig(DefaultRecoverConfig)
}
// RecoverWithConfig returns a recover middleware from config.
// RecoverWithConfig returns a Recover middleware from config.
// See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultRecoverConfig.Skipper
}
if config.StackSize == 0 {
config.StackSize = DefaultRecoverConfig.StackSize
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
defer func() {
if r := recover(); r != nil {
var err error

View File

@ -7,8 +7,11 @@ import (
)
type (
// SecureConfig defines the config for secure middleware.
// SecureConfig defines the config for Secure middleware.
SecureConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// XSSProtection provides protection against cross-site scripting attack (XSS)
// by setting the `X-XSS-Protection` header.
// Optional. Default value "1; mode=block".
@ -54,15 +57,16 @@ type (
)
var (
// DefaultSecureConfig is the default secure middleware config.
// DefaultSecureConfig is the default Secure middleware config.
DefaultSecureConfig = SecureConfig{
Skipper: defaultSkipper,
XSSProtection: "1; mode=block",
ContentTypeNosniff: "nosniff",
XFrameOptions: "SAMEORIGIN",
}
)
// Secure returns a secure middleware.
// Secure returns a Secure middleware.
// Secure middleware provides protection against cross-site scripting (XSS) attack,
// content type sniffing, clickjacking, insecure connection and other code injection
// attacks.
@ -70,11 +74,20 @@ func Secure() echo.MiddlewareFunc {
return SecureWithConfig(DefaultSecureConfig)
}
// SecureWithConfig returns a secure middleware from config.
// SecureWithConfig returns a Secure middleware from config.
// See: `Secure()`.
func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultSecureConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
res := c.Response()

View File

@ -7,12 +7,22 @@ import (
type (
// TrailingSlashConfig defines the config for TrailingSlash middleware.
TrailingSlashConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Status code to be used when redirecting the request.
// Optional, but when provided the request is redirected using this code.
RedirectCode int `json:"redirect_code"`
}
)
var (
// DefaultTrailingSlashConfig is the default TrailingSlash middleware config.
DefaultTrailingSlashConfig = TrailingSlashConfig{
Skipper: defaultSkipper,
}
)
// AddTrailingSlash returns a root level (before router) middleware which adds a
// trailing slash to the request `URL#Path`.
//
@ -24,8 +34,17 @@ func AddTrailingSlash() echo.MiddlewareFunc {
// AddTrailingSlashWithConfig returns a AddTrailingSlash middleware from config.
// See `AddTrailingSlash()`.
func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
url := req.URL()
path := url.Path()
@ -62,8 +81,17 @@ func RemoveTrailingSlash() echo.MiddlewareFunc {
// RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware from config.
// See `RemoveTrailingSlash()`.
func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTrailingSlashConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
req := c.Request()
url := req.URL()
path := url.Path()

View File

@ -10,8 +10,11 @@ import (
)
type (
// StaticConfig defines the config for static middleware.
// StaticConfig defines the config for Static middleware.
StaticConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// Root directory from where the static content is served.
// Required.
Root string `json:"root"`
@ -32,13 +35,14 @@ type (
)
var (
// DefaultStaticConfig is the default static middleware config.
// DefaultStaticConfig is the default Static middleware config.
DefaultStaticConfig = StaticConfig{
Index: "index.html",
Skipper: defaultSkipper,
Index: "index.html",
}
)
// Static returns a static middleware to serves static content from the provided
// Static returns a Static middleware to serves static content from the provided
// root directory.
func Static(root string) echo.MiddlewareFunc {
c := DefaultStaticConfig
@ -46,16 +50,23 @@ func Static(root string) echo.MiddlewareFunc {
return StaticWithConfig(c)
}
// StaticWithConfig returns a static middleware from config.
// StaticWithConfig returns a Static middleware from config.
// See `Static()`.
func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultStaticConfig.Skipper
}
if config.Index == "" {
config.Index = DefaultStaticConfig.Index
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
fs := http.Dir(config.Root)
p := c.Request().URL().Path()
if strings.Contains(c.Path(), "*") { // If serving from a group, e.g. `/static*`.