1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

Merge pull request #1651 from curvegrid/cors-allow-origin-func

CORS: add an optional custom function to validate the origin
This commit is contained in:
Roland Lammel 2020-11-25 10:36:46 +01:00 committed by GitHub
commit 502cce28d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 28 deletions

View File

@ -19,6 +19,13 @@ type (
// Optional. Default value []string{"*"}. // Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"` AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
@ -113,6 +120,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
} }
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if allowed {
allowOrigin = origin
}
} else {
// Check allowed origins // Check allowed origins
for _, o := range config.AllowOrigins { for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials { if o == "*" && config.AllowCredentials {
@ -148,6 +164,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
} }
} }
}
// Origin not allowed // Origin not allowed
if allowOrigin == "" { if allowOrigin == "" {

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) {
} }
} }
} }
func Test_allowOriginFunc(t *testing.T) {
returnTrue := func(origin string) (bool, error) {
return true, nil
}
returnFalse := func(origin string) (bool, error) {
return false, nil
}
returnError := func(origin string) (bool, error) {
return true, errors.New("this is a test error")
}
allowOriginFuncs := []func(origin string) (bool, error){
returnTrue,
returnFalse,
returnError,
}
const origin = "http://example.com"
e := echo.New()
for _, allowOriginFunc := range allowOriginFuncs {
req := httptest.NewRequest(http.MethodOptions, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, origin)
cors := CORSWithConfig(CORSConfig{
AllowOriginFunc: allowOriginFunc,
})
h := cors(echo.NotFoundHandler)
err := h(c)
expected, expectedErr := allowOriginFunc(origin)
if expectedErr != nil {
assert.Equal(t, expectedErr, err)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
continue
}
if expected {
assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
} else {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}
}
}