1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-03-19 06:07:48 +02:00
pocketbase/apis/middlewares_rate_limit_test.go

160 lines
3.6 KiB
Go

package apis_test
import (
"net/http/httptest"
"testing"
"time"
"github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tests"
)
func TestDefaultRateLimitMiddleware(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
app.Settings().RateLimits.Enabled = true
app.Settings().RateLimits.Rules = []core.RateLimitRule{
{
Label: "/rate/",
MaxRequests: 2,
Duration: 1,
},
{
Label: "/rate/b",
MaxRequests: 3,
Duration: 1,
},
{
Label: "POST /rate/b",
MaxRequests: 1,
Duration: 1,
},
{
Label: "/rate/guest",
MaxRequests: 1,
Duration: 1,
Audience: core.RateLimitRuleAudienceGuest,
},
{
Label: "/rate/auth",
MaxRequests: 1,
Duration: 1,
Audience: core.RateLimitRuleAudienceAuth,
},
}
pbRouter, err := apis.NewRouter(app)
if err != nil {
t.Fatal(err)
}
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
return e.String(200, "norate")
}).BindFunc(func(e *core.RequestEvent) error {
return e.Next()
})
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
return e.String(200, "a")
})
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
return e.String(200, "b")
})
pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
return e.String(200, "guest")
})
pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
return e.String(200, "auth")
})
mux, err := pbRouter.BuildMux()
if err != nil {
t.Fatal(err)
}
scenarios := []struct {
url string
wait float64
authenticated bool
expectedStatus int
}{
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/norate", 0, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 429},
{"/rate/a", 0, false, 429},
{"/rate/a", 1.1, false, 200},
{"/rate/a", 0, false, 200},
{"/rate/a", 0, false, 429},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 429},
{"/rate/b", 1.1, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 200},
{"/rate/b", 0, false, 429},
// "auth" with guest (should fallback to the /rate/ rule)
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 200},
{"/rate/auth", 0, false, 429},
{"/rate/auth", 0, false, 429},
// "auth" rule with regular user (should match the /rate/auth rule)
{"/rate/auth", 0, true, 200},
{"/rate/auth", 0, true, 429},
{"/rate/auth", 0, true, 429},
// "guest" with guest (should match the /rate/guest rule)
{"/rate/guest", 0, false, 200},
{"/rate/guest", 0, false, 429},
{"/rate/guest", 0, false, 429},
// "guest" rule with regular user (should fallback to the /rate/ rule)
{"/rate/guest", 1, true, 200},
{"/rate/guest", 0, true, 200},
{"/rate/guest", 0, true, 429},
{"/rate/guest", 0, true, 429},
}
for _, s := range scenarios {
t.Run(s.url, func(t *testing.T) {
if s.wait > 0 {
time.Sleep(time.Duration(s.wait) * time.Second)
}
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", s.url, nil)
if s.authenticated {
auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
if err != nil {
t.Fatal(err)
}
token, err := auth.NewAuthToken()
if err != nil {
t.Fatal(err)
}
req.Header.Add("Authorization", token)
}
mux.ServeHTTP(rec, req)
result := rec.Result()
if result.StatusCode != s.expectedStatus {
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
}
})
}
}