1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-22 20:06:21 +02:00

cors: not checking for origin header

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-11-12 14:29:11 -08:00
parent e08070379a
commit 4c78b7122b
2 changed files with 3 additions and 31 deletions

View File

@ -75,6 +75,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
if len(config.AllowMethods) == 0 { if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods config.AllowMethods = DefaultCORSConfig.AllowMethods
} }
allowOrigins := strings.Join(config.AllowOrigins, ",")
allowMethods := strings.Join(config.AllowMethods, ",") allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",") allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",")
@ -88,25 +89,11 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
origin := req.Header().Get(echo.HeaderOrigin)
originSet := req.Header().Contains(echo.HeaderOrigin) // Issue #517
// Check allowed origins
allowedOrigin := ""
for _, o := range config.AllowOrigins {
if o == "*" || o == origin {
allowedOrigin = o
break
}
}
// Simple request // Simple request
if req.Method() != echo.OPTIONS { if req.Method() != echo.OPTIONS {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !originSet || allowedOrigin == "" { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigins)
return next(c)
}
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
if config.AllowCredentials { if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
} }
@ -120,10 +107,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
if !originSet || allowedOrigin == "" { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigins)
return next(c)
}
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
if config.AllowCredentials { if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")

View File

@ -21,18 +21,6 @@ func TestCORS(t *testing.T) {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
}) })
// No origin header
h(c)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Empty origin header
req = test.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderOrigin, "")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard origin // Wildcard origin
req = test.NewRequest(echo.GET, "/", nil) req = test.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()