mirror of
https://github.com/labstack/echo.git
synced 2024-11-28 08:38:39 +02:00
CORS: add an optional custom function to validate the origin
This commit is contained in:
parent
17a5fca161
commit
26ab188922
@ -19,6 +19,13 @@ type (
|
|||||||
// Optional. Default value []string{"*"}.
|
// Optional. Default value []string{"*"}.
|
||||||
AllowOrigins []string `yaml:"allow_origins"`
|
AllowOrigins []string `yaml:"allow_origins"`
|
||||||
|
|
||||||
|
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
||||||
|
// origin as an argument and returns true if allowed or false otherwise. If
|
||||||
|
// an error is returned, it is returned by the handler. If this option is
|
||||||
|
// set, AllowOrigins is ignored.
|
||||||
|
// Optional.
|
||||||
|
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
|
||||||
|
|
||||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||||
// This is used in response to a preflight request.
|
// This is used in response to a preflight request.
|
||||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||||
@ -113,39 +120,49 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|||||||
return c.NoContent(http.StatusNoContent)
|
return c.NoContent(http.StatusNoContent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check allowed origins
|
if config.AllowOriginFunc == nil {
|
||||||
for _, o := range config.AllowOrigins {
|
// Check allowed origins
|
||||||
if o == "*" && config.AllowCredentials {
|
for _, o := range config.AllowOrigins {
|
||||||
allowOrigin = origin
|
if o == "*" && config.AllowCredentials {
|
||||||
break
|
|
||||||
}
|
|
||||||
if o == "*" || o == origin {
|
|
||||||
allowOrigin = o
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if matchSubdomain(origin, o) {
|
|
||||||
allowOrigin = origin
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check allowed origin patterns
|
|
||||||
for _, re := range allowOriginPatterns {
|
|
||||||
if allowOrigin == "" {
|
|
||||||
didx := strings.Index(origin, "://")
|
|
||||||
if didx == -1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
domAuth := origin[didx+3:]
|
|
||||||
// to avoid regex cost by invalid long domain
|
|
||||||
if len(domAuth) > 253 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if match, _ := regexp.MatchString(re, origin); match {
|
|
||||||
allowOrigin = origin
|
allowOrigin = origin
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if o == "*" || o == origin {
|
||||||
|
allowOrigin = o
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if matchSubdomain(origin, o) {
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check allowed origin patterns
|
||||||
|
for _, re := range allowOriginPatterns {
|
||||||
|
if allowOrigin == "" {
|
||||||
|
didx := strings.Index(origin, "://")
|
||||||
|
if didx == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domAuth := origin[didx+3:]
|
||||||
|
// to avoid regex cost by invalid long domain
|
||||||
|
if len(domAuth) > 253 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if match, _ := regexp.MatchString(re, origin); match {
|
||||||
|
allowOrigin = origin
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
allowed, err := config.AllowOriginFunc(origin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if allowed {
|
||||||
|
allowOrigin = origin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_allowOriginFunc(t *testing.T) {
|
||||||
|
returnTrue := func(origin string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
returnFalse := func(origin string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
returnError := func(origin string) (bool, error) {
|
||||||
|
return true, errors.New("this is a test error")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowOriginFuncs := []func(origin string) (bool, error){
|
||||||
|
returnTrue,
|
||||||
|
returnFalse,
|
||||||
|
returnError,
|
||||||
|
}
|
||||||
|
|
||||||
|
const origin = "http://example.com"
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
for _, allowOriginFunc := range allowOriginFuncs {
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
req.Header.Set(echo.HeaderOrigin, origin)
|
||||||
|
cors := CORSWithConfig(CORSConfig{
|
||||||
|
AllowOriginFunc: allowOriginFunc,
|
||||||
|
})
|
||||||
|
h := cors(echo.NotFoundHandler)
|
||||||
|
err := h(c)
|
||||||
|
|
||||||
|
expected, expectedErr := allowOriginFunc(origin)
|
||||||
|
if expectedErr != nil {
|
||||||
|
assert.Equal(t, expectedErr, err)
|
||||||
|
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if expected {
|
||||||
|
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user