package apis // ------------------------------------------------------------------- // This middleware is ported from echo/middleware to minimize the breaking // changes and differences in the API behavior from earlier PocketBase versions // (https://github.com/labstack/echo/blob/ec5b858dab6105ab4c3ed2627d1ebdfb6ae1ecb8/middleware/cors.go). // // I doubt that this would matter for most cases, but the only major difference // is that for non-supported routes this middleware doesn't return 405 and fallbacks // to the default catch-all PocketBase route (aka. returns 404) to avoid // the extra overhead of further hijacking and wrapping the Go default mux // (https://github.com/golang/go/issues/65648#issuecomment-1955328807). // ------------------------------------------------------------------- import ( "log" "net/http" "regexp" "strconv" "strings" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tools/hook" ) const ( DefaultCorsMiddlewareId = "pbCors" DefaultCorsMiddlewarePriority = DefaultActivityLoggerMiddlewarePriority - 1 // before the activity logger and rate limit so that OPTIONS preflight requests are not counted ) // CORSConfig defines the config for CORS middleware. type CORSConfig struct { // AllowOrigins determines the value of the Access-Control-Allow-Origin // response header. This header defines a list of origins that may access the // resource. The wildcard characters '*' and '?' are supported and are // converted to regex fragments '.*' and '.' accordingly. // // Security: use extreme caution when handling the origin, and carefully // validate any logic. Remember that attackers may register hostile domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // Optional. Default value []string{"*"}. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin AllowOrigins []string // AllowOriginFunc is a custom function to validate the origin. It takes the // origin as an argument and returns true if allowed or false otherwise. If // an error is returned, it is returned by the handler. If this option is // set, AllowOrigins is ignored. // // Security: use extreme caution when handling the origin, and carefully // validate any logic. Remember that attackers may register hostile domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // Optional. AllowOriginFunc func(origin string) (bool, error) // AllowMethods determines the value of the Access-Control-Allow-Methods // response header. This header specified the list of methods allowed when // accessing the resource. This is used in response to a preflight request. // // Optional. Default value DefaultCORSConfig.AllowMethods. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods AllowMethods []string // AllowHeaders determines the value of the Access-Control-Allow-Headers // response header. This header is used in response to a preflight request to // indicate which HTTP headers can be used when making the actual request. // // Optional. Default value []string{}. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers AllowHeaders []string // AllowCredentials determines the value of the // Access-Control-Allow-Credentials response header. This header indicates // whether or not the response to the request can be exposed when the // credentials mode (Request.credentials) is true. When used as part of a // response to a preflight request, this indicates whether or not the actual // request can be made using credentials. See also // [MDN: Access-Control-Allow-Credentials]. // // Optional. Default value false, in which case the header is not set. // // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. // See "Exploiting CORS misconfigurations for Bitcoins and bounties", // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials AllowCredentials bool // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. // // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. // // Optional. Default value is false. UnsafeWildcardOriginWithAllowCredentials bool // ExposeHeaders determines the value of Access-Control-Expose-Headers, which // defines a list of headers that clients are allowed to access. // // Optional. Default value []string{}, in which case the header is not set. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header ExposeHeaders []string // MaxAge determines the value of the Access-Control-Max-Age response header. // This header indicates how long (in seconds) the results of a preflight // request can be cached. // The header is set only if MaxAge != 0, negative value sends "0" which instructs browsers not to cache that response. // // Optional. Default value 0 - meaning header is not sent. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age MaxAge int } // DefaultCORSConfig is the default CORS middleware config. var DefaultCORSConfig = CORSConfig{ AllowOrigins: []string{"*"}, AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, } // CORS returns a CORS middleware. func CORS(config CORSConfig) *hook.Handler[*core.RequestEvent] { // Defaults if len(config.AllowOrigins) == 0 { config.AllowOrigins = DefaultCORSConfig.AllowOrigins } if len(config.AllowMethods) == 0 { config.AllowMethods = DefaultCORSConfig.AllowMethods } allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins)) for _, origin := range config.AllowOrigins { if origin == "*" { continue // "*" is handled differently and does not need regexp } pattern := regexp.QuoteMeta(origin) pattern = strings.ReplaceAll(pattern, "\\*", ".*") pattern = strings.ReplaceAll(pattern, "\\?", ".") pattern = "^" + pattern + "$" re, err := regexp.Compile(pattern) if err != nil { // This is to preserve previous behaviour - invalid patterns were just ignored. // If we would turn this to panic, users with invalid patterns // would have applications crashing in production due unrecovered panic. log.Println("invalid AllowOrigins pattern", origin) continue } allowOriginPatterns = append(allowOriginPatterns, re) } allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") maxAge := "0" if config.MaxAge > 0 { maxAge = strconv.Itoa(config.MaxAge) } return &hook.Handler[*core.RequestEvent]{ Id: DefaultCorsMiddlewareId, Priority: DefaultCorsMiddlewarePriority, Func: func(e *core.RequestEvent) error { req := e.Request res := e.Response origin := req.Header.Get("Origin") allowOrigin := "" res.Header().Add("Vary", "Origin") // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request // For simplicity we just consider method type and later `Origin` header. preflight := req.Method == http.MethodOptions // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { if !preflight { return e.Next() } return e.NoContent(http.StatusNoContent) } if config.AllowOriginFunc != nil { allowed, err := config.AllowOriginFunc(origin) if err != nil { return err } if allowed { allowOrigin = origin } } else { // Check allowed origins for _, o := range config.AllowOrigins { if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { allowOrigin = origin break } if o == "*" || o == origin { allowOrigin = o break } if matchSubdomain(origin, o) { allowOrigin = origin break } } checkPatterns := false if allowOrigin == "" { // to avoid regex cost by invalid (long) domains (253 is domain name max limit) if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { checkPatterns = true } } if checkPatterns { for _, re := range allowOriginPatterns { if match := re.MatchString(origin); match { allowOrigin = origin break } } } } // Origin not allowed if allowOrigin == "" { if !preflight { return e.Next() } return e.NoContent(http.StatusNoContent) } res.Header().Set("Access-Control-Allow-Origin", allowOrigin) if config.AllowCredentials { res.Header().Set("Access-Control-Allow-Credentials", "true") } // Simple request if !preflight { if exposeHeaders != "" { res.Header().Set("Access-Control-Expose-Headers", exposeHeaders) } return e.Next() } // Preflight request res.Header().Add("Vary", "Access-Control-Request-Method") res.Header().Add("Vary", "Access-Control-Request-Headers") res.Header().Set("Access-Control-Allow-Methods", allowMethods) if allowHeaders != "" { res.Header().Set("Access-Control-Allow-Headers", allowHeaders) } else { h := req.Header.Get("Access-Control-Request-Headers") if h != "" { res.Header().Set("Access-Control-Allow-Headers", h) } } if config.MaxAge != 0 { res.Header().Set("Access-Control-Max-Age", maxAge) } return e.NoContent(http.StatusNoContent) }, } } func matchScheme(domain, pattern string) bool { didx := strings.Index(domain, ":") pidx := strings.Index(pattern, ":") return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx] } // matchSubdomain compares authority with wildcard func matchSubdomain(domain, pattern string) bool { if !matchScheme(domain, pattern) { return false } didx := strings.Index(domain, "://") pidx := strings.Index(pattern, "://") if didx == -1 || pidx == -1 { return false } domAuth := domain[didx+3:] // to avoid long loop by invalid long domain if len(domAuth) > 253 { return false } patAuth := pattern[pidx+3:] domComp := strings.Split(domAuth, ".") patComp := strings.Split(patAuth, ".") for i := len(domComp)/2 - 1; i >= 0; i-- { opp := len(domComp) - 1 - i domComp[i], domComp[opp] = domComp[opp], domComp[i] } for i := len(patComp)/2 - 1; i >= 0; i-- { opp := len(patComp) - 1 - i patComp[i], patComp[opp] = patComp[opp], patComp[i] } for i, v := range domComp { if len(patComp) <= i { return false } p := patComp[i] if p == "*" { return true } if p != v { return false } } return false }