diff --git a/middleware/cors.go b/middleware/cors.go index 07df0e57..d6ef8964 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -19,6 +19,13 @@ type ( // Optional. Default value []string{"*"}. AllowOrigins []string `yaml:"allow_origins"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. @@ -113,39 +120,49 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err + } + if allowed { allowOrigin = origin - break } - if o == "*" || o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin - break - } - } - - // Check allowed origin patterns - for _, re := range allowOriginPatterns { - if allowOrigin == "" { - didx := strings.Index(origin, "://") - if didx == -1 { - continue - } - domAuth := origin[didx+3:] - // to avoid regex cost by invalid long domain - if len(domAuth) > 253 { - break - } - - if match, _ := regexp.MatchString(re, origin); match { + } else { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials { allowOrigin = origin break } + if o == "*" || o == origin { + allowOrigin = o + break + } + if matchSubdomain(origin, o) { + allowOrigin = origin + break + } + } + + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } } } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index fc34694d..717abe49 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "net/http/httptest" "testing" @@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) { } } } + +func Test_allowOriginFunc(t *testing.T) { + returnTrue := func(origin string) (bool, error) { + return true, nil + } + returnFalse := func(origin string) (bool, error) { + return false, nil + } + returnError := func(origin string) (bool, error) { + return true, errors.New("this is a test error") + } + + allowOriginFuncs := []func(origin string) (bool, error){ + returnTrue, + returnFalse, + returnError, + } + + const origin = "http://example.com" + + e := echo.New() + for _, allowOriginFunc := range allowOriginFuncs { + req := httptest.NewRequest(http.MethodOptions, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + req.Header.Set(echo.HeaderOrigin, origin) + cors := CORSWithConfig(CORSConfig{ + AllowOriginFunc: allowOriginFunc, + }) + h := cors(echo.NotFoundHandler) + err := h(c) + + expected, expectedErr := allowOriginFunc(origin) + if expectedErr != nil { + assert.Equal(t, expectedErr, err) + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + continue + } + + if expected { + assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } else { + assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + } + } +}