mirror of
				https://github.com/labstack/echo.git
				synced 2025-10-30 23:57:38 +02:00 
			
		
		
		
	Merge pull request #1651 from curvegrid/cors-allow-origin-func
CORS: add an optional custom function to validate the origin
This commit is contained in:
		| @@ -19,6 +19,13 @@ type ( | ||||
| 		// Optional. Default value []string{"*"}. | ||||
| 		AllowOrigins []string `yaml:"allow_origins"` | ||||
|  | ||||
| 		// AllowOriginFunc is a custom function to validate the origin. It takes the | ||||
| 		// origin as an argument and returns true if allowed or false otherwise. If | ||||
| 		// an error is returned, it is returned by the handler. If this option is | ||||
| 		// set, AllowOrigins is ignored. | ||||
| 		// Optional. | ||||
| 		AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` | ||||
|  | ||||
| 		// AllowMethods defines a list methods allowed when accessing the resource. | ||||
| 		// This is used in response to a preflight request. | ||||
| 		// Optional. Default value DefaultCORSConfig.AllowMethods. | ||||
| @@ -113,39 +120,49 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { | ||||
| 				return c.NoContent(http.StatusNoContent) | ||||
| 			} | ||||
|  | ||||
| 			// Check allowed origins | ||||
| 			for _, o := range config.AllowOrigins { | ||||
| 				if o == "*" && config.AllowCredentials { | ||||
| 			if config.AllowOriginFunc != nil { | ||||
| 				allowed, err := config.AllowOriginFunc(origin) | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 				if allowed { | ||||
| 					allowOrigin = origin | ||||
| 					break | ||||
| 				} | ||||
| 				if o == "*" || o == origin { | ||||
| 					allowOrigin = o | ||||
| 					break | ||||
| 				} | ||||
| 				if matchSubdomain(origin, o) { | ||||
| 					allowOrigin = origin | ||||
| 					break | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			// Check allowed origin patterns | ||||
| 			for _, re := range allowOriginPatterns { | ||||
| 				if allowOrigin == "" { | ||||
| 					didx := strings.Index(origin, "://") | ||||
| 					if didx == -1 { | ||||
| 						continue | ||||
| 					} | ||||
| 					domAuth := origin[didx+3:] | ||||
| 					// to avoid regex cost by invalid long domain | ||||
| 					if len(domAuth) > 253 { | ||||
| 						break | ||||
| 					} | ||||
|  | ||||
| 					if match, _ := regexp.MatchString(re, origin); match { | ||||
| 			} else { | ||||
| 				// Check allowed origins | ||||
| 				for _, o := range config.AllowOrigins { | ||||
| 					if o == "*" && config.AllowCredentials { | ||||
| 						allowOrigin = origin | ||||
| 						break | ||||
| 					} | ||||
| 					if o == "*" || o == origin { | ||||
| 						allowOrigin = o | ||||
| 						break | ||||
| 					} | ||||
| 					if matchSubdomain(origin, o) { | ||||
| 						allowOrigin = origin | ||||
| 						break | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 				// Check allowed origin patterns | ||||
| 				for _, re := range allowOriginPatterns { | ||||
| 					if allowOrigin == "" { | ||||
| 						didx := strings.Index(origin, "://") | ||||
| 						if didx == -1 { | ||||
| 							continue | ||||
| 						} | ||||
| 						domAuth := origin[didx+3:] | ||||
| 						// to avoid regex cost by invalid long domain | ||||
| 						if len(domAuth) > 253 { | ||||
| 							break | ||||
| 						} | ||||
|  | ||||
| 						if match, _ := regexp.MatchString(re, origin); match { | ||||
| 							allowOrigin = origin | ||||
| 							break | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| @@ -360,3 +361,49 @@ func TestCorsHeaders(t *testing.T) { | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Test_allowOriginFunc(t *testing.T) { | ||||
| 	returnTrue := func(origin string) (bool, error) { | ||||
| 		return true, nil | ||||
| 	} | ||||
| 	returnFalse := func(origin string) (bool, error) { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 	returnError := func(origin string) (bool, error) { | ||||
| 		return true, errors.New("this is a test error") | ||||
| 	} | ||||
|  | ||||
| 	allowOriginFuncs := []func(origin string) (bool, error){ | ||||
| 		returnTrue, | ||||
| 		returnFalse, | ||||
| 		returnError, | ||||
| 	} | ||||
|  | ||||
| 	const origin = "http://example.com" | ||||
|  | ||||
| 	e := echo.New() | ||||
| 	for _, allowOriginFunc := range allowOriginFuncs { | ||||
| 		req := httptest.NewRequest(http.MethodOptions, "/", nil) | ||||
| 		rec := httptest.NewRecorder() | ||||
| 		c := e.NewContext(req, rec) | ||||
| 		req.Header.Set(echo.HeaderOrigin, origin) | ||||
| 		cors := CORSWithConfig(CORSConfig{ | ||||
| 			AllowOriginFunc: allowOriginFunc, | ||||
| 		}) | ||||
| 		h := cors(echo.NotFoundHandler) | ||||
| 		err := h(c) | ||||
|  | ||||
| 		expected, expectedErr := allowOriginFunc(origin) | ||||
| 		if expectedErr != nil { | ||||
| 			assert.Equal(t, expectedErr, err) | ||||
| 			assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if expected { | ||||
| 			assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) | ||||
| 		} else { | ||||
| 			assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user