From a3352d880c61bdd1310ef7eefb606204dbd90e75 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Thu, 7 Apr 2016 16:16:58 -0700 Subject: [PATCH] Fixed #454, Fixed #274 Signed-off-by: Vishal Rana --- echo.go | 37 +++++++---- middleware/cors.go | 133 ++++++++++++++++++++++++++++++++++++++ middleware/cors_test.go | 63 ++++++++++++++++++ middleware/static_test.go | 1 + 4 files changed, 220 insertions(+), 14 deletions(-) create mode 100644 middleware/cors.go create mode 100644 middleware/cors_test.go create mode 100644 middleware/static_test.go diff --git a/echo.go b/echo.go index f2621b2f..3be0edb2 100644 --- a/echo.go +++ b/echo.go @@ -152,20 +152,29 @@ const ( // Headers const ( - HeaderAcceptEncoding = "Accept-Encoding" - HeaderAuthorization = "Authorization" - HeaderContentDisposition = "Content-Disposition" - HeaderContentEncoding = "Content-Encoding" - HeaderContentLength = "Content-Length" - HeaderContentType = "Content-Type" - HeaderIfModifiedSince = "If-Modified-Since" - HeaderLastModified = "Last-Modified" - HeaderLocation = "Location" - HeaderUpgrade = "Upgrade" - HeaderVary = "Vary" - HeaderWWWAuthenticate = "WWW-Authenticate" - HeaderXForwardedFor = "X-Forwarded-For" - HeaderXRealIP = "X-Real-IP" + HeaderAcceptEncoding = "Accept-Encoding" + HeaderAuthorization = "Authorization" + HeaderContentDisposition = "Content-Disposition" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderLastModified = "Last-Modified" + HeaderLocation = "Location" + HeaderUpgrade = "Upgrade" + HeaderVary = "Vary" + HeaderWWWAuthenticate = "WWW-Authenticate" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXRealIP = "X-Real-IP" + HeaderOrigin = "Origin" + HeaderAccessControlRequestMethod = "Access-Control-Request-Method" + HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" + HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" + HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" + HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" + HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" + HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" + HeaderAccessControlMaxAge = "Access-Control-Max-Age" ) var ( diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 00000000..466c817e --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,133 @@ +package middleware + +import ( + "net/http" + "strconv" + "strings" + + "github.com/labstack/echo" +) + +type ( + // CORSConfig defines the config for CORS middleware. + CORSConfig struct { + // AllowOrigin defines a list of origins that may access the resource. + // Optional with default value as []string{"*"}. + AllowOrigins []string + + // AllowMethods defines a list methods allowed when accessing the resource. + // This is used in response to a preflight request. + // Optional with default value as `DefaultCORSConfig.AllowMethods`. + AllowMethods []string + + // AllowHeaders defines a list of request headers that can be used when + // making the actual request. This in response to a preflight request. + // Optional with default value as []string{}. + AllowHeaders []string + + // AllowCredentials indicates whether or not the response to the request + // can be exposed when the credentials flag is true. When used as part of + // a response to a preflight request, this indicates whether or not the + // actual request can be made using credentials. + // Optional with default value as false. + AllowCredentials bool + + // ExposeHeaders defines a whitelist headers that clients are allowed to + // access. + // Optional with default value as []string{}. + ExposeHeaders []string + + // MaxAge indicates how long (in seconds) the results of a preflight request + // can be cached. + // Optional with default value as 0. + MaxAge int + } +) + +var ( + // DefaultCORSConfig is the default CORS middleware config. + DefaultCORSConfig = CORSConfig{ + AllowOrigins: []string{"*"}, + AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.POST, echo.DELETE}, + } +) + +// CORS returns a cross-origin HTTP request (CORS) middleware. +// See https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS +func CORS() echo.MiddlewareFunc { + return CORSFromConfig(DefaultCORSConfig) +} + +// CORSFromConfig returns a CORS middleware from config. +// See `CORS()`. +func CORSFromConfig(config CORSConfig) echo.MiddlewareFunc { + // Defaults + if len(config.AllowOrigins) == 0 { + config.AllowOrigins = DefaultCORSConfig.AllowOrigins + } + if len(config.AllowMethods) == 0 { + config.AllowMethods = DefaultCORSConfig.AllowMethods + } + allowMethods := strings.Join(config.AllowMethods, ",") + allowHeaders := strings.Join(config.AllowHeaders, ",") + exposeHeaders := strings.Join(config.ExposeHeaders, ",") + maxAge := strconv.Itoa(config.MaxAge) + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + rq := c.Request() + origin := c.Request().Header().Get(echo.HeaderOrigin) + header := c.Response().Header() + + // Check allowed origins + allowedOrigin := "" + for _, o := range config.AllowOrigins { + if o == "*" || o == origin { + allowedOrigin = o + break + } + } + + // Simple request + if rq.Method() != echo.OPTIONS { + header.Add(echo.HeaderVary, echo.HeaderOrigin) + if origin == "" || allowedOrigin == "" { + return next(c) + } + header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + if config.AllowCredentials { + header.Set(echo.HeaderAccessControlAllowCredentials, "true") + } + if exposeHeaders != "" { + header.Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) + } + return next(c) + } + + // Preflight request + header.Add(echo.HeaderVary, echo.HeaderOrigin) + header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) + header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) + if origin == "" || allowedOrigin == "" { + return next(c) + } + header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) + header.Set(echo.HeaderAccessControlAllowMethods, allowMethods) + if config.AllowCredentials { + header.Set(echo.HeaderAccessControlAllowCredentials, "true") + } + if allowHeaders != "" { + header.Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) + } else { + h := rq.Header().Get(echo.HeaderAccessControlRequestHeaders) + if h != "" { + header.Set(echo.HeaderAccessControlAllowHeaders, h) + } + } + if config.MaxAge > 0 { + header.Set(echo.HeaderAccessControlMaxAge, maxAge) + } + return c.NoContent(http.StatusNoContent) + } + } +} diff --git a/middleware/cors_test.go b/middleware/cors_test.go new file mode 100644 index 00000000..7b2d24e8 --- /dev/null +++ b/middleware/cors_test.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/labstack/echo" + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +func TestCORS(t *testing.T) { + e := echo.New() + rq := test.NewRequest(echo.GET, "/", nil) + rc := test.NewResponseRecorder() + c := echo.NewContext(rq, rc, e) + cors := CORSFromConfig(CORSConfig{ + AllowCredentials: true, + }) + h := cors(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + // No origin header + h(c) + assert.Equal(t, "", rc.Header().Get(echo.HeaderAccessControlAllowOrigin)) + + // Wildcard origin + rq = test.NewRequest(echo.GET, "/", nil) + rc = test.NewResponseRecorder() + c = echo.NewContext(rq, rc, e) + rq.Header().Set(echo.HeaderOrigin, "localhost") + h(c) + assert.Equal(t, "*", rc.Header().Get(echo.HeaderAccessControlAllowOrigin)) + + // Simple request + rq = test.NewRequest(echo.GET, "/", nil) + rc = test.NewResponseRecorder() + c = echo.NewContext(rq, rc, e) + rq.Header().Set(echo.HeaderOrigin, "localhost") + cors = CORSFromConfig(CORSConfig{ + AllowOrigins: []string{"localhost"}, + AllowCredentials: true, + MaxAge: 3600, + }) + h = cors(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + h(c) + assert.Equal(t, "localhost", rc.Header().Get(echo.HeaderAccessControlAllowOrigin)) + + // Preflight request + rq = test.NewRequest(echo.OPTIONS, "/", nil) + rc = test.NewResponseRecorder() + c = echo.NewContext(rq, rc, e) + rq.Header().Set(echo.HeaderOrigin, "localhost") + rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + h(c) + assert.Equal(t, "localhost", rc.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.NotEmpty(t, rc.Header().Get(echo.HeaderAccessControlAllowMethods)) + assert.Equal(t, "true", rc.Header().Get(echo.HeaderAccessControlAllowCredentials)) + assert.Equal(t, "3600", rc.Header().Get(echo.HeaderAccessControlMaxAge)) +} diff --git a/middleware/static_test.go b/middleware/static_test.go new file mode 100644 index 00000000..c870d7c1 --- /dev/null +++ b/middleware/static_test.go @@ -0,0 +1 @@ +package middleware