From 27f9b326b80b2fa21a2e6265b00802122ff093f7 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Wed, 27 Jul 2016 09:34:44 -0700 Subject: [PATCH] Ability to skip a middleware via callback Signed-off-by: Vishal Rana --- middleware/basic_auth.go | 33 +++++++++++++++++++++++++++------ middleware/body_limit.go | 31 ++++++++++++++++++++++++++----- middleware/compress.go | 19 +++++++++++++++---- middleware/cors.go | 11 +++++++++++ middleware/csrf.go | 11 +++++++++++ middleware/jwt.go | 13 ++++++++++++- middleware/logger.go | 31 +++++++++++++++++++++---------- middleware/method_override.go | 25 +++++++++++++++++-------- middleware/middleware.go | 14 ++++++++++++++ middleware/recover.go | 17 ++++++++++++++--- middleware/secure.go | 21 +++++++++++++++++---- middleware/slash.go | 28 ++++++++++++++++++++++++++++ middleware/static.go | 21 ++++++++++++++++----- 13 files changed, 229 insertions(+), 46 deletions(-) create mode 100644 middleware/middleware.go diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index bb083905..4c932e1b 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -7,13 +7,16 @@ import ( ) type ( - // BasicAuthConfig defines the config for HTTP basic auth middleware. + // BasicAuthConfig defines the config for BasicAuth middleware. BasicAuthConfig struct { - // Validator is a function to validate basic auth credentials. + // Skipper defines a function to skip middleware. + Skipper Skipper + + // Validator is a function to validate BasicAuth credentials. Validator BasicAuthValidator } - // BasicAuthValidator defines a function to validate basic auth credentials. + // BasicAuthValidator defines a function to validate BasicAuth credentials. BasicAuthValidator func(string, string) bool ) @@ -21,20 +24,38 @@ const ( basic = "Basic" ) -// BasicAuth returns an HTTP basic auth middleware. +var ( + // DefaultBasicAuthConfig is the default BasicAuth middleware config. + DefaultBasicAuthConfig = BasicAuthConfig{ + Skipper: defaultSkipper, + } +) + +// BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For invalid credentials, it sends "401 - Unauthorized" response. // For empty or invalid `Authorization` header, it sends "400 - Bad Request" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - return BasicAuthWithConfig(BasicAuthConfig{fn}) + c := DefaultBasicAuthConfig + c.Validator = fn + return BasicAuthWithConfig(c) } -// BasicAuthWithConfig returns an HTTP basic auth middleware from config. +// BasicAuthWithConfig returns an BasicAuth middleware from config. // See `BasicAuth()`. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultBasicAuthConfig.Skipper + } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + auth := c.Request().Header().Get(echo.HeaderAuthorization) l := len(basic) diff --git a/middleware/body_limit.go b/middleware/body_limit.go index d804d233..3aa60b24 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -10,8 +10,11 @@ import ( ) type ( - // BodyLimitConfig defines the config for body limit middleware. + // BodyLimitConfig defines the config for BodyLimit middleware. BodyLimitConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Maximum allowed size for a request body, it can be specified // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. Limit string `json:"limit"` @@ -26,21 +29,35 @@ type ( } ) -// BodyLimit returns a body limit middleware. +var ( + // DefaultBodyLimitConfig is the default Gzip middleware config. + DefaultBodyLimitConfig = BodyLimitConfig{ + Skipper: defaultSkipper, + } +) + +// BodyLimit returns a BodyLimit middleware. // // BodyLimit middleware sets the maximum allowed size for a request body, if the // size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The body limit is determined based on both `Content-Length` request +// response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. // Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, // G, T or P. func BodyLimit(limit string) echo.MiddlewareFunc { - return BodyLimitWithConfig(BodyLimitConfig{Limit: limit}) + c := DefaultBodyLimitConfig + c.Limit = limit + return BodyLimitWithConfig(c) } -// BodyLimitWithConfig returns a body limit middleware from config. +// BodyLimitWithConfig returns a BodyLimit middleware from config. // See: `BodyLimit()`. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultBodyLimitConfig.Skipper + } + limit, err := bytes.Parse(config.Limit) if err != nil { panic(fmt.Errorf("invalid body-limit=%s", config.Limit)) @@ -50,6 +67,10 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() // Based on content length diff --git a/middleware/compress.go b/middleware/compress.go index 41e0b836..48599103 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -13,8 +13,11 @@ import ( ) type ( - // GzipConfig defines the config for gzip middleware. + // GzipConfig defines the config for Gzip middleware. GzipConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Gzip compression level. // Optional. Default value -1. Level int `json:"level"` @@ -27,9 +30,10 @@ type ( ) var ( - // DefaultGzipConfig is the default gzip middleware config. + // DefaultGzipConfig is the default Gzip middleware config. DefaultGzipConfig = GzipConfig{ - Level: -1, + Skipper: defaultSkipper, + Level: -1, } ) @@ -39,10 +43,13 @@ func Gzip() echo.MiddlewareFunc { return GzipWithConfig(DefaultGzipConfig) } -// GzipWithConfig return gzip middleware from config. +// GzipWithConfig return Gzip middleware from config. // See: `Gzip()`. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } if config.Level == 0 { config.Level = DefaultGzipConfig.Level } @@ -52,6 +59,10 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header().Get(echo.HeaderAcceptEncoding), scheme) { diff --git a/middleware/cors.go b/middleware/cors.go index 79bcf2cb..2c1b3797 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -11,6 +11,9 @@ import ( type ( // CORSConfig defines the config for CORS middleware. CORSConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // AllowOrigin defines a list of origins that may access the resource. // Optional. Default value []string{"*"}. AllowOrigins []string `json:"allow_origins"` @@ -47,6 +50,7 @@ type ( var ( // DefaultCORSConfig is the default CORS middleware config. DefaultCORSConfig = CORSConfig{ + Skipper: defaultSkipper, AllowOrigins: []string{"*"}, AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.PATCH, echo.POST, echo.DELETE}, } @@ -62,6 +66,9 @@ func CORS() echo.MiddlewareFunc { // See: `CORS()`. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultCORSConfig.Skipper + } if len(config.AllowOrigins) == 0 { config.AllowOrigins = DefaultCORSConfig.AllowOrigins } @@ -75,6 +82,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() res := c.Response() origin := req.Header().Get(echo.HeaderOrigin) diff --git a/middleware/csrf.go b/middleware/csrf.go index b3eef01c..44fdfc23 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -14,6 +14,9 @@ import ( type ( // CSRFConfig defines the config for CSRF middleware. CSRFConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // TokenLength is the length of the generated token. TokenLength uint8 `json:"token_length"` // Optional. Default value 32. @@ -64,6 +67,7 @@ type ( var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ + Skipper: defaultSkipper, TokenLength: 32, TokenLookup: "header:" + echo.HeaderXCSRFToken, ContextKey: "csrf", @@ -83,6 +87,9 @@ func CSRF() echo.MiddlewareFunc { // See `CSRF()`. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultCSRFConfig.Skipper + } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } @@ -111,6 +118,10 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() k, err := c.Cookie(config.CookieName) token := "" diff --git a/middleware/jwt.go b/middleware/jwt.go index d8c7af28..8432cef7 100644 --- a/middleware/jwt.go +++ b/middleware/jwt.go @@ -11,8 +11,11 @@ import ( ) type ( - // JWTConfig defines the config for JWT auth middleware. + // JWTConfig defines the config for JWT middleware. JWTConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Signing key to validate token. // Required. SigningKey []byte `json:"signing_key"` @@ -49,6 +52,7 @@ const ( var ( // DefaultJWTConfig is the default JWT auth middleware config. DefaultJWTConfig = JWTConfig{ + Skipper: defaultSkipper, SigningMethod: AlgorithmHS256, ContextKey: "user", TokenLookup: "header:" + echo.HeaderAuthorization, @@ -72,6 +76,9 @@ func JWT(key []byte) echo.MiddlewareFunc { // See: `JWT()`. func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultJWTConfig.Skipper + } if config.SigningKey == nil { panic("jwt middleware requires signing key") } @@ -95,6 +102,10 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + auth, err := extractor(c) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) diff --git a/middleware/logger.go b/middleware/logger.go index 6ef9d83f..8b069c61 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -16,8 +16,11 @@ import ( ) type ( - // LoggerConfig defines the config for logger middleware. + // LoggerConfig defines the config for Logger middleware. LoggerConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Log format which can be constructed using the following tags: // // - time_rfc3339 @@ -32,8 +35,8 @@ type ( // - status // - latency (In microseconds) // - latency_human (Human readable) - // - rx_bytes (Bytes received) - // - tx_bytes (Bytes sent) + // - bytes_in (Bytes received) + // - bytes_out (Bytes sent) // // Example "${remote_ip} ${status}" // @@ -51,14 +54,15 @@ type ( ) var ( - // DefaultLoggerConfig is the default logger middleware config. + // DefaultLoggerConfig is the default Logger middleware config. DefaultLoggerConfig = LoggerConfig{ + Skipper: defaultSkipper, Format: `{"time":"${time_rfc3339}","remote_ip":"${remote_ip}",` + `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","rx_bytes":${rx_bytes},` + - `"tx_bytes":${tx_bytes}}` + "\n", - color: color.New(), + `"latency_human":"${latency_human}","bytes_in":${bytes_in},` + + `"bytes_out":${bytes_out}}` + "\n", Output: os.Stdout, + color: color.New(), } ) @@ -67,10 +71,13 @@ func Logger() echo.MiddlewareFunc { return LoggerWithConfig(DefaultLoggerConfig) } -// LoggerWithConfig returns a logger middleware from config. +// LoggerWithConfig returns a Logger middleware from config. // See: `Logger()`. func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultLoggerConfig.Skipper + } if config.Format == "" { config.Format = DefaultLoggerConfig.Format } @@ -91,6 +98,10 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { + if config.Skipper(c) { + return next(c) + } + req := c.Request() res := c.Response() start := time.Now() @@ -149,13 +160,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { return w.Write([]byte(strconv.FormatInt(l, 10))) case "latency_human": return w.Write([]byte(stop.Sub(start).String())) - case "rx_bytes": + case "bytes_in": b := req.Header().Get(echo.HeaderContentLength) if b == "" { b = "0" } return w.Write([]byte(b)) - case "tx_bytes": + case "bytes_out": return w.Write([]byte(strconv.FormatInt(res.Size(), 10))) } return 0, nil diff --git a/middleware/method_override.go b/middleware/method_override.go index 7a0744e9..71c93235 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -1,12 +1,13 @@ package middleware -import ( - "github.com/labstack/echo" -) +import "github.com/labstack/echo" type ( - // MethodOverrideConfig defines the config for method override middleware. + // MethodOverrideConfig defines the config for MethodOverride middleware. MethodOverrideConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Getter is a function that gets overridden method from the request. // Optional. Default values MethodFromHeader(echo.HeaderXHTTPMethodOverride). Getter MethodOverrideGetter @@ -17,13 +18,14 @@ type ( ) var ( - // DefaultMethodOverrideConfig is the default method override middleware config. + // DefaultMethodOverrideConfig is the default MethodOverride middleware config. DefaultMethodOverrideConfig = MethodOverrideConfig{ - Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), + Skipper: defaultSkipper, + Getter: MethodFromHeader(echo.HeaderXHTTPMethodOverride), } ) -// MethodOverride returns a method override middleware. +// MethodOverride returns a MethodOverride middleware. // MethodOverride middleware checks for the overridden method from the request and // uses it instead of the original method. // @@ -32,16 +34,23 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a method override middleware from config. +// MethodOverrideWithConfig returns a MethodOverride middleware from config. // See: `MethodOverride()`. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultMethodOverrideConfig.Skipper + } if config.Getter == nil { config.Getter = DefaultMethodOverrideConfig.Getter } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() if req.Method() == echo.POST { m := config.Getter(c) diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 00000000..bf3e07c4 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,14 @@ +package middleware + +import "github.com/labstack/echo" + +type ( + // Skipper defines a function to skip middleware. Returning true skips processing + // the middleware. + Skipper func(c echo.Context) bool +) + +// defaultSkipper returns false which processes the middleware. +func defaultSkipper(c echo.Context) bool { + return false +} diff --git a/middleware/recover.go b/middleware/recover.go index 4dc2ff10..572d6816 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -9,8 +9,11 @@ import ( ) type ( - // RecoverConfig defines the config for recover middleware. + // RecoverConfig defines the config for Recover middleware. RecoverConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Size of the stack to be printed. // Optional. Default value 4KB. StackSize int `json:"stack_size"` @@ -27,8 +30,9 @@ type ( ) var ( - // DefaultRecoverConfig is the default recover middleware config. + // DefaultRecoverConfig is the default Recover middleware config. DefaultRecoverConfig = RecoverConfig{ + Skipper: defaultSkipper, StackSize: 4 << 10, // 4 KB DisableStackAll: false, DisablePrintStack: false, @@ -41,16 +45,23 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a recover middleware from config. +// RecoverWithConfig returns a Recover middleware from config. // See: `Recover()`. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultRecoverConfig.Skipper + } if config.StackSize == 0 { config.StackSize = DefaultRecoverConfig.StackSize } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + defer func() { if r := recover(); r != nil { var err error diff --git a/middleware/secure.go b/middleware/secure.go index 84c00960..c381a375 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -7,8 +7,11 @@ import ( ) type ( - // SecureConfig defines the config for secure middleware. + // SecureConfig defines the config for Secure middleware. SecureConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // XSSProtection provides protection against cross-site scripting attack (XSS) // by setting the `X-XSS-Protection` header. // Optional. Default value "1; mode=block". @@ -54,15 +57,16 @@ type ( ) var ( - // DefaultSecureConfig is the default secure middleware config. + // DefaultSecureConfig is the default Secure middleware config. DefaultSecureConfig = SecureConfig{ + Skipper: defaultSkipper, XSSProtection: "1; mode=block", ContentTypeNosniff: "nosniff", XFrameOptions: "SAMEORIGIN", } ) -// Secure returns a secure middleware. +// Secure returns a Secure middleware. // Secure middleware provides protection against cross-site scripting (XSS) attack, // content type sniffing, clickjacking, insecure connection and other code injection // attacks. @@ -70,11 +74,20 @@ func Secure() echo.MiddlewareFunc { return SecureWithConfig(DefaultSecureConfig) } -// SecureWithConfig returns a secure middleware from config. +// SecureWithConfig returns a Secure middleware from config. // See: `Secure()`. func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultSecureConfig.Skipper + } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() res := c.Response() diff --git a/middleware/slash.go b/middleware/slash.go index 17a78031..2d4eebd7 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -7,12 +7,22 @@ import ( type ( // TrailingSlashConfig defines the config for TrailingSlash middleware. TrailingSlashConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Status code to be used when redirecting the request. // Optional, but when provided the request is redirected using this code. RedirectCode int `json:"redirect_code"` } ) +var ( + // DefaultTrailingSlashConfig is the default TrailingSlash middleware config. + DefaultTrailingSlashConfig = TrailingSlashConfig{ + Skipper: defaultSkipper, + } +) + // AddTrailingSlash returns a root level (before router) middleware which adds a // trailing slash to the request `URL#Path`. // @@ -24,8 +34,17 @@ func AddTrailingSlash() echo.MiddlewareFunc { // AddTrailingSlashWithConfig returns a AddTrailingSlash middleware from config. // See `AddTrailingSlash()`. func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultTrailingSlashConfig.Skipper + } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() url := req.URL() path := url.Path() @@ -62,8 +81,17 @@ func RemoveTrailingSlash() echo.MiddlewareFunc { // RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware from config. // See `RemoveTrailingSlash()`. func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultTrailingSlashConfig.Skipper + } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + req := c.Request() url := req.URL() path := url.Path() diff --git a/middleware/static.go b/middleware/static.go index 8bbba2ed..94314556 100644 --- a/middleware/static.go +++ b/middleware/static.go @@ -10,8 +10,11 @@ import ( ) type ( - // StaticConfig defines the config for static middleware. + // StaticConfig defines the config for Static middleware. StaticConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + // Root directory from where the static content is served. // Required. Root string `json:"root"` @@ -32,13 +35,14 @@ type ( ) var ( - // DefaultStaticConfig is the default static middleware config. + // DefaultStaticConfig is the default Static middleware config. DefaultStaticConfig = StaticConfig{ - Index: "index.html", + Skipper: defaultSkipper, + Index: "index.html", } ) -// Static returns a static middleware to serves static content from the provided +// Static returns a Static middleware to serves static content from the provided // root directory. func Static(root string) echo.MiddlewareFunc { c := DefaultStaticConfig @@ -46,16 +50,23 @@ func Static(root string) echo.MiddlewareFunc { return StaticWithConfig(c) } -// StaticWithConfig returns a static middleware from config. +// StaticWithConfig returns a Static middleware from config. // See `Static()`. func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { // Defaults + if config.Skipper == nil { + config.Skipper = DefaultStaticConfig.Skipper + } if config.Index == "" { config.Index = DefaultStaticConfig.Index } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + fs := http.Dir(config.Root) p := c.Request().URL().Path() if strings.Contains(c.Path(), "*") { // If serving from a group, e.g. `/static*`.