mirror of
https://github.com/labstack/echo.git
synced 2025-11-27 22:38:25 +02:00
Allow header support in Router, MethodNotFoundHandler (405) and CORS middleware
This commit is contained in:
@@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, allowedOrigin, method string
|
||||
expected bool
|
||||
func TestCORSWithConfig_AllowMethods(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
allowOrigins []string
|
||||
allowContextKey string
|
||||
|
||||
whenOrigin string
|
||||
whenAllowMethods []string
|
||||
|
||||
expectAllow string
|
||||
expectAccessControlAllowMethods string
|
||||
}{
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
|
||||
allowContextKey: "OPTIONS, GET",
|
||||
whenAllowMethods: []string{http.MethodGet, http.MethodHead},
|
||||
whenOrigin: "",
|
||||
expectAllow: "OPTIONS, GET",
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
|
||||
allowContextKey: "",
|
||||
whenAllowMethods: nil,
|
||||
whenOrigin: "",
|
||||
expectAllow: "",
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
|
||||
allowContextKey: "OPTIONS, GET",
|
||||
whenAllowMethods: []string{http.MethodGet, http.MethodHead},
|
||||
whenOrigin: "http://google.com",
|
||||
expectAllow: "OPTIONS, GET",
|
||||
expectAccessControlAllowMethods: "GET,HEAD",
|
||||
},
|
||||
{
|
||||
domain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
name: "default AllowMethods, preflight, existing origin, sets both headers",
|
||||
allowContextKey: "OPTIONS, GET",
|
||||
whenAllowMethods: nil,
|
||||
whenOrigin: "http://google.com",
|
||||
expectAllow: "OPTIONS, GET",
|
||||
expectAccessControlAllowMethods: "OPTIONS, GET",
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
|
||||
allowContextKey: "",
|
||||
whenAllowMethods: nil,
|
||||
whenOrigin: "http://google.com",
|
||||
expectAllow: "",
|
||||
expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
},
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(tt.method, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
if tt.domain != "" {
|
||||
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||
}
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.allowedOrigin},
|
||||
//AllowCredentials: true,
|
||||
//MaxAge: 3600,
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: tc.allowOrigins,
|
||||
AllowMethods: tc.whenAllowMethods,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
|
||||
if tc.allowContextKey != "" {
|
||||
c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
|
||||
}
|
||||
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
|
||||
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
|
||||
assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||
func TestCorsHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
originDomain string
|
||||
method string
|
||||
allowedOrigin string
|
||||
expected bool
|
||||
expectStatus int
|
||||
expectAllowHeader string
|
||||
}{
|
||||
{
|
||||
name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
|
||||
originDomain: "",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow any origin, specific origin domain",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
|
||||
originDomain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
|
||||
originDomain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "preflight, allow any origin, missing origin header = no CORS logic done",
|
||||
originDomain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow any origin, existing origin header = CORS logic done",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow any origin, missing origin header = no CORS logic done",
|
||||
originDomain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow specific origin, different origin header = no CORS logic done",
|
||||
originDomain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow specific origin, matching origin header = CORS logic done",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
}
|
||||
|
||||
expectedAllowOrigin := ""
|
||||
if tt.allowedOrigin == "*" {
|
||||
expectedAllowOrigin = "*"
|
||||
} else {
|
||||
expectedAllowOrigin = tt.domain
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
switch {
|
||||
case tt.expected && tt.method == http.MethodOptions:
|
||||
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||
case tt.expected && tt.method == http.MethodGet:
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
default:
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
}
|
||||
e.Use(CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tc.allowedOrigin},
|
||||
//AllowCredentials: true,
|
||||
//MaxAge: 3600,
|
||||
}))
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
e.POST("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusCreated, "OK")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(tc.method, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
if tc.originDomain != "" {
|
||||
req.Header.Set(echo.HeaderOrigin, tc.originDomain)
|
||||
}
|
||||
|
||||
// we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||
assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
|
||||
assert.Equal(t, tc.expectStatus, rec.Code)
|
||||
|
||||
expectedAllowOrigin := ""
|
||||
if tc.allowedOrigin == "*" {
|
||||
expectedAllowOrigin = "*"
|
||||
} else {
|
||||
expectedAllowOrigin = tc.originDomain
|
||||
}
|
||||
switch {
|
||||
case tc.expected && tc.method == http.MethodOptions:
|
||||
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||
|
||||
case tc.expected && tc.method == http.MethodGet:
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
default:
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
}
|
||||
})
|
||||
|
||||
if tt.method == http.MethodOptions {
|
||||
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user