2024-09-29 19:23:19 +03:00
|
|
|
package apis
|
|
|
|
|
|
|
|
import (
|
|
|
|
"sync"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/pocketbase/pocketbase/core"
|
|
|
|
"github.com/pocketbase/pocketbase/tools/hook"
|
|
|
|
"github.com/pocketbase/pocketbase/tools/store"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
DefaultRateLimitMiddlewareId = "pbRateLimit"
|
|
|
|
DefaultRateLimitMiddlewarePriority = -1000
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
rateLimitersStoreKey = "__pbRateLimiters__"
|
|
|
|
rateLimitersCronKey = "__pbRateLimitersCleanup__"
|
|
|
|
rateLimitersSettingsHookId = "__pbRateLimitersSettingsHook__"
|
|
|
|
)
|
|
|
|
|
|
|
|
// rateLimit defines the global rate limit middleware.
|
|
|
|
//
|
|
|
|
// This middleware is registered by default for all routes.
|
|
|
|
func rateLimit() *hook.Handler[*core.RequestEvent] {
|
|
|
|
return &hook.Handler[*core.RequestEvent]{
|
|
|
|
Id: DefaultRateLimitMiddlewareId,
|
|
|
|
Priority: DefaultRateLimitMiddlewarePriority,
|
|
|
|
Func: func(e *core.RequestEvent) error {
|
|
|
|
if skipRateLimit(e) {
|
|
|
|
return e.Next()
|
|
|
|
}
|
|
|
|
|
|
|
|
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(defaultRateLimitLabels(e))
|
|
|
|
if ok {
|
|
|
|
err := checkRateLimit(e, e.Request.Pattern, rule)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return e.Next()
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// collectionPathRateLimit defines a rate limit middleware for the internal collection handlers.
|
|
|
|
func collectionPathRateLimit(collectionPathParam string, baseTags ...string) *hook.Handler[*core.RequestEvent] {
|
|
|
|
if collectionPathParam == "" {
|
|
|
|
collectionPathParam = "collection"
|
|
|
|
}
|
|
|
|
|
|
|
|
return &hook.Handler[*core.RequestEvent]{
|
|
|
|
Id: DefaultRateLimitMiddlewareId,
|
|
|
|
Priority: DefaultRateLimitMiddlewarePriority,
|
|
|
|
Func: func(e *core.RequestEvent) error {
|
|
|
|
collection, err := e.App.FindCachedCollectionByNameOrId(e.Request.PathValue(collectionPathParam))
|
|
|
|
if err != nil {
|
|
|
|
return e.NotFoundError("Missing or invalid collection context.", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := checkCollectionRateLimit(e, collection, baseTags...); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return e.Next()
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// checkCollectionRateLimit checks whether the current request satisfy the
|
|
|
|
// rate limit configuration for the specific collection.
|
|
|
|
//
|
|
|
|
// Each baseTags entry will be prefixed with the collection name and its wildcard variant.
|
|
|
|
func checkCollectionRateLimit(e *core.RequestEvent, collection *core.Collection, baseTags ...string) error {
|
|
|
|
if skipRateLimit(e) {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
labels := make([]string, 0, 2+len(baseTags)*2)
|
|
|
|
|
|
|
|
rtId := collection.Id + e.Request.Pattern
|
|
|
|
|
|
|
|
// add first the primary labels (aka. ["collectionName:action1", "collectionName:action2"])
|
|
|
|
for _, baseTag := range baseTags {
|
|
|
|
rtId += baseTag
|
|
|
|
labels = append(labels, collection.Name+":"+baseTag)
|
|
|
|
}
|
|
|
|
|
|
|
|
// add the wildcard labels (aka. [..., "*:action1","*:action2", "*"])
|
|
|
|
for _, baseTag := range baseTags {
|
|
|
|
labels = append(labels, "*:"+baseTag)
|
|
|
|
}
|
|
|
|
labels = append(labels, defaultRateLimitLabels(e)...)
|
|
|
|
|
|
|
|
rule, ok := e.App.Settings().RateLimits.FindRateLimitRule(labels)
|
|
|
|
if ok {
|
|
|
|
return checkRateLimit(e, rtId, rule)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// -------------------------------------------------------------------
|
|
|
|
|
2024-11-11 14:24:54 +02:00
|
|
|
// @todo consider exporting as helper?
|
|
|
|
//
|
|
|
|
//nolint:unused
|
|
|
|
func isClientRateLimited(e *core.RequestEvent, rtId string) bool {
|
|
|
|
rateLimiters, ok := e.App.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter])
|
|
|
|
if !ok || rateLimiters == nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
rt, ok := rateLimiters.GetOk(rtId)
|
|
|
|
if !ok || rt == nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
client, ok := rt.getClient(e.RealIP())
|
|
|
|
if !ok || client == nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
return client.available <= 0 && time.Now().Unix()-client.lastConsume < client.interval
|
|
|
|
}
|
|
|
|
|
|
|
|
// @todo consider exporting as helper?
|
2024-09-29 19:23:19 +03:00
|
|
|
func checkRateLimit(e *core.RequestEvent, rtId string, rule core.RateLimitRule) error {
|
2024-11-08 18:04:13 +02:00
|
|
|
switch rule.Audience {
|
|
|
|
case core.RateLimitRuleAudienceAll:
|
|
|
|
// valid for both guest and regular users
|
|
|
|
case core.RateLimitRuleAudienceGuest:
|
|
|
|
if e.Auth != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
case core.RateLimitRuleAudienceAuth:
|
|
|
|
if e.Auth == nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-09-29 19:23:19 +03:00
|
|
|
rateLimiters := e.App.Store().GetOrSet(rateLimitersStoreKey, func() any {
|
|
|
|
return initRateLimitersStore(e.App)
|
|
|
|
}).(*store.Store[*rateLimiter])
|
|
|
|
if rateLimiters == nil {
|
|
|
|
e.App.Logger().Warn("Failed to retrieve app rate limiters store")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
rt := rateLimiters.GetOrSet(rtId, func() *rateLimiter {
|
|
|
|
return newRateLimiter(rule.MaxRequests, rule.Duration, rule.Duration+1800)
|
|
|
|
})
|
|
|
|
if rt == nil {
|
|
|
|
e.App.Logger().Warn("Failed to retrieve app rate limiter", "id", rtId)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
key := e.RealIP()
|
|
|
|
if key == "" {
|
|
|
|
e.App.Logger().Warn("Empty rate limit client key")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
if !rt.isAllowed(key) {
|
|
|
|
return e.TooManyRequestsError("", nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func skipRateLimit(e *core.RequestEvent) bool {
|
|
|
|
return !e.App.Settings().RateLimits.Enabled || e.HasSuperuserAuth()
|
|
|
|
}
|
|
|
|
|
|
|
|
func defaultRateLimitLabels(e *core.RequestEvent) []string {
|
|
|
|
return []string{e.Request.Method + " " + e.Request.URL.Path, e.Request.URL.Path}
|
|
|
|
}
|
|
|
|
|
|
|
|
func destroyRateLimitersStore(app core.App) {
|
|
|
|
app.OnSettingsReload().Unbind(rateLimitersSettingsHookId)
|
|
|
|
app.Cron().Remove(rateLimitersCronKey)
|
|
|
|
app.Store().Remove(rateLimitersStoreKey)
|
|
|
|
}
|
|
|
|
|
|
|
|
func initRateLimitersStore(app core.App) *store.Store[*rateLimiter] {
|
|
|
|
app.Cron().Add(rateLimitersCronKey, "2 * * * *", func() { // offset a little since too many cleanup tasks execute at 00
|
|
|
|
limitersStore, ok := app.Store().Get(rateLimitersStoreKey).(*store.Store[*rateLimiter])
|
|
|
|
if !ok {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
limiters := limitersStore.GetAll()
|
|
|
|
for _, limiter := range limiters {
|
|
|
|
limiter.clean()
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
app.OnSettingsReload().Bind(&hook.Handler[*core.SettingsReloadEvent]{
|
|
|
|
Id: rateLimitersSettingsHookId,
|
|
|
|
Func: func(e *core.SettingsReloadEvent) error {
|
|
|
|
err := e.Next()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// reset
|
|
|
|
destroyRateLimitersStore(e.App)
|
|
|
|
|
|
|
|
return nil
|
|
|
|
},
|
|
|
|
})
|
|
|
|
|
|
|
|
return store.New[*rateLimiter](nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
func newRateLimiter(maxAllowed int, intervalInSec int64, minDeleteIntervalInSec int64) *rateLimiter {
|
|
|
|
return &rateLimiter{
|
|
|
|
maxAllowed: maxAllowed,
|
|
|
|
interval: intervalInSec,
|
|
|
|
minDeleteInterval: minDeleteIntervalInSec,
|
|
|
|
clients: map[string]*fixedWindow{},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type rateLimiter struct {
|
|
|
|
clients map[string]*fixedWindow
|
|
|
|
|
|
|
|
maxAllowed int
|
|
|
|
interval int64
|
|
|
|
minDeleteInterval int64
|
|
|
|
totalDeleted int64
|
|
|
|
|
|
|
|
sync.RWMutex
|
|
|
|
}
|
|
|
|
|
2024-11-11 14:24:54 +02:00
|
|
|
//nolint:unused
|
|
|
|
func (rt *rateLimiter) getClient(key string) (*fixedWindow, bool) {
|
|
|
|
rt.RLock()
|
|
|
|
client, ok := rt.clients[key]
|
|
|
|
rt.RUnlock()
|
|
|
|
|
|
|
|
return client, ok
|
|
|
|
}
|
|
|
|
|
2024-09-29 19:23:19 +03:00
|
|
|
func (rt *rateLimiter) isAllowed(key string) bool {
|
|
|
|
// lock only reads to minimize locks contention
|
|
|
|
rt.RLock()
|
|
|
|
client, ok := rt.clients[key]
|
|
|
|
rt.RUnlock()
|
|
|
|
|
|
|
|
if !ok {
|
|
|
|
rt.Lock()
|
|
|
|
// check again in case the client was added by another request
|
|
|
|
client, ok = rt.clients[key]
|
|
|
|
if !ok {
|
|
|
|
client = newFixedWindow(rt.maxAllowed, rt.interval)
|
|
|
|
rt.clients[key] = client
|
|
|
|
}
|
|
|
|
rt.Unlock()
|
|
|
|
}
|
|
|
|
|
|
|
|
return client.consume()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (rt *rateLimiter) clean() {
|
|
|
|
rt.Lock()
|
|
|
|
defer rt.Unlock()
|
|
|
|
|
|
|
|
nowUnix := time.Now().Unix()
|
|
|
|
|
|
|
|
for k, client := range rt.clients {
|
|
|
|
if client.hasExpired(nowUnix, rt.minDeleteInterval) {
|
|
|
|
delete(rt.clients, k)
|
|
|
|
rt.totalDeleted++
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// "shrink" the map if too may items were deleted
|
|
|
|
//
|
|
|
|
// @todo remove after https://github.com/golang/go/issues/20135
|
|
|
|
if rt.totalDeleted >= 300 {
|
|
|
|
shrunk := make(map[string]*fixedWindow, len(rt.clients))
|
|
|
|
for k, v := range rt.clients {
|
|
|
|
shrunk[k] = v
|
|
|
|
}
|
|
|
|
rt.clients = shrunk
|
|
|
|
rt.totalDeleted = 0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func newFixedWindow(maxAllowed int, intervalInSec int64) *fixedWindow {
|
|
|
|
return &fixedWindow{
|
|
|
|
maxAllowed: maxAllowed,
|
|
|
|
interval: intervalInSec,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type fixedWindow struct {
|
|
|
|
// use plain Mutex instead of RWMutex since the operations are expected
|
|
|
|
// to be mostly writes (e.g. consume()) and it should perform better
|
|
|
|
sync.Mutex
|
|
|
|
|
|
|
|
maxAllowed int // the max allowed tokens per interval
|
|
|
|
available int // the total available tokens
|
|
|
|
interval int64 // in seconds
|
|
|
|
lastConsume int64 // the time of the last consume
|
|
|
|
}
|
|
|
|
|
|
|
|
// hasExpired checks whether it has been at least minElapsed seconds since the lastConsume time.
|
|
|
|
// (usually used to perform periodic cleanup of staled instances).
|
|
|
|
func (l *fixedWindow) hasExpired(relativeNow int64, minElapsed int64) bool {
|
|
|
|
l.Lock()
|
|
|
|
defer l.Unlock()
|
|
|
|
|
|
|
|
return relativeNow-l.lastConsume > minElapsed
|
|
|
|
}
|
|
|
|
|
|
|
|
// consume decrease the current window allowance with 1 (if not exhausted already).
|
|
|
|
//
|
|
|
|
// It returns false if the allowance has been already exhausted and the user
|
|
|
|
// has to wait until it resets back to its maxAllowed value.
|
|
|
|
func (l *fixedWindow) consume() bool {
|
|
|
|
l.Lock()
|
|
|
|
defer l.Unlock()
|
|
|
|
|
|
|
|
nowUnix := time.Now().Unix()
|
|
|
|
|
|
|
|
// reset consumed counter
|
|
|
|
if nowUnix-l.lastConsume >= l.interval {
|
|
|
|
l.available = l.maxAllowed
|
|
|
|
}
|
|
|
|
|
|
|
|
if l.available > 0 {
|
|
|
|
l.available--
|
|
|
|
l.lastConsume = nowUnix
|
|
|
|
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|