1
0
mirror of https://github.com/labstack/echo.git synced 2026-05-16 09:48:24 +02:00
Files
echo/middleware/rate_limiter_test.go
T
toimtoimtoim f071367e3c V5 changes
2026-01-18 18:14:41 +02:00

649 lines
18 KiB
Go

// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors
package middleware
import (
"errors"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
func TestRateLimiter(t *testing.T) {
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
testCases := []struct {
id string
expectErr string
}{
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
assert.Panics(t, func() {
RateLimiterWithConfig(RateLimiterConfig{})
})
assert.NotPanics(t, func() {
RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore})
})
}
func TestRateLimiterWithConfig(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c *echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
return "", errors.New("invalid identifier")
}
return id, nil
},
DenyHandler: func(ctx *echo.Context, identifier string, err error) error {
return ctx.JSON(http.StatusForbidden, nil)
},
ErrorHandler: func(ctx *echo.Context, err error) error {
return ctx.JSON(http.StatusBadRequest, nil)
},
Store: inMemoryStore,
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
code int
}{
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusOK},
{"127.0.0.1", http.StatusForbidden},
{"", http.StatusBadRequest},
{"127.0.0.1", http.StatusForbidden},
{"127.0.0.1", http.StatusForbidden},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, tc.code, rec.Code)
}
}
func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw, err := RateLimiterConfig{
IdentifierExtractor: func(c *echo.Context) (string, error) {
id := c.Request().Header.Get(echo.HeaderXRealIP)
if id == "" {
return "", errors.New("invalid identifier")
}
return id, nil
},
Store: inMemoryStore,
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
expectErr string
}{
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
{
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
mw, err := RateLimiterConfig{
Store: inMemoryStore,
}.ToMiddleware()
assert.NoError(t, err)
testCases := []struct {
id string
expectErr string
}{
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
{id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"},
}
for _, tc := range testCases {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, tc.id)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
err := mw(handler)(c)
if tc.expectErr != "" {
assert.EqualError(t, err, tc.expectErr)
} else {
assert.NoError(t, err)
}
assert.Equal(t, http.StatusOK, rec.Code)
}
}
}
func TestRateLimiterWithConfig_skipper(t *testing.T) {
e := echo.New()
var beforeFuncRan bool
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStore(5)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := RateLimiterConfig{
Skipper: func(c *echo.Context) bool {
return true
},
BeforeFunc: func(c *echo.Context) {
beforeFuncRan = true
},
Store: inMemoryStore,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
return "127.0.0.1", nil
},
}.ToMiddleware()
assert.NoError(t, err)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, false, beforeFuncRan)
}
func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
e := echo.New()
var beforeFuncRan bool
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var inMemoryStore = NewRateLimiterMemoryStore(5)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := RateLimiterConfig{
Skipper: func(c *echo.Context) bool {
return false
},
BeforeFunc: func(c *echo.Context) {
beforeFuncRan = true
},
Store: inMemoryStore,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
return "127.0.0.1", nil
},
}.ToMiddleware()
assert.NoError(t, err)
_ = mw(handler)(c)
assert.Equal(t, true, beforeFuncRan)
}
func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
e := echo.New()
handler := func(c *echo.Context) error {
return c.String(http.StatusOK, "test")
}
var beforeRan bool
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
mw, err := RateLimiterConfig{
BeforeFunc: func(c *echo.Context) {
beforeRan = true
},
Store: inMemoryStore,
IdentifierExtractor: func(ctx *echo.Context) (string, error) {
return "127.0.0.1", nil
},
}.ToMiddleware()
assert.NoError(t, err)
err = mw(handler)(c)
assert.NoError(t, err)
assert.Equal(t, true, beforeRan)
}
func TestRateLimiterMemoryStore_Allow(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second})
testCases := []struct {
id string
allowed bool
}{
{"127.0.0.1", true}, // 0 ms
{"127.0.0.1", true}, // 220 ms burst #2
{"127.0.0.1", true}, // 440 ms burst #3
{"127.0.0.1", false}, // 660 ms block
{"127.0.0.1", false}, // 880 ms block
{"127.0.0.1", true}, // 1100 ms next second #1
{"127.0.0.2", true}, // 1320 ms allow other ip
{"127.0.0.1", false}, // 1540 ms no burst
{"127.0.0.1", false}, // 1760 ms no burst
{"127.0.0.1", false}, // 1980 ms no burst
{"127.0.0.1", true}, // 2200 ms no burst
{"127.0.0.1", false}, // 2420 ms no burst
{"127.0.0.1", false}, // 2640 ms no burst
{"127.0.0.1", false}, // 2860 ms no burst
{"127.0.0.1", true}, // 3080 ms no burst
{"127.0.0.1", false}, // 3300 ms no burst
{"127.0.0.1", false}, // 3520 ms no burst
{"127.0.0.1", false}, // 3740 ms no burst
{"127.0.0.1", false}, // 3960 ms no burst
{"127.0.0.1", true}, // 4180 ms no burst
{"127.0.0.1", false}, // 4400 ms no burst
{"127.0.0.1", false}, // 4620 ms no burst
{"127.0.0.1", false}, // 4840 ms no burst
{"127.0.0.1", true}, // 5060 ms no burst
}
for i, tc := range testCases {
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
inMemoryStore.timeNow = func() time.Time {
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
}
allowed, _ := inMemoryStore.Allow(tc.id)
assert.Equal(t, tc.allowed, allowed)
}
}
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
inMemoryStore.visitors = map[string]*Visitor{
"A": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now(),
},
"B": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now().Add(-1 * time.Minute),
},
"C": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now().Add(-5 * time.Minute),
},
"D": {
Limiter: rate.NewLimiter(1, 3),
lastSeen: time.Now().Add(-10 * time.Minute),
},
}
inMemoryStore.Allow("D")
inMemoryStore.cleanupStaleVisitors(time.Now())
var exists bool
_, exists = inMemoryStore.visitors["A"]
assert.Equal(t, true, exists)
_, exists = inMemoryStore.visitors["B"]
assert.Equal(t, true, exists)
_, exists = inMemoryStore.visitors["C"]
assert.Equal(t, false, exists)
_, exists = inMemoryStore.visitors["D"]
assert.Equal(t, true, exists)
}
func TestNewRateLimiterMemoryStore(t *testing.T) {
testCases := []struct {
rate float64
burst int
expiresIn time.Duration
expectedExpiresIn time.Duration
}{
{1, 3, 5 * time.Second, 5 * time.Second},
{2, 4, 0, 3 * time.Minute},
{1, 5, 10 * time.Minute, 10 * time.Minute},
{3, 7, 0, 3 * time.Minute},
}
for _, tc := range testCases {
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn})
assert.Equal(t, tc.rate, store.rate)
assert.Equal(t, tc.burst, store.burst)
assert.Equal(t, tc.expectedExpiresIn, store.expiresIn)
}
}
func TestRateLimiterMemoryStore_FractionalRateDefaultBurst(t *testing.T) {
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 0.5, // fractional rate should get a burst of at least 1
})
base := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
store.timeNow = func() time.Time {
return base
}
allowed, err := store.Allow("user")
assert.NoError(t, err)
assert.True(t, allowed, "first request should not be blocked")
allowed, err = store.Allow("user")
assert.NoError(t, err)
assert.False(t, allowed, "burst token should be consumed immediately")
store.timeNow = func() time.Time {
return base.Add(2 * time.Second)
}
allowed, err = store.Allow("user")
assert.NoError(t, err)
assert.True(t, allowed, "token should refill for fractional rate after time passes")
}
func generateAddressList(count int) []string {
addrs := make([]string, count)
for i := 0; i < count; i++ {
addrs[i] = randomString(15)
}
return addrs
}
func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) {
for i := 0; i < b.N; i++ {
store.Allow(addrs[rand.Intn(max)])
}
wg.Done()
}
func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) {
addrs := generateAddressList(max)
wg := &sync.WaitGroup{}
for i := 0; i < parallel; i++ {
wg.Add(1)
go run(wg, store, addrs, max, b)
}
wg.Wait()
}
const (
testExpiresIn = 1000 * time.Millisecond
)
func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 10, 1000, b)
}
func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 10, 10000, b)
}
func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 10, 100000, b)
}
func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) {
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
benchmarkStore(store, 100, 10000, b)
}
// TestRateLimiterMemoryStore_TOCTOUFix verifies that the TOCTOU race condition is fixed
// by ensuring timeNow() is only called once per Allow() call
func TestRateLimiterMemoryStore_TOCTOUFix(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 1,
Burst: 1,
ExpiresIn: 2 * time.Second,
})
// Track time calls to verify we use the same time value
timeCallCount := 0
baseTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
store.timeNow = func() time.Time {
timeCallCount++
return baseTime
}
// First request - should succeed
allowed, err := store.Allow("127.0.0.1")
assert.NoError(t, err)
assert.True(t, allowed, "First request should be allowed")
// Verify timeNow() was only called once
assert.Equal(t, 1, timeCallCount, "timeNow() should only be called once per Allow()")
}
// TestRateLimiterMemoryStore_ConcurrentAccess verifies rate limiting correctness under concurrent load
func TestRateLimiterMemoryStore_ConcurrentAccess(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 10,
Burst: 5,
ExpiresIn: 5 * time.Second,
})
const goroutines = 50
const requestsPerGoroutine = 20
var wg sync.WaitGroup
var allowedCount, deniedCount int32
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
allowed, err := store.Allow("test-user")
assert.NoError(t, err)
if allowed {
atomic.AddInt32(&allowedCount, 1)
} else {
atomic.AddInt32(&deniedCount, 1)
}
time.Sleep(time.Millisecond)
}
}()
}
wg.Wait()
totalRequests := goroutines * requestsPerGoroutine
allowed := int(allowedCount)
denied := int(deniedCount)
assert.Equal(t, totalRequests, allowed+denied, "All requests should be processed")
assert.Greater(t, denied, 0, "Some requests should be denied due to rate limiting")
assert.Greater(t, allowed, 0, "Some requests should be allowed")
}
// TestRateLimiterMemoryStore_RaceDetection verifies no data races with high concurrency
// Run with: go test -race ./middleware -run TestRateLimiterMemoryStore_RaceDetection
func TestRateLimiterMemoryStore_RaceDetection(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 100,
Burst: 200,
ExpiresIn: 1 * time.Second,
})
const goroutines = 100
const requestsPerGoroutine = 100
var wg sync.WaitGroup
identifiers := []string{"user1", "user2", "user3", "user4", "user5"}
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(routineID int) {
defer wg.Done()
for j := 0; j < requestsPerGoroutine; j++ {
identifier := identifiers[routineID%len(identifiers)]
_, err := store.Allow(identifier)
assert.NoError(t, err)
}
}(i)
}
wg.Wait()
}
// TestRateLimiterMemoryStore_TimeOrdering verifies time ordering consistency in rate limiting decisions
func TestRateLimiterMemoryStore_TimeOrdering(t *testing.T) {
t.Parallel()
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: 1,
Burst: 2,
ExpiresIn: 5 * time.Second,
})
currentTime := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
store.timeNow = func() time.Time {
return currentTime
}
// First two requests should succeed (burst=2)
allowed1, _ := store.Allow("user1")
assert.True(t, allowed1, "Request 1 should be allowed (burst)")
allowed2, _ := store.Allow("user1")
assert.True(t, allowed2, "Request 2 should be allowed (burst)")
// Third request should be denied
allowed3, _ := store.Allow("user1")
assert.False(t, allowed3, "Request 3 should be denied (burst exhausted)")
// Advance time by 1 second
currentTime = currentTime.Add(1 * time.Second)
// Fourth request should succeed
allowed4, _ := store.Allow("user1")
assert.True(t, allowed4, "Request 4 should be allowed (1 token available)")
}