mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Merge pull request #1651 from curvegrid/cors-allow-origin-func
CORS: add an optional custom function to validate the origin
This commit is contained in:
commit
502cce28d5
@ -19,6 +19,13 @@ type (
|
||||
// Optional. Default value []string{"*"}.
|
||||
AllowOrigins []string `yaml:"allow_origins"`
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||
// origin as an argument and returns true if allowed or false otherwise. If
|
||||
// an error is returned, it is returned by the handler. If this option is
|
||||
// set, AllowOrigins is ignored.
|
||||
// Optional.
|
||||
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
@ -113,39 +120,49 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials {
|
||||
if config.AllowOriginFunc != nil {
|
||||
allowed, err := config.AllowOriginFunc(origin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if allowed {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed origin patterns
|
||||
for _, re := range allowOriginPatterns {
|
||||
if allowOrigin == "" {
|
||||
didx := strings.Index(origin, "://")
|
||||
if didx == -1 {
|
||||
continue
|
||||
}
|
||||
domAuth := origin[didx+3:]
|
||||
// to avoid regex cost by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
break
|
||||
}
|
||||
|
||||
if match, _ := regexp.MatchString(re, origin); match {
|
||||
} else {
|
||||
// Check allowed origins
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" && config.AllowCredentials {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
if o == "*" || o == origin {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed origin patterns
|
||||
for _, re := range allowOriginPatterns {
|
||||
if allowOrigin == "" {
|
||||
didx := strings.Index(origin, "://")
|
||||
if didx == -1 {
|
||||
continue
|
||||
}
|
||||
domAuth := origin[didx+3:]
|
||||
// to avoid regex cost by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
break
|
||||
}
|
||||
|
||||
if match, _ := regexp.MatchString(re, origin); match {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_allowOriginFunc(t *testing.T) {
|
||||
returnTrue := func(origin string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
returnFalse := func(origin string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
returnError := func(origin string) (bool, error) {
|
||||
return true, errors.New("this is a test error")
|
||||
}
|
||||
|
||||
allowOriginFuncs := []func(origin string) (bool, error){
|
||||
returnTrue,
|
||||
returnFalse,
|
||||
returnError,
|
||||
}
|
||||
|
||||
const origin = "http://example.com"
|
||||
|
||||
e := echo.New()
|
||||
for _, allowOriginFunc := range allowOriginFuncs {
|
||||
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, origin)
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOriginFunc: allowOriginFunc,
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
err := h(c)
|
||||
|
||||
expected, expectedErr := allowOriginFunc(origin)
|
||||
if expectedErr != nil {
|
||||
assert.Equal(t, expectedErr, err)
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
continue
|
||||
}
|
||||
|
||||
if expected {
|
||||
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
} else {
|
||||
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user