package apis_test

import (
	"net/http/httptest"
	"testing"
	"time"

	"github.com/pocketbase/pocketbase/apis"
	"github.com/pocketbase/pocketbase/core"
	"github.com/pocketbase/pocketbase/tests"
)

func TestDefaultRateLimitMiddleware(t *testing.T) {
	app, _ := tests.NewTestApp()
	defer app.Cleanup()

	app.Settings().RateLimits.Enabled = true
	app.Settings().RateLimits.Rules = []core.RateLimitRule{
		{
			Label:       "/rate/",
			MaxRequests: 2,
			Duration:    1,
		},
		{
			Label:       "/rate/b",
			MaxRequests: 3,
			Duration:    1,
		},
		{
			Label:       "POST /rate/b",
			MaxRequests: 1,
			Duration:    1,
		},
		{
			Label:       "/rate/guest",
			MaxRequests: 1,
			Duration:    1,
			Audience:    core.RateLimitRuleAudienceGuest,
		},
		{
			Label:       "/rate/auth",
			MaxRequests: 1,
			Duration:    1,
			Audience:    core.RateLimitRuleAudienceAuth,
		},
	}

	pbRouter, err := apis.NewRouter(app)
	if err != nil {
		t.Fatal(err)
	}
	pbRouter.GET("/norate", func(e *core.RequestEvent) error {
		return e.String(200, "norate")
	}).BindFunc(func(e *core.RequestEvent) error {
		return e.Next()
	})
	pbRouter.GET("/rate/a", func(e *core.RequestEvent) error {
		return e.String(200, "a")
	})
	pbRouter.GET("/rate/b", func(e *core.RequestEvent) error {
		return e.String(200, "b")
	})
	pbRouter.GET("/rate/guest", func(e *core.RequestEvent) error {
		return e.String(200, "guest")
	})
	pbRouter.GET("/rate/auth", func(e *core.RequestEvent) error {
		return e.String(200, "auth")
	})

	mux, err := pbRouter.BuildMux()
	if err != nil {
		t.Fatal(err)
	}

	scenarios := []struct {
		url            string
		wait           float64
		authenticated  bool
		expectedStatus int
	}{
		{"/norate", 0, false, 200},
		{"/norate", 0, false, 200},
		{"/norate", 0, false, 200},
		{"/norate", 0, false, 200},
		{"/norate", 0, false, 200},

		{"/rate/a", 0, false, 200},
		{"/rate/a", 0, false, 200},
		{"/rate/a", 0, false, 429},
		{"/rate/a", 0, false, 429},
		{"/rate/a", 1.1, false, 200},
		{"/rate/a", 0, false, 200},
		{"/rate/a", 0, false, 429},

		{"/rate/b", 0, false, 200},
		{"/rate/b", 0, false, 200},
		{"/rate/b", 0, false, 200},
		{"/rate/b", 0, false, 429},
		{"/rate/b", 1.1, false, 200},
		{"/rate/b", 0, false, 200},
		{"/rate/b", 0, false, 200},
		{"/rate/b", 0, false, 429},

		// "auth" with guest (should fallback to the /rate/ rule)
		{"/rate/auth", 0, false, 200},
		{"/rate/auth", 0, false, 200},
		{"/rate/auth", 0, false, 429},
		{"/rate/auth", 0, false, 429},

		// "auth" rule with regular user (should match the /rate/auth rule)
		{"/rate/auth", 0, true, 200},
		{"/rate/auth", 0, true, 429},
		{"/rate/auth", 0, true, 429},

		// "guest" with guest (should match the /rate/guest rule)
		{"/rate/guest", 0, false, 200},
		{"/rate/guest", 0, false, 429},
		{"/rate/guest", 0, false, 429},

		// "guest" rule with regular user (should fallback to the /rate/ rule)
		{"/rate/guest", 1, true, 200},
		{"/rate/guest", 0, true, 200},
		{"/rate/guest", 0, true, 429},
		{"/rate/guest", 0, true, 429},
	}

	for _, s := range scenarios {
		t.Run(s.url, func(t *testing.T) {
			if s.wait > 0 {
				time.Sleep(time.Duration(s.wait) * time.Second)
			}

			rec := httptest.NewRecorder()
			req := httptest.NewRequest("GET", s.url, nil)

			if s.authenticated {
				auth, err := app.FindAuthRecordByEmail("users", "test@example.com")
				if err != nil {
					t.Fatal(err)
				}

				token, err := auth.NewAuthToken()
				if err != nil {
					t.Fatal(err)
				}

				req.Header.Add("Authorization", token)
			}

			mux.ServeHTTP(rec, req)

			result := rec.Result()

			if result.StatusCode != s.expectedStatus {
				t.Fatalf("Expected response status %d, got %d", s.expectedStatus, result.StatusCode)
			}
		})
	}
}