mirror of
https://github.com/labstack/echo.git
synced 2024-11-24 08:22:21 +02:00
adds middleware for rate limiting (#1724)
* adds middleware for rate limiting * added comment for InMemoryStore ShouldAllow * removed redundant mutex declaration * fixed lint issues * removed sleep from tests * improved coverage * refactor: renames Identifiers, includes default SourceFunc * Added last seen stats for visitor * uses http Constants for improved readdability adds default error handler * used other handler apart from default handler to mark custom error handler for rate limiting * split tests into separate blocks added an error pair to IdentifierExtractor Includes deny handler for explicitly denying requests * adds comments for exported members Extractor and ErrorHandler * makes cleanup implementation inhouse * Avoid race for cleanup due to non-atomic access to store.expiresIn * Use a dedicated producer for rate testing * tidy commit * refactors tests, implicitly tests lastSeen property on visitor switches NewRateLimiterMemoryStore constructor to Referential Functions style (Advised by @pafuent) * switches to mock of time module for time based tests tests are now fully deterministic * improved coverage * replaces Rob Pike referential options with more conventional struct configs makes cleanup asynchronous * blocks racy access to lastCleanup * Add benchmark tests for rate limiter * Add rate limiter with sharded memory store * Racy access to store.lastCleanup eliminated Merges in shiny sharded map implementation by @lammel * Remove RateLimiterShradedMemoryStore for now * Make fields for RateLimiterStoreConfig public for external configuration * Improve docs for RateLimiter usage * Fix ErrorHandler vs. DenyHandler usage for rate limiter * Simplify NewRateLimiterMemoryStore * improved coverage * updated errorHandler and denyHandler to use echo.HTTPError * Improve wording for error and comments * Remove duplicate lastSeen marking for Allow * Improve wording for comments * Add disclaimer on perf characteristics of memory store * changes Allow signature on rate limiter to return err too Co-authored-by: Roland Lammel <rl@neotel.at>
This commit is contained in:
parent
9b0e63046b
commit
7c8592a7e0
1
.gitignore
vendored
1
.gitignore
vendored
@ -5,3 +5,4 @@ vendor
|
|||||||
.idea
|
.idea
|
||||||
*.iml
|
*.iml
|
||||||
*.out
|
*.out
|
||||||
|
.vscode
|
||||||
|
1
go.mod
1
go.mod
@ -12,4 +12,5 @@ require (
|
|||||||
golang.org/x/net v0.0.0-20200822124328-c89045814202
|
golang.org/x/net v0.0.0-20200822124328-c89045814202
|
||||||
golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect
|
golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect
|
||||||
golang.org/x/text v0.3.3 // indirect
|
golang.org/x/text v0.3.3 // indirect
|
||||||
|
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324
|
||||||
)
|
)
|
||||||
|
2
go.sum
2
go.sum
@ -46,6 +46,8 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
|
|||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE=
|
||||||
|
golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
268
middleware/rate_limiter.go
Normal file
268
middleware/rate_limiter.go
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// RateLimiterStore is the interface to be implemented by custom stores.
|
||||||
|
RateLimiterStore interface {
|
||||||
|
// Stores for the rate limiter have to implement the Allow method
|
||||||
|
Allow(identifier string) (bool, error)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// RateLimiterConfig defines the configuration for the rate limiter
|
||||||
|
RateLimiterConfig struct {
|
||||||
|
Skipper Skipper
|
||||||
|
BeforeFunc BeforeFunc
|
||||||
|
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
|
||||||
|
IdentifierExtractor Extractor
|
||||||
|
// Store defines a store for the rate limiter
|
||||||
|
Store RateLimiterStore
|
||||||
|
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
|
||||||
|
ErrorHandler func(context echo.Context, err error) error
|
||||||
|
// DenyHandler provides a handler to be called when RateLimiter denies access
|
||||||
|
DenyHandler func(context echo.Context, identifier string, err error) error
|
||||||
|
}
|
||||||
|
// Extractor is used to extract data from echo.Context
|
||||||
|
Extractor func(context echo.Context) (string, error)
|
||||||
|
)
|
||||||
|
|
||||||
|
// errors
|
||||||
|
var (
|
||||||
|
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
|
||||||
|
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
|
||||||
|
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
|
||||||
|
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
|
||||||
|
var DefaultRateLimiterConfig = RateLimiterConfig{
|
||||||
|
Skipper: DefaultSkipper,
|
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||||
|
id := ctx.RealIP()
|
||||||
|
return id, nil
|
||||||
|
},
|
||||||
|
ErrorHandler: func(context echo.Context, err error) error {
|
||||||
|
return &echo.HTTPError{
|
||||||
|
Code: ErrExtractorError.Code,
|
||||||
|
Message: ErrExtractorError.Message,
|
||||||
|
Internal: err,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
DenyHandler: func(context echo.Context, identifier string, err error) error {
|
||||||
|
return &echo.HTTPError{
|
||||||
|
Code: ErrRateLimitExceeded.Code,
|
||||||
|
Message: ErrRateLimitExceeded.Message,
|
||||||
|
Internal: err,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
RateLimiter returns a rate limiting middleware
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
limiterStore := middleware.NewRateLimiterMemoryStore(20)
|
||||||
|
|
||||||
|
e.GET("/rate-limited", func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}, RateLimiter(limiterStore))
|
||||||
|
*/
|
||||||
|
func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc {
|
||||||
|
config := DefaultRateLimiterConfig
|
||||||
|
config.Store = store
|
||||||
|
|
||||||
|
return RateLimiterWithConfig(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
RateLimiterWithConfig returns a rate limiting middleware
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
config := middleware.RateLimiterConfig{
|
||||||
|
Skipper: DefaultSkipper,
|
||||||
|
Store: middleware.NewRateLimiterMemoryStore(
|
||||||
|
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
|
||||||
|
)
|
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||||
|
id := ctx.RealIP()
|
||||||
|
return id, nil
|
||||||
|
},
|
||||||
|
ErrorHandler: func(context echo.Context, err error) error {
|
||||||
|
return context.JSON(http.StatusTooManyRequests, nil)
|
||||||
|
},
|
||||||
|
DenyHandler: func(context echo.Context, identifier string) error {
|
||||||
|
return context.JSON(http.StatusForbidden, nil)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
e.GET("/rate-limited", func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}, middleware.RateLimiterWithConfig(config))
|
||||||
|
*/
|
||||||
|
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
|
||||||
|
if config.Skipper == nil {
|
||||||
|
config.Skipper = DefaultRateLimiterConfig.Skipper
|
||||||
|
}
|
||||||
|
if config.IdentifierExtractor == nil {
|
||||||
|
config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
|
||||||
|
}
|
||||||
|
if config.ErrorHandler == nil {
|
||||||
|
config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
|
||||||
|
}
|
||||||
|
if config.DenyHandler == nil {
|
||||||
|
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
|
||||||
|
}
|
||||||
|
if config.Store == nil {
|
||||||
|
panic("Store configuration must be provided")
|
||||||
|
}
|
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||||
|
return func(c echo.Context) error {
|
||||||
|
if config.Skipper(c) {
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
if config.BeforeFunc != nil {
|
||||||
|
config.BeforeFunc(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
identifier, err := config.IdentifierExtractor(c)
|
||||||
|
if err != nil {
|
||||||
|
c.Error(config.ErrorHandler(c, err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if allow, err := config.Store.Allow(identifier); !allow {
|
||||||
|
c.Error(config.DenyHandler(c, identifier, err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return next(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
|
||||||
|
RateLimiterMemoryStore struct {
|
||||||
|
visitors map[string]*Visitor
|
||||||
|
mutex sync.Mutex
|
||||||
|
rate rate.Limit
|
||||||
|
burst int
|
||||||
|
expiresIn time.Duration
|
||||||
|
lastCleanup time.Time
|
||||||
|
}
|
||||||
|
// Visitor signifies a unique user's limiter details
|
||||||
|
Visitor struct {
|
||||||
|
*rate.Limiter
|
||||||
|
lastSeen time.Time
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
|
||||||
|
the provided rate (as req/s). Burst and ExpiresIn will be set to default values.
|
||||||
|
|
||||||
|
Example (with 20 requests/sec):
|
||||||
|
|
||||||
|
limiterStore := middleware.NewRateLimiterMemoryStore(20)
|
||||||
|
|
||||||
|
*/
|
||||||
|
func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) {
|
||||||
|
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
|
||||||
|
Rate: rate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
|
||||||
|
with the provided configuration. Rate must be provided. Burst will be set to the value of
|
||||||
|
the configured rate if not provided or set to 0.
|
||||||
|
|
||||||
|
The build-in memory store is usually capable for modest loads. For higher loads other
|
||||||
|
store implementations should be considered.
|
||||||
|
|
||||||
|
Characteristics:
|
||||||
|
* Concurrency above 100 parallel requests may causes measurable lock contention
|
||||||
|
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
|
||||||
|
* A high number of requests from a single IP address may cause lock contention
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
|
||||||
|
middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes},
|
||||||
|
)
|
||||||
|
*/
|
||||||
|
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
|
||||||
|
store = &RateLimiterMemoryStore{}
|
||||||
|
|
||||||
|
store.rate = config.Rate
|
||||||
|
store.burst = config.Burst
|
||||||
|
store.expiresIn = config.ExpiresIn
|
||||||
|
if config.ExpiresIn == 0 {
|
||||||
|
store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
|
||||||
|
}
|
||||||
|
if config.Burst == 0 {
|
||||||
|
store.burst = int(config.Rate)
|
||||||
|
}
|
||||||
|
store.visitors = make(map[string]*Visitor)
|
||||||
|
store.lastCleanup = now()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
|
||||||
|
type RateLimiterMemoryStoreConfig struct {
|
||||||
|
Rate rate.Limit // Rate of requests allowed to pass as req/s
|
||||||
|
Burst int // Burst additionally allows a number of requests to pass when rate limit is reached
|
||||||
|
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
|
||||||
|
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
|
||||||
|
ExpiresIn: 3 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow implements RateLimiterStore.Allow
|
||||||
|
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
|
||||||
|
store.mutex.Lock()
|
||||||
|
limiter, exists := store.visitors[identifier]
|
||||||
|
if !exists {
|
||||||
|
limiter = new(Visitor)
|
||||||
|
limiter.Limiter = rate.NewLimiter(store.rate, store.burst)
|
||||||
|
store.visitors[identifier] = limiter
|
||||||
|
}
|
||||||
|
limiter.lastSeen = now()
|
||||||
|
if now().Sub(store.lastCleanup) > store.expiresIn {
|
||||||
|
store.cleanupStaleVisitors()
|
||||||
|
}
|
||||||
|
store.mutex.Unlock()
|
||||||
|
return limiter.AllowN(now(), 1), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
|
||||||
|
of users who haven't visited again after the configured expiry time has elapsed
|
||||||
|
*/
|
||||||
|
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
|
||||||
|
for id, visitor := range store.visitors {
|
||||||
|
if now().Sub(visitor.lastSeen) > store.expiresIn {
|
||||||
|
delete(store.visitors, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
store.lastCleanup = now()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
actual time method which is mocked in test file
|
||||||
|
*/
|
||||||
|
var now = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
462
middleware/rate_limiter_test.go
Normal file
462
middleware/rate_limiter_test.go
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/labstack/echo/v4"
|
||||||
|
"github.com/labstack/gommon/random"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimiter(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
|
||||||
|
mw := RateLimiter(inMemoryStore)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
id string
|
||||||
|
code int
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, tc.id)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
assert.Equal(t, tc.code, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_panicBehaviour(t *testing.T) {
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
RateLimiter(nil)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
RateLimiter(inMemoryStore)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithConfig(t *testing.T) {
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
|
||||||
|
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||||
|
IdentifierExtractor: func(c echo.Context) (string, error) {
|
||||||
|
id := c.Request().Header.Get(echo.HeaderXRealIP)
|
||||||
|
if id == "" {
|
||||||
|
return "", errors.New("invalid identifier")
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
},
|
||||||
|
DenyHandler: func(ctx echo.Context, identifier string, err error) error {
|
||||||
|
return ctx.JSON(http.StatusForbidden, nil)
|
||||||
|
},
|
||||||
|
ErrorHandler: func(ctx echo.Context, err error) error {
|
||||||
|
return ctx.JSON(http.StatusBadRequest, nil)
|
||||||
|
},
|
||||||
|
Store: inMemoryStore,
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
id string
|
||||||
|
code int
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusForbidden},
|
||||||
|
{"", http.StatusBadRequest},
|
||||||
|
{"127.0.0.1", http.StatusForbidden},
|
||||||
|
{"127.0.0.1", http.StatusForbidden},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, tc.id)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.code, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) {
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
|
||||||
|
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||||
|
IdentifierExtractor: func(c echo.Context) (string, error) {
|
||||||
|
id := c.Request().Header.Get(echo.HeaderXRealIP)
|
||||||
|
if id == "" {
|
||||||
|
return "", errors.New("invalid identifier")
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
},
|
||||||
|
Store: inMemoryStore,
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
id string
|
||||||
|
code int
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"", http.StatusForbidden},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, tc.id)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.code, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithConfig_defaultConfig(t *testing.T) {
|
||||||
|
{
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
|
||||||
|
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||||
|
Store: inMemoryStore,
|
||||||
|
})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
id string
|
||||||
|
code int
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusOK},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
{"127.0.0.1", http.StatusTooManyRequests},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, tc.id)
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.code, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithConfig_skipper(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
var beforeFuncRan bool
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStore(5)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||||
|
Skipper: func(c echo.Context) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
BeforeFunc: func(c echo.Context) {
|
||||||
|
beforeFuncRan = true
|
||||||
|
},
|
||||||
|
Store: inMemoryStore,
|
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||||
|
return "127.0.0.1", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
|
||||||
|
assert.Equal(t, false, beforeFuncRan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
var beforeFuncRan bool
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStore(5)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||||
|
Skipper: func(c echo.Context) bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
BeforeFunc: func(c echo.Context) {
|
||||||
|
beforeFuncRan = true
|
||||||
|
},
|
||||||
|
Store: inMemoryStore,
|
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||||
|
return "127.0.0.1", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
|
||||||
|
assert.Equal(t, true, beforeFuncRan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWithConfig_beforeFunc(t *testing.T) {
|
||||||
|
e := echo.New()
|
||||||
|
|
||||||
|
handler := func(c echo.Context) error {
|
||||||
|
return c.String(http.StatusOK, "test")
|
||||||
|
}
|
||||||
|
|
||||||
|
var beforeRan bool
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Add(echo.HeaderXRealIP, "127.0.0.1")
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
c := e.NewContext(req, rec)
|
||||||
|
|
||||||
|
mw := RateLimiterWithConfig(RateLimiterConfig{
|
||||||
|
BeforeFunc: func(c echo.Context) {
|
||||||
|
beforeRan = true
|
||||||
|
},
|
||||||
|
Store: inMemoryStore,
|
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) {
|
||||||
|
return "127.0.0.1", nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = mw(handler)(c)
|
||||||
|
|
||||||
|
assert.Equal(t, true, beforeRan)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterMemoryStore_Allow(t *testing.T) {
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second})
|
||||||
|
testCases := []struct {
|
||||||
|
id string
|
||||||
|
allowed bool
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", true}, // 0 ms
|
||||||
|
{"127.0.0.1", true}, // 220 ms burst #2
|
||||||
|
{"127.0.0.1", true}, // 440 ms burst #3
|
||||||
|
{"127.0.0.1", false}, // 660 ms block
|
||||||
|
{"127.0.0.1", false}, // 880 ms block
|
||||||
|
{"127.0.0.1", true}, // 1100 ms next second #1
|
||||||
|
{"127.0.0.2", true}, // 1320 ms allow other ip
|
||||||
|
{"127.0.0.1", false}, // 1540 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 1760 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 1980 ms no burst
|
||||||
|
{"127.0.0.1", true}, // 2200 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 2420 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 2640 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 2860 ms no burst
|
||||||
|
{"127.0.0.1", true}, // 3080 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 3300 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 3520 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 3740 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 3960 ms no burst
|
||||||
|
{"127.0.0.1", true}, // 4180 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 4400 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 4620 ms no burst
|
||||||
|
{"127.0.0.1", false}, // 4840 ms no burst
|
||||||
|
{"127.0.0.1", true}, // 5060 ms no burst
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond)
|
||||||
|
now = func() time.Time {
|
||||||
|
return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond)
|
||||||
|
}
|
||||||
|
allowed, _ := inMemoryStore.Allow(tc.id)
|
||||||
|
assert.Equal(t, tc.allowed, allowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) {
|
||||||
|
var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3})
|
||||||
|
now = func() time.Time {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
fmt.Println(now())
|
||||||
|
inMemoryStore.visitors = map[string]*Visitor{
|
||||||
|
"A": {
|
||||||
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
|
lastSeen: now(),
|
||||||
|
},
|
||||||
|
"B": {
|
||||||
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
|
lastSeen: now().Add(-1 * time.Minute),
|
||||||
|
},
|
||||||
|
"C": {
|
||||||
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
|
lastSeen: now().Add(-5 * time.Minute),
|
||||||
|
},
|
||||||
|
"D": {
|
||||||
|
Limiter: rate.NewLimiter(1, 3),
|
||||||
|
lastSeen: now().Add(-10 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
inMemoryStore.Allow("D")
|
||||||
|
inMemoryStore.cleanupStaleVisitors()
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
|
||||||
|
_, exists = inMemoryStore.visitors["A"]
|
||||||
|
assert.Equal(t, true, exists)
|
||||||
|
|
||||||
|
_, exists = inMemoryStore.visitors["B"]
|
||||||
|
assert.Equal(t, true, exists)
|
||||||
|
|
||||||
|
_, exists = inMemoryStore.visitors["C"]
|
||||||
|
assert.Equal(t, false, exists)
|
||||||
|
|
||||||
|
_, exists = inMemoryStore.visitors["D"]
|
||||||
|
assert.Equal(t, true, exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRateLimiterMemoryStore(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
rate rate.Limit
|
||||||
|
burst int
|
||||||
|
expiresIn time.Duration
|
||||||
|
expectedExpiresIn time.Duration
|
||||||
|
}{
|
||||||
|
{1, 3, 5 * time.Second, 5 * time.Second},
|
||||||
|
{2, 4, 0, 3 * time.Minute},
|
||||||
|
{1, 5, 10 * time.Minute, 10 * time.Minute},
|
||||||
|
{3, 7, 0, 3 * time.Minute},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn})
|
||||||
|
assert.Equal(t, tc.rate, store.rate)
|
||||||
|
assert.Equal(t, tc.burst, store.burst)
|
||||||
|
assert.Equal(t, tc.expectedExpiresIn, store.expiresIn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateAddressList(count int) []string {
|
||||||
|
addrs := make([]string, count)
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
addrs[i] = random.String(15)
|
||||||
|
}
|
||||||
|
return addrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
store.Allow(addrs[rand.Intn(max)])
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) {
|
||||||
|
addrs := generateAddressList(max)
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
for i := 0; i < parallel; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go run(wg, store, addrs, max, b)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
testExpiresIn = 1000 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) {
|
||||||
|
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
|
||||||
|
benchmarkStore(store, 10, 1000, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) {
|
||||||
|
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
|
||||||
|
benchmarkStore(store, 10, 10000, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) {
|
||||||
|
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
|
||||||
|
benchmarkStore(store, 10, 100000, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) {
|
||||||
|
var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn})
|
||||||
|
benchmarkStore(store, 100, 10000, b)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user