mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-03-19 06:07:48 +02:00
160 lines
3.6 KiB
Go
160 lines
3.6 KiB
Go
package apis_test
|
|
|
|
import (
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pocketbase/pocketbase/apis"
|
|
"github.com/pocketbase/pocketbase/core"
|
|
"github.com/pocketbase/pocketbase/tests"
|
|
)
|
|
|
|
func TestDefaultRateLimitMiddleware(t *testing.T) {
|
|
app, _ := tests.NewTestApp()
|
|
defer app.Cleanup()
|
|
|
|
app.Settings().RateLimits.Enabled = true
|
|
app.Settings().RateLimits.Rules = []core.RateLimitRule{
|
|
{
|
|
Label: "/rate/",
|
|
MaxRequests: 2,
|
|
Duration: 1,
|
|
},
|
|
{
|
|
Label: "/rate/b",
|
|
MaxRequests: 3,
|
|
Duration: 1,
|
|
},
|
|
{
|
|
Label: "POST /rate/b",
|
|
MaxRequests: 1,
|
|
Duration: 1,
|
|
},
|
|
{
|
|
Label: "/rate/guest",
|
|
MaxRequests: 1,
|
|
Duration: 1,
|
|
Audience: core.RateLimitRuleAudienceGuest,
|
|
},
|
|
{
|
|
Label: "/rate/auth",
|
|
MaxRequests: 1,
|
|
Duration: 1,
|
|
Audience: core.RateLimitRuleAudienceAuth,
|
|
},
|
|
}
|
|
|
|
pbRouter, err := apis.NewRouter(app)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pbRouter.GET("/norate", func(e *core.RequestEvent) error {
|
|
return e.String(200, "norate")
|
|
}).BindFunc(func(e *core.RequestEvent) error {
|
|
return e.Next()
|
|
})
|
|
pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
|
|
return e.String(200, "a")
|
|
})
|
|
pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
|
|
return e.String(200, "b")
|
|
})
|
|
pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
|
|
return e.String(200, "guest")
|
|
})
|
|
pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
|
|
return e.String(200, "auth")
|
|
})
|
|
|
|
mux, err := pbRouter.BuildMux()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
scenarios := []struct {
|
|
url string
|
|
wait float64
|
|
authenticated bool
|
|
expectedStatus int
|
|
}{
|
|
{"/norate", 0, false, 200},
|
|
{"/norate", 0, false, 200},
|
|
{"/norate", 0, false, 200},
|
|
{"/norate", 0, false, 200},
|
|
{"/norate", 0, false, 200},
|
|
|
|
{"/rate/a", 0, false, 200},
|
|
{"/rate/a", 0, false, 200},
|
|
{"/rate/a", 0, false, 429},
|
|
{"/rate/a", 0, false, 429},
|
|
{"/rate/a", 1.1, false, 200},
|
|
{"/rate/a", 0, false, 200},
|
|
{"/rate/a", 0, false, 429},
|
|
|
|
{"/rate/b", 0, false, 200},
|
|
{"/rate/b", 0, false, 200},
|
|
{"/rate/b", 0, false, 200},
|
|
{"/rate/b", 0, false, 429},
|
|
{"/rate/b", 1.1, false, 200},
|
|
{"/rate/b", 0, false, 200},
|
|
{"/rate/b", 0, false, 200},
|
|
{"/rate/b", 0, false, 429},
|
|
|
|
// "auth" with guest (should fallback to the /rate/ rule)
|
|
{"/rate/auth", 0, false, 200},
|
|
{"/rate/auth", 0, false, 200},
|
|
{"/rate/auth", 0, false, 429},
|
|
{"/rate/auth", 0, false, 429},
|
|
|
|
// "auth" rule with regular user (should match the /rate/auth rule)
|
|
{"/rate/auth", 0, true, 200},
|
|
{"/rate/auth", 0, true, 429},
|
|
{"/rate/auth", 0, true, 429},
|
|
|
|
// "guest" with guest (should match the /rate/guest rule)
|
|
{"/rate/guest", 0, false, 200},
|
|
{"/rate/guest", 0, false, 429},
|
|
{"/rate/guest", 0, false, 429},
|
|
|
|
// "guest" rule with regular user (should fallback to the /rate/ rule)
|
|
{"/rate/guest", 1, true, 200},
|
|
{"/rate/guest", 0, true, 200},
|
|
{"/rate/guest", 0, true, 429},
|
|
{"/rate/guest", 0, true, 429},
|
|
}
|
|
|
|
for _, s := range scenarios {
|
|
t.Run(s.url, func(t *testing.T) {
|
|
if s.wait > 0 {
|
|
time.Sleep(time.Duration(s.wait) * time.Second)
|
|
}
|
|
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", s.url, nil)
|
|
|
|
if s.authenticated {
|
|
auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
token, err := auth.NewAuthToken()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
req.Header.Add("Authorization", token)
|
|
}
|
|
|
|
mux.ServeHTTP(rec, req)
|
|
|
|
result := rec.Result()
|
|
|
|
if result.StatusCode != s.expectedStatus {
|
|
t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
|
|
}
|
|
})
|
|
}
|
|
}
|