From 6424d779dc8419de44baa27c8a11d8acce979614 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Tue, 3 May 2016 08:32:28 -0700 Subject: [PATCH] Added test for secure middleware Signed-off-by: Vishal Rana --- echo.go | 1 + middleware/cors.go | 30 +++++------ middleware/method_override.go | 7 ++- middleware/method_override_test.go | 2 +- middleware/secure.go | 80 ++++++++++++++++++++++++------ middleware/secure_test.go | 73 ++++++++++++++++----------- 6 files changed, 130 insertions(+), 63 deletions(-) diff --git a/echo.go b/echo.go index dcc7b12c..ec002e45 100644 --- a/echo.go +++ b/echo.go @@ -153,6 +153,7 @@ const ( HeaderUpgrade = "Upgrade" HeaderVary = "Vary" HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXHTTPMethodOverride = "X-HTTP-Method-Override" HeaderXForwardedFor = "X-Forwarded-For" HeaderXRealIP = "X-Real-IP" diff --git a/middleware/cors.go b/middleware/cors.go index 54025a10..d6724062 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -76,8 +76,8 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { req := c.Request() - origin := c.Request().Header().Get(echo.HeaderOrigin) - header := c.Response().Header() + res := c.Response() + origin := req.Header().Get(echo.HeaderOrigin) // Check allowed origins allowedOrigin := "" @@ -90,42 +90,42 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Simple request if req.Method() != echo.OPTIONS { - header.Add(echo.HeaderVary, echo.HeaderOrigin) + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) if origin == "" || allowedOrigin == "" { return next(c) } - header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) if config.AllowCredentials { - header.Set(echo.HeaderAccessControlAllowCredentials, "true") + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") } if exposeHeaders != "" { - header.Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) + res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } return next(c) } // Preflight request - header.Add(echo.HeaderVary, echo.HeaderOrigin) - header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) - header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) + res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) if origin == "" || allowedOrigin == "" { return next(c) } - header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) - header.Set(echo.HeaderAccessControlAllowMethods, allowMethods) + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) if config.AllowCredentials { - header.Set(echo.HeaderAccessControlAllowCredentials, "true") + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") } if allowHeaders != "" { - header.Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) + res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) } else { h := req.Header().Get(echo.HeaderAccessControlRequestHeaders) if h != "" { - header.Set(echo.HeaderAccessControlAllowHeaders, h) + res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) } } if config.MaxAge > 0 { - header.Set(echo.HeaderAccessControlMaxAge, maxAge) + res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge) } return c.NoContent(http.StatusNoContent) } diff --git a/middleware/method_override.go b/middleware/method_override.go index 27f6b996..4512b899 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -35,13 +35,18 @@ func MethodOverride() echo.MiddlewareFunc { // MethodOverrideWithConfig returns a method override middleware from config. // See `MethodOverride()`. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + // Defaults + if config.Getter == nil { + config.Getter = DefaultMethodOverrideConfig.Getter + } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { req := c.Request() if req.Method() == echo.POST { m := config.Getter(c) if m != "" { - c.Request().SetMethod(m) + req.SetMethod(m) } } return next(c) diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 0fccd4eb..964ed1a7 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -14,7 +14,7 @@ func TestMethodOverride(t *testing.T) { e := echo.New() m := MethodOverride() h := func(c echo.Context) error { - return c.String(http.StatusOK, "Okay") + return c.String(http.StatusOK, "test") } // Override with http header diff --git a/middleware/secure.go b/middleware/secure.go index c3e9e84f..eed75d8e 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -7,47 +7,95 @@ import ( ) type ( + // SecureConfig defines the config for secure middleware. SecureConfig struct { - DisableXSSProtection bool - DisableContentTypeNosniff bool - XFrameOptions string - DisableHSTSIncludeSubdomains bool - HSTSMaxAge int - ContentSecurityPolicy string + // XSSProtection provides protection against cross-site scripting attack (XSS) + // by setting the `X-XSS-Protection` header. + // Optional, with default value as `1; mode=block`. + XSSProtection string + + // ContentTypeNosniff provides protection against overriding Content-Type + // header by setting the `X-Content-Type-Options` header. + // Optional, with default value as "nosniff". + ContentTypeNosniff string + + // XFrameOptions can be used to indicate whether or not a browser should + // be allowed to render a page in a ,