mirror of
https://github.com/labstack/echo.git
synced 2025-02-03 13:11:39 +02:00
Merge pull request #1669 from ulasakdeniz/fix-incorrect-cors-headers
Fix empty/incorrect CORS headers
This commit is contained in:
commit
90bef88e1a
@ -102,6 +102,17 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
origin := req.Header.Get(echo.HeaderOrigin)
|
origin := req.Header.Get(echo.HeaderOrigin)
|
||||||
allowOrigin := ""
|
allowOrigin := ""
|
||||||
|
|
||||||
|
preflight := req.Method == http.MethodOptions
|
||||||
|
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||||
|
|
||||||
|
// No Origin provided
|
||||||
|
if origin == "" {
|
||||||
|
if !preflight {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
return c.NoContent(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
// Check allowed origins
|
// Check allowed origins
|
||||||
for _, o := range config.AllowOrigins {
|
for _, o := range config.AllowOrigins {
|
||||||
if o == "*" && config.AllowCredentials {
|
if o == "*" && config.AllowCredentials {
|
||||||
@ -138,9 +149,16 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Origin not allowed
|
||||||
|
if allowOrigin == "" {
|
||||||
|
if !preflight {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
return c.NoContent(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
// Simple request
|
// Simple request
|
||||||
if req.Method != http.MethodOptions {
|
if !preflight {
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
|
||||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
if config.AllowCredentials {
|
if config.AllowCredentials {
|
||||||
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||||
@ -152,7 +170,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Preflight request
|
// Preflight request
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
||||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
||||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||||
|
@ -17,19 +17,31 @@ func TestCORS(t *testing.T) {
|
|||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
h := CORS()(echo.NotFoundHandler)
|
h := CORS()(echo.NotFoundHandler)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||||
h(c)
|
h(c)
|
||||||
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
|
||||||
|
// Wildcard AllowedOrigin with no Origin header in request
|
||||||
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
h = CORS()(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
|
||||||
// Allow origins
|
// Allow origins
|
||||||
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
req = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
c = e.NewContext(req, rec)
|
c = e.NewContext(req, rec)
|
||||||
h = CORSWithConfig(CORSConfig{
|
h = CORSWithConfig(CORSConfig{
|
||||||
AllowOrigins: []string{"localhost"},
|
AllowOrigins: []string{"localhost"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 3600,
|
||||||
})(echo.NotFoundHandler)
|
})(echo.NotFoundHandler)
|
||||||
req.Header.Set(echo.HeaderOrigin, "localhost")
|
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||||
h(c)
|
h(c)
|
||||||
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||||
|
|
||||||
// Preflight request
|
// Preflight request
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
@ -67,6 +79,22 @@ func TestCORS(t *testing.T) {
|
|||||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||||
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
||||||
|
|
||||||
|
// Preflight request with Access-Control-Request-Headers
|
||||||
|
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
rec = httptest.NewRecorder()
|
||||||
|
c = e.NewContext(req, rec)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, "localhost")
|
||||||
|
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||||
|
req.Header.Set(echo.HeaderAccessControlRequestHeaders, "Special-Request-Header")
|
||||||
|
cors = CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{"*"},
|
||||||
|
})
|
||||||
|
h = cors(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, "Special-Request-Header", rec.Header().Get(echo.HeaderAccessControlAllowHeaders))
|
||||||
|
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||||
|
|
||||||
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
||||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
rec = httptest.NewRecorder()
|
rec = httptest.NewRecorder()
|
||||||
@ -126,7 +154,7 @@ func Test_allowOriginScheme(t *testing.T) {
|
|||||||
if tt.expected {
|
if tt.expected {
|
||||||
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -217,7 +245,118 @@ func Test_allowOriginSubdomain(t *testing.T) {
|
|||||||
if tt.expected {
|
if tt.expected {
|
||||||
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCorsHeaders(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
domain, allowedOrigin, method string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://bar.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "*",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "", // Request does not have Origin header
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://bar.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodGet,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
domain: "http://example.com",
|
||||||
|
allowedOrigin: "http://example.com",
|
||||||
|
method: http.MethodOptions,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest(tt.method, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
if tt.domain != "" {
|
||||||
|
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||||
|
}
|
||||||
|
cors := CORSWithConfig(CORSConfig{
|
||||||
|
AllowOrigins: []string{tt.allowedOrigin},
|
||||||
|
//AllowCredentials: true,
|
||||||
|
//MaxAge: 3600,
|
||||||
|
})
|
||||||
|
h := cors(echo.NotFoundHandler)
|
||||||
|
h(c)
|
||||||
|
|
||||||
|
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||||
|
|
||||||
|
expectedAllowOrigin := ""
|
||||||
|
if tt.allowedOrigin == "*" {
|
||||||
|
expectedAllowOrigin = "*"
|
||||||
|
} else {
|
||||||
|
expectedAllowOrigin = tt.domain
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case tt.expected && tt.method == http.MethodOptions:
|
||||||
|
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||||
|
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||||
|
case tt.expected && tt.method == http.MethodGet:
|
||||||
|
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||||
|
default:
|
||||||
|
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||||
|
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.method == http.MethodOptions {
|
||||||
|
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user