1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-01 00:55:04 +02:00
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana
2016-11-12 20:24:53 -08:00
parent 74ccda6546
commit 2f70d3e1c7
3 changed files with 56 additions and 47 deletions

View File

@ -15,7 +15,8 @@ type (
Skipper Skipper Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource. // AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}. // Optional. If request header `Origin` is set, value is []string{"<Origin>"}
// else []string{"*"}.
AllowOrigins []string `json:"allow_origins"` AllowOrigins []string `json:"allow_origins"`
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
@ -51,7 +52,6 @@ var (
// DefaultCORSConfig is the default CORS middleware config. // DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{ DefaultCORSConfig = CORSConfig{
Skipper: defaultSkipper, Skipper: defaultSkipper,
AllowOrigins: []string{"*"},
AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.PATCH, echo.POST, echo.DELETE}, AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.PATCH, echo.POST, echo.DELETE},
} }
) )
@ -69,12 +69,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
if config.Skipper == nil { if config.Skipper == nil {
config.Skipper = DefaultCORSConfig.Skipper config.Skipper = DefaultCORSConfig.Skipper
} }
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 { if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods config.AllowMethods = DefaultCORSConfig.AllowMethods
} }
allowedOrigins := strings.Join(config.AllowOrigins, ",") allowedOrigins := strings.Join(config.AllowOrigins, ",")
allowMethods := strings.Join(config.AllowMethods, ",") allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",") allowHeaders := strings.Join(config.AllowHeaders, ",")
@ -89,6 +87,17 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
origin := req.Header.Get(echo.HeaderOrigin)
if allowedOrigins == "" {
if origin != "" {
allowedOrigins = origin
} else {
if !config.AllowCredentials {
allowedOrigins = "*"
}
}
}
// Simple request // Simple request
if req.Method != echo.OPTIONS { if req.Method != echo.OPTIONS {

View File

@ -11,21 +11,21 @@ import (
func TestCORS(t *testing.T) { func TestCORS(t *testing.T) {
e := echo.New() e := echo.New()
// Origin origin
req, _ := http.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
cors := CORSWithConfig(CORSConfig{ h := CORS()(echo.NotFoundHandler)
AllowCredentials: true, req.Header.Set(echo.HeaderOrigin, "localhost")
}) h(c)
h := cors(func(c echo.Context) error { assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
return c.String(http.StatusOK, "test")
})
// Wildcard origin // Wildcard origin
req, _ = http.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost") h = CORS()(echo.NotFoundHandler)
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -34,14 +34,7 @@ func TestCORS(t *testing.T) {
rec = httptest.NewRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
cors = CORSWithConfig(CORSConfig{ h = CORS()(echo.NotFoundHandler)
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
@ -51,6 +44,12 @@ func TestCORS(t *testing.T) {
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(echo.NotFoundHandler)
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))

View File

@ -37,7 +37,8 @@ CORSConfig struct {
Skipper Skipper Skipper Skipper
// AllowOrigin defines a list of origins that may access the resource. // AllowOrigin defines a list of origins that may access the resource.
// Optional. Default value []string{"*"}. // Optional. If request header `Origin` is set, value is []string{"<Origin>"}
// else []string{"*"}.
AllowOrigins []string `json:"allow_origins"` AllowOrigins []string `json:"allow_origins"`
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.