mirror of
https://github.com/labstack/echo.git
synced 2024-12-22 20:06:21 +02:00
236 lines
7.6 KiB
Go
236 lines
7.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/labstack/echo/v5"
|
|
)
|
|
|
|
// CORSConfig defines the config for CORS middleware.
|
|
type CORSConfig struct {
|
|
// Skipper defines a function to skip middleware.
|
|
Skipper Skipper
|
|
|
|
// AllowOrigin defines a list of origins that may access the resource.
|
|
// Optional. Default value []string{"*"}.
|
|
AllowOrigins []string
|
|
|
|
// AllowOriginFunc is a custom function to validate the origin. It takes the
|
|
// origin as an argument and returns true if allowed or false otherwise. If
|
|
// an error is returned, it is returned by the handler. If this option is
|
|
// set, AllowOrigins is ignored.
|
|
// Optional.
|
|
AllowOriginFunc func(origin string) (bool, error)
|
|
|
|
// AllowMethods defines a list methods allowed when accessing the resource.
|
|
// This is used in response to a preflight request.
|
|
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
|
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value
|
|
// from `Allow` header that echo.Router set into context.
|
|
AllowMethods []string
|
|
|
|
// AllowHeaders defines a list of request headers that can be used when
|
|
// making the actual request. This is in response to a preflight request.
|
|
// Optional. Default value []string{}.
|
|
AllowHeaders []string
|
|
|
|
// AllowCredentials indicates whether or not the response to the request
|
|
// can be exposed when the credentials flag is true. When used as part of
|
|
// a response to a preflight request, this indicates whether or not the
|
|
// actual request can be made using credentials.
|
|
// Optional. Default value false.
|
|
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
|
|
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
|
AllowCredentials bool
|
|
|
|
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
|
// access.
|
|
// Optional. Default value []string{}.
|
|
ExposeHeaders []string
|
|
|
|
// MaxAge indicates how long (in seconds) the results of a preflight request
|
|
// can be cached.
|
|
// Optional. Default value 0.
|
|
MaxAge int
|
|
}
|
|
|
|
// DefaultCORSConfig is the default CORS middleware config.
|
|
var DefaultCORSConfig = CORSConfig{
|
|
Skipper: DefaultSkipper,
|
|
AllowOrigins: []string{"*"},
|
|
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
|
|
}
|
|
|
|
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
|
|
// See: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
|
|
func CORS() echo.MiddlewareFunc {
|
|
return CORSWithConfig(DefaultCORSConfig)
|
|
}
|
|
|
|
// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration.
|
|
// See: `CORS()`.
|
|
func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
|
return toMiddlewareOrPanic(config)
|
|
}
|
|
|
|
// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration
|
|
func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
|
|
// Defaults
|
|
if config.Skipper == nil {
|
|
config.Skipper = DefaultCORSConfig.Skipper
|
|
}
|
|
if len(config.AllowOrigins) == 0 {
|
|
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
|
|
}
|
|
hasCustomAllowMethods := true
|
|
if len(config.AllowMethods) == 0 {
|
|
hasCustomAllowMethods = false
|
|
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
|
}
|
|
|
|
allowOriginPatterns := []string{}
|
|
for _, origin := range config.AllowOrigins {
|
|
pattern := regexp.QuoteMeta(origin)
|
|
pattern = strings.Replace(pattern, "\\*", ".*", -1)
|
|
pattern = strings.Replace(pattern, "\\?", ".", -1)
|
|
pattern = "^" + pattern + "$"
|
|
allowOriginPatterns = append(allowOriginPatterns, pattern)
|
|
}
|
|
|
|
allowMethods := strings.Join(config.AllowMethods, ",")
|
|
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
|
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
|
maxAge := strconv.Itoa(config.MaxAge)
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if config.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
|
|
req := c.Request()
|
|
res := c.Response()
|
|
origin := req.Header.Get(echo.HeaderOrigin)
|
|
allowOrigin := ""
|
|
|
|
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
|
|
|
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
|
|
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
|
// For simplicity we just consider method type and later `Origin` header.
|
|
preflight := req.Method == http.MethodOptions
|
|
|
|
// Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
|
|
// as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
|
|
// middlewares by calling next(c).
|
|
// But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
|
|
// handler does.
|
|
routerAllowMethods := ""
|
|
if preflight {
|
|
tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
|
|
if ok && tmpAllowMethods != "" {
|
|
routerAllowMethods = tmpAllowMethods
|
|
c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
|
|
}
|
|
}
|
|
|
|
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
|
|
if origin == "" {
|
|
if !preflight {
|
|
return next(c)
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
|
|
if config.AllowOriginFunc != nil {
|
|
allowed, err := config.AllowOriginFunc(origin)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if allowed {
|
|
allowOrigin = origin
|
|
}
|
|
} else {
|
|
// Check allowed origins
|
|
for _, o := range config.AllowOrigins {
|
|
if o == "*" && config.AllowCredentials {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
if o == "*" || o == origin {
|
|
allowOrigin = o
|
|
break
|
|
}
|
|
if matchSubdomain(origin, o) {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
}
|
|
|
|
checkPatterns := false
|
|
if allowOrigin == "" {
|
|
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
|
|
if len(origin) <= (5+3+253) && strings.Contains(origin, "://") {
|
|
checkPatterns = true
|
|
}
|
|
}
|
|
if checkPatterns {
|
|
for _, re := range allowOriginPatterns {
|
|
if match, _ := regexp.MatchString(re, origin); match {
|
|
allowOrigin = origin
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Origin not allowed
|
|
if allowOrigin == "" {
|
|
if !preflight {
|
|
return next(c)
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
|
|
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
|
if config.AllowCredentials {
|
|
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
|
}
|
|
|
|
// Simple request
|
|
if !preflight {
|
|
if exposeHeaders != "" {
|
|
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
|
|
}
|
|
return next(c)
|
|
}
|
|
|
|
// Preflight request
|
|
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
|
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
|
|
|
if !hasCustomAllowMethods && routerAllowMethods != "" {
|
|
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
|
|
} else {
|
|
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
|
|
}
|
|
|
|
if allowHeaders != "" {
|
|
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
|
|
} else {
|
|
h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
|
|
if h != "" {
|
|
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
|
|
}
|
|
}
|
|
if config.MaxAge > 0 {
|
|
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
|
|
}
|
|
return c.NoContent(http.StatusNoContent)
|
|
}
|
|
}, nil
|
|
}
|