package apis import ( "sync" "time" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/store" ) const ( DefaultRateLimitMiddlewareId = "pbRateLimit" DefaultRateLimitMiddlewarePriority = -1000 ) const ( rateLimitersStoreKey = "__pbRateLimiters__" rateLimitersCronKey = "__pbRateLimitersCleanup__" rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__" ) // rateLimit defines the global rate limit middleware. // // This middleware is registered by default for all routes. func rateLimit() *hook.Handler[*core.RequestEvent] { return &hook.Handler[*core.RequestEvent]{ Id: DefaultRateLimitMiddlewareId, Priority: DefaultRateLimitMiddlewarePriority, Func: func(e *core.RequestEvent) error { if skipRateLimit(e) { return e.Next() } rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(defaultRateLimitLabels(e)) if ok { err := checkRateLimit(e, e.Request.Pattern, rule) if err != nil { return err } } return e.Next() }, } } // collectionPathRateLimit defines a rate limit middleware for the internal collection handlers. func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] { if collectionPathParam == "" { collectionPathParam = "collection" } return &hook.Handler[*core.RequestEvent]{ Id: DefaultRateLimitMiddlewareId, Priority: DefaultRateLimitMiddlewarePriority, Func: func(e *core.RequestEvent) error { collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam)) if err != nil { return e.NotFoundError("Missing or invalid collection context.", err) } if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil { return err } return e.Next() }, } } // checkCollectionRateLimit checks whether the current request satisfy the // rate limit configuration for the specific collection. // // Each baseTags entry will be prefixed with the collection name and its wildcard variant. func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error { if skipRateLimit(e) { return nil } labels := make([]string, 0, 2+len(baseTags)*2) rtId := collection.Id + e.Request.Pattern // add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"]) for _, baseTag := range baseTags { rtId += baseTag labels = append(labels, collection.Name+":"+baseTag) } // add the wildcard labels (aka. [..., "*:action1","*:action2", "*"]) for _, baseTag := range baseTags { labels = append(labels, "*:"+baseTag) } labels = append(labels, defaultRateLimitLabels(e)...) rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels) if ok { return checkRateLimit(e, rtId, rule) } return nil } // ------------------------------------------------------------------- // @todo consider exporting as helper? // //nolint:unused func isClientRateLimited(e *core.RequestEvent, rtId string) bool { rateLimiters, ok := e.App.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter]) if !ok || rateLimiters == nil { return false } rt, ok := rateLimiters.GetOk(rtId) if !ok || rt == nil { return false } client, ok := rt.getClient(e.RealIP()) if !ok || client == nil { return false } return client.available <= 0 && time.Now().Unix()-client.lastConsume < client.interval } // @todo consider exporting as helper? func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error { switch rule.Audience { case core.RateLimitRuleAudienceAll: // valid for both guest and regular users case core.RateLimitRuleAudienceGuest: if e.Auth != nil { return nil } case core.RateLimitRuleAudienceAuth: if e.Auth == nil { return nil } } rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any { return initRateLimitersStore(e.App) }).(*store.Store[*rateLimiter]) if rateLimiters == nil { e.App.Logger().Warn("Failed to retrieve app rate limiters store") return nil } rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter { return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800) }) if rt == nil { e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId) return nil } key := e.RealIP() if key == "" { e.App.Logger().Warn("Empty rate limit client key") return nil } if !rt.isAllowed(key) { return e.TooManyRequestsError("", nil) } return nil } func skipRateLimit(e *core.RequestEvent) bool { return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth() } func defaultRateLimitLabels(e *core.RequestEvent) []string { return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path} } func destroyRateLimitersStore(app core.App) { app.OnSettingsReload().Unbind(rateLimitersSettingsHookId) app.Cron().Remove(rateLimitersCronKey) app.Store().Remove(rateLimitersStoreKey) } func initRateLimitersStore(app core.App) *store.Store[*rateLimiter] { app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00 limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter]) if !ok { return } limiters := limitersStore.GetAll() for _, limiter := range limiters { limiter.clean() } }) app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{ Id: rateLimitersSettingsHookId, Func: func(e *core.SettingsReloadEvent) error { err := e.Next() if err != nil { return err } // reset destroyRateLimitersStore(e.App) return nil }, }) return store.New[*rateLimiter](nil) } func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter { return &rateLimiter{ maxAllowed: maxAllowed, interval: intervalInSec, minDeleteInterval: minDeleteIntervalInSec, clients: map[string]*fixedWindow{}, } } type rateLimiter struct { clients map[string]*fixedWindow maxAllowed int interval int64 minDeleteInterval int64 totalDeleted int64 sync.RWMutex } //nolint:unused func (rt *rateLimiter) getClient(key string) (*fixedWindow, bool) { rt.RLock() client, ok := rt.clients[key] rt.RUnlock() return client, ok } func (rt *rateLimiter) isAllowed(key string) bool { // lock only reads to minimize locks contention rt.RLock() client, ok := rt.clients[key] rt.RUnlock() if !ok { rt.Lock() // check again in case the client was added by another request client, ok = rt.clients[key] if !ok { client = newFixedWindow(rt.maxAllowed, rt.interval) rt.clients[key] = client } rt.Unlock() } return client.consume() } func (rt *rateLimiter) clean() { rt.Lock() defer rt.Unlock() nowUnix := time.Now().Unix() for k, client := range rt.clients { if client.hasExpired(nowUnix, rt.minDeleteInterval) { delete(rt.clients, k) rt.totalDeleted++ } } // "shrink" the map if too may items were deleted // // @todo remove after https://github.com/golang/go/issues/20135 if rt.totalDeleted >= 300 { shrunk := make(map[string]*fixedWindow, len(rt.clients)) for k, v := range rt.clients { shrunk[k] = v } rt.clients = shrunk rt.totalDeleted = 0 } } func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow { return &fixedWindow{ maxAllowed: maxAllowed, interval: intervalInSec, } } type fixedWindow struct { // use plain Mutex instead of RWMutex since the operations are expected // to be mostly writes (e.g. consume()) and it should perform better sync.Mutex maxAllowed int // the max allowed tokens per interval available int // the total available tokens interval int64 // in seconds lastConsume int64 // the time of the last consume } // hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time. // (usually used to perform periodic cleanup of staled instances). func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool { l.Lock() defer l.Unlock() return relativeNow-l.lastConsume > minElapsed } // consume decrease the current window allowance with 1 (if not exhausted already). // // It returns false if the allowance has been already exhausted and the user // has to wait until it resets back to its maxAllowed value. func (l *fixedWindow) consume() bool { l.Lock() defer l.Unlock() nowUnix := time.Now().Unix() // reset consumed counter if nowUnix-l.lastConsume >= l.interval { l.available = l.maxAllowed } if l.available > 0 { l.available-- l.lastConsume = nowUnix return true } return false }