mirror of
				https://github.com/labstack/echo.git
				synced 2025-10-30 23:57:38 +02:00 
			
		
		
		
	adds middleware for rate limiting (#1724)
* adds middleware for rate limiting * added comment for InMemoryStore ShouldAllow * removed redundant mutex declaration * fixed lint issues * removed sleep from tests * improved coverage * refactor: renames Identifiers, includes default SourceFunc * Added last seen stats for visitor * uses http Constants for improved readdability adds default error handler * used other handler apart from default handler to mark custom error handler for rate limiting * split tests into separate blocks added an error pair to IdentifierExtractor Includes deny handler for explicitly denying requests * adds comments for exported members Extractor and ErrorHandler * makes cleanup implementation inhouse * Avoid race for cleanup due to non-atomic access to store.expiresIn * Use a dedicated producer for rate testing * tidy commit * refactors tests, implicitly tests lastSeen property on visitor switches NewRateLimiterMemoryStore constructor to Referential Functions style (Advised by @pafuent) * switches to mock of time module for time based tests tests are now fully deterministic * improved coverage * replaces Rob Pike referential options with more conventional struct configs makes cleanup asynchronous * blocks racy access to lastCleanup * Add benchmark tests for rate limiter * Add rate limiter with sharded memory store * Racy access to store.lastCleanup eliminated Merges in shiny sharded map implementation by @lammel * Remove RateLimiterShradedMemoryStore for now * Make fields for RateLimiterStoreConfig public for external configuration * Improve docs for RateLimiter usage * Fix ErrorHandler vs. DenyHandler usage for rate limiter * Simplify NewRateLimiterMemoryStore * improved coverage * updated errorHandler and denyHandler to use echo.HTTPError * Improve wording for error and comments * Remove duplicate lastSeen marking for Allow * Improve wording for comments * Add disclaimer on perf characteristics of memory store * changes Allow signature on rate limiter to return err too Co-authored-by: Roland Lammel <rl@neotel.at>
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							9b0e63046b
						
					
				
				
					commit
					7c8592a7e0
				
			
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -5,3 +5,4 @@ vendor | ||||
| .idea | ||||
| *.iml | ||||
| *.out | ||||
| .vscode | ||||
|   | ||||
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							| @@ -12,4 +12,5 @@ require ( | ||||
| 	golang.org/x/net v0.0.0-20200822124328-c89045814202 | ||||
| 	golang.org/x/sys v0.0.0-20200826173525-f9321e4c35a6 // indirect | ||||
| 	golang.org/x/text v0.3.3 // indirect | ||||
| 	golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 | ||||
| ) | ||||
|   | ||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							| @@ -46,6 +46,8 @@ golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= | ||||
| golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
|   | ||||
							
								
								
									
										268
									
								
								middleware/rate_limiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								middleware/rate_limiter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,268 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/labstack/echo/v4" | ||||
| 	"golang.org/x/time/rate" | ||||
| ) | ||||
|  | ||||
| type ( | ||||
| 	// RateLimiterStore is the interface to be implemented by custom stores. | ||||
| 	RateLimiterStore interface { | ||||
| 		// Stores for the rate limiter have to implement the Allow method | ||||
| 		Allow(identifier string) (bool, error) | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| type ( | ||||
| 	// RateLimiterConfig defines the configuration for the rate limiter | ||||
| 	RateLimiterConfig struct { | ||||
| 		Skipper    Skipper | ||||
| 		BeforeFunc BeforeFunc | ||||
| 		// IdentifierExtractor uses echo.Context to extract the identifier for a visitor | ||||
| 		IdentifierExtractor Extractor | ||||
| 		// Store defines a store for the rate limiter | ||||
| 		Store RateLimiterStore | ||||
| 		// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error | ||||
| 		ErrorHandler func(context echo.Context, err error) error | ||||
| 		// DenyHandler provides a handler to be called when RateLimiter denies access | ||||
| 		DenyHandler func(context echo.Context, identifier string, err error) error | ||||
| 	} | ||||
| 	// Extractor is used to extract data from echo.Context | ||||
| 	Extractor func(context echo.Context) (string, error) | ||||
| ) | ||||
|  | ||||
| // errors | ||||
| var ( | ||||
| 	// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded | ||||
| 	ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") | ||||
| 	// ErrExtractorError denotes an error raised when extractor function is unsuccessful | ||||
| 	ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") | ||||
| ) | ||||
|  | ||||
| // DefaultRateLimiterConfig defines default values for RateLimiterConfig | ||||
| var DefaultRateLimiterConfig = RateLimiterConfig{ | ||||
| 	Skipper: DefaultSkipper, | ||||
| 	IdentifierExtractor: func(ctx echo.Context) (string, error) { | ||||
| 		id := ctx.RealIP() | ||||
| 		return id, nil | ||||
| 	}, | ||||
| 	ErrorHandler: func(context echo.Context, err error) error { | ||||
| 		return &echo.HTTPError{ | ||||
| 			Code:     ErrExtractorError.Code, | ||||
| 			Message:  ErrExtractorError.Message, | ||||
| 			Internal: err, | ||||
| 		} | ||||
| 	}, | ||||
| 	DenyHandler: func(context echo.Context, identifier string, err error) error { | ||||
| 		return &echo.HTTPError{ | ||||
| 			Code:     ErrRateLimitExceeded.Code, | ||||
| 			Message:  ErrRateLimitExceeded.Message, | ||||
| 			Internal: err, | ||||
| 		} | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| /* | ||||
| RateLimiter returns a rate limiting middleware | ||||
|  | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	limiterStore := middleware.NewRateLimiterMemoryStore(20) | ||||
|  | ||||
| 	e.GET("/rate-limited", func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	}, RateLimiter(limiterStore)) | ||||
| */ | ||||
| func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { | ||||
| 	config := DefaultRateLimiterConfig | ||||
| 	config.Store = store | ||||
|  | ||||
| 	return RateLimiterWithConfig(config) | ||||
| } | ||||
|  | ||||
| /* | ||||
| RateLimiterWithConfig returns a rate limiting middleware | ||||
|  | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	config := middleware.RateLimiterConfig{ | ||||
| 		Skipper: DefaultSkipper, | ||||
| 		Store: middleware.NewRateLimiterMemoryStore( | ||||
| 			middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} | ||||
| 		) | ||||
| 		IdentifierExtractor: func(ctx echo.Context) (string, error) { | ||||
| 			id := ctx.RealIP() | ||||
| 			return id, nil | ||||
| 		}, | ||||
| 		ErrorHandler: func(context echo.Context, err error) error { | ||||
| 			return context.JSON(http.StatusTooManyRequests, nil) | ||||
| 		}, | ||||
| 		DenyHandler: func(context echo.Context, identifier string) error { | ||||
| 			return context.JSON(http.StatusForbidden, nil) | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	e.GET("/rate-limited", func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	}, middleware.RateLimiterWithConfig(config)) | ||||
| */ | ||||
| func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { | ||||
| 	if config.Skipper == nil { | ||||
| 		config.Skipper = DefaultRateLimiterConfig.Skipper | ||||
| 	} | ||||
| 	if config.IdentifierExtractor == nil { | ||||
| 		config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor | ||||
| 	} | ||||
| 	if config.ErrorHandler == nil { | ||||
| 		config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler | ||||
| 	} | ||||
| 	if config.DenyHandler == nil { | ||||
| 		config.DenyHandler = DefaultRateLimiterConfig.DenyHandler | ||||
| 	} | ||||
| 	if config.Store == nil { | ||||
| 		panic("Store configuration must be provided") | ||||
| 	} | ||||
| 	return func(next echo.HandlerFunc) echo.HandlerFunc { | ||||
| 		return func(c echo.Context) error { | ||||
| 			if config.Skipper(c) { | ||||
| 				return next(c) | ||||
| 			} | ||||
| 			if config.BeforeFunc != nil { | ||||
| 				config.BeforeFunc(c) | ||||
| 			} | ||||
|  | ||||
| 			identifier, err := config.IdentifierExtractor(c) | ||||
| 			if err != nil { | ||||
| 				c.Error(config.ErrorHandler(c, err)) | ||||
| 				return nil | ||||
| 			} | ||||
|  | ||||
| 			if allow, err := config.Store.Allow(identifier); !allow { | ||||
| 				c.Error(config.DenyHandler(c, identifier, err)) | ||||
| 				return nil | ||||
| 			} | ||||
| 			return next(c) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ( | ||||
| 	// RateLimiterMemoryStore is the built-in store implementation for RateLimiter | ||||
| 	RateLimiterMemoryStore struct { | ||||
| 		visitors    map[string]*Visitor | ||||
| 		mutex       sync.Mutex | ||||
| 		rate        rate.Limit | ||||
| 		burst       int | ||||
| 		expiresIn   time.Duration | ||||
| 		lastCleanup time.Time | ||||
| 	} | ||||
| 	// Visitor signifies a unique user's limiter details | ||||
| 	Visitor struct { | ||||
| 		*rate.Limiter | ||||
| 		lastSeen time.Time | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| /* | ||||
| NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with | ||||
| the provided rate (as req/s). Burst and ExpiresIn will be set to default values. | ||||
|  | ||||
| Example (with 20 requests/sec): | ||||
|  | ||||
| 	limiterStore := middleware.NewRateLimiterMemoryStore(20) | ||||
|  | ||||
| */ | ||||
| func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { | ||||
| 	return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ | ||||
| 		Rate: rate, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| /* | ||||
| NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore | ||||
| with the provided configuration. Rate must be provided. Burst will be set to the value of | ||||
| the configured rate if not provided or set to 0. | ||||
|  | ||||
| The build-in memory store is usually capable for modest loads. For higher loads other | ||||
| store implementations should be considered. | ||||
|  | ||||
| Characteristics: | ||||
| * Concurrency above 100 parallel requests may causes measurable lock contention | ||||
| * A high number of different IP addresses (above 16000) may be impacted by the internally used Go map | ||||
| * A high number of requests from a single IP address may cause lock contention | ||||
|  | ||||
| Example: | ||||
|  | ||||
| 	limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( | ||||
| 		middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes}, | ||||
| 	) | ||||
| */ | ||||
| func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { | ||||
| 	store = &RateLimiterMemoryStore{} | ||||
|  | ||||
| 	store.rate = config.Rate | ||||
| 	store.burst = config.Burst | ||||
| 	store.expiresIn = config.ExpiresIn | ||||
| 	if config.ExpiresIn == 0 { | ||||
| 		store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn | ||||
| 	} | ||||
| 	if config.Burst == 0 { | ||||
| 		store.burst = int(config.Rate) | ||||
| 	} | ||||
| 	store.visitors = make(map[string]*Visitor) | ||||
| 	store.lastCleanup = now() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore | ||||
| type RateLimiterMemoryStoreConfig struct { | ||||
| 	Rate      rate.Limit    // Rate of requests allowed to pass as req/s | ||||
| 	Burst     int           // Burst additionally allows a number of requests to pass when rate limit is reached | ||||
| 	ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up | ||||
| } | ||||
|  | ||||
| // DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore | ||||
| var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ | ||||
| 	ExpiresIn: 3 * time.Minute, | ||||
| } | ||||
|  | ||||
| // Allow implements RateLimiterStore.Allow | ||||
| func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { | ||||
| 	store.mutex.Lock() | ||||
| 	limiter, exists := store.visitors[identifier] | ||||
| 	if !exists { | ||||
| 		limiter = new(Visitor) | ||||
| 		limiter.Limiter = rate.NewLimiter(store.rate, store.burst) | ||||
| 		store.visitors[identifier] = limiter | ||||
| 	} | ||||
| 	limiter.lastSeen = now() | ||||
| 	if now().Sub(store.lastCleanup) > store.expiresIn { | ||||
| 		store.cleanupStaleVisitors() | ||||
| 	} | ||||
| 	store.mutex.Unlock() | ||||
| 	return limiter.AllowN(now(), 1), nil | ||||
| } | ||||
|  | ||||
| /* | ||||
| cleanupStaleVisitors helps manage the size of the visitors map by removing stale records | ||||
| of users who haven't visited again after the configured expiry time has elapsed | ||||
| */ | ||||
| func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { | ||||
| 	for id, visitor := range store.visitors { | ||||
| 		if now().Sub(visitor.lastSeen) > store.expiresIn { | ||||
| 			delete(store.visitors, id) | ||||
| 		} | ||||
| 	} | ||||
| 	store.lastCleanup = now() | ||||
| } | ||||
|  | ||||
| /* | ||||
| actual time method which is mocked in test file | ||||
| */ | ||||
| var now = func() time.Time { | ||||
| 	return time.Now() | ||||
| } | ||||
							
								
								
									
										462
									
								
								middleware/rate_limiter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										462
									
								
								middleware/rate_limiter_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,462 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/labstack/echo/v4" | ||||
| 	"github.com/labstack/gommon/random" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"golang.org/x/time/rate" | ||||
| ) | ||||
|  | ||||
| func TestRateLimiter(t *testing.T) { | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	handler := func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	} | ||||
|  | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
|  | ||||
| 	mw := RateLimiter(inMemoryStore) | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		id   string | ||||
| 		code int | ||||
| 	}{ | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 		req.Header.Add(echo.HeaderXRealIP, tc.id) | ||||
|  | ||||
| 		rec := httptest.NewRecorder() | ||||
| 		c := e.NewContext(req, rec) | ||||
|  | ||||
| 		_ = mw(handler)(c) | ||||
| 		assert.Equal(t, tc.code, rec.Code) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRateLimiter_panicBehaviour(t *testing.T) { | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
|  | ||||
| 	assert.Panics(t, func() { | ||||
| 		RateLimiter(nil) | ||||
| 	}) | ||||
|  | ||||
| 	assert.NotPanics(t, func() { | ||||
| 		RateLimiter(inMemoryStore) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestRateLimiterWithConfig(t *testing.T) { | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
|  | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	handler := func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	} | ||||
|  | ||||
| 	mw := RateLimiterWithConfig(RateLimiterConfig{ | ||||
| 		IdentifierExtractor: func(c echo.Context) (string, error) { | ||||
| 			id := c.Request().Header.Get(echo.HeaderXRealIP) | ||||
| 			if id == "" { | ||||
| 				return "", errors.New("invalid identifier") | ||||
| 			} | ||||
| 			return id, nil | ||||
| 		}, | ||||
| 		DenyHandler: func(ctx echo.Context, identifier string, err error) error { | ||||
| 			return ctx.JSON(http.StatusForbidden, nil) | ||||
| 		}, | ||||
| 		ErrorHandler: func(ctx echo.Context, err error) error { | ||||
| 			return ctx.JSON(http.StatusBadRequest, nil) | ||||
| 		}, | ||||
| 		Store: inMemoryStore, | ||||
| 	}) | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		id   string | ||||
| 		code int | ||||
| 	}{ | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusForbidden}, | ||||
| 		{"", http.StatusBadRequest}, | ||||
| 		{"127.0.0.1", http.StatusForbidden}, | ||||
| 		{"127.0.0.1", http.StatusForbidden}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 		req.Header.Add(echo.HeaderXRealIP, tc.id) | ||||
|  | ||||
| 		rec := httptest.NewRecorder() | ||||
|  | ||||
| 		c := e.NewContext(req, rec) | ||||
|  | ||||
| 		_ = mw(handler)(c) | ||||
|  | ||||
| 		assert.Equal(t, tc.code, rec.Code) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
|  | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	handler := func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	} | ||||
|  | ||||
| 	mw := RateLimiterWithConfig(RateLimiterConfig{ | ||||
| 		IdentifierExtractor: func(c echo.Context) (string, error) { | ||||
| 			id := c.Request().Header.Get(echo.HeaderXRealIP) | ||||
| 			if id == "" { | ||||
| 				return "", errors.New("invalid identifier") | ||||
| 			} | ||||
| 			return id, nil | ||||
| 		}, | ||||
| 		Store: inMemoryStore, | ||||
| 	}) | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		id   string | ||||
| 		code int | ||||
| 	}{ | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusOK}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 		{"", http.StatusForbidden}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 		{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 		req.Header.Add(echo.HeaderXRealIP, tc.id) | ||||
|  | ||||
| 		rec := httptest.NewRecorder() | ||||
|  | ||||
| 		c := e.NewContext(req, rec) | ||||
|  | ||||
| 		_ = mw(handler)(c) | ||||
|  | ||||
| 		assert.Equal(t, tc.code, rec.Code) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { | ||||
| 	{ | ||||
| 		var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
|  | ||||
| 		e := echo.New() | ||||
|  | ||||
| 		handler := func(c echo.Context) error { | ||||
| 			return c.String(http.StatusOK, "test") | ||||
| 		} | ||||
|  | ||||
| 		mw := RateLimiterWithConfig(RateLimiterConfig{ | ||||
| 			Store: inMemoryStore, | ||||
| 		}) | ||||
|  | ||||
| 		testCases := []struct { | ||||
| 			id   string | ||||
| 			code int | ||||
| 		}{ | ||||
| 			{"127.0.0.1", http.StatusOK}, | ||||
| 			{"127.0.0.1", http.StatusOK}, | ||||
| 			{"127.0.0.1", http.StatusOK}, | ||||
| 			{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 			{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 			{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 			{"127.0.0.1", http.StatusTooManyRequests}, | ||||
| 		} | ||||
|  | ||||
| 		for _, tc := range testCases { | ||||
| 			req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 			req.Header.Add(echo.HeaderXRealIP, tc.id) | ||||
|  | ||||
| 			rec := httptest.NewRecorder() | ||||
|  | ||||
| 			c := e.NewContext(req, rec) | ||||
|  | ||||
| 			_ = mw(handler)(c) | ||||
|  | ||||
| 			assert.Equal(t, tc.code, rec.Code) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRateLimiterWithConfig_skipper(t *testing.T) { | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	var beforeFuncRan bool | ||||
| 	handler := func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	} | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStore(5) | ||||
|  | ||||
| 	req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 	req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") | ||||
|  | ||||
| 	rec := httptest.NewRecorder() | ||||
|  | ||||
| 	c := e.NewContext(req, rec) | ||||
|  | ||||
| 	mw := RateLimiterWithConfig(RateLimiterConfig{ | ||||
| 		Skipper: func(c echo.Context) bool { | ||||
| 			return true | ||||
| 		}, | ||||
| 		BeforeFunc: func(c echo.Context) { | ||||
| 			beforeFuncRan = true | ||||
| 		}, | ||||
| 		Store: inMemoryStore, | ||||
| 		IdentifierExtractor: func(ctx echo.Context) (string, error) { | ||||
| 			return "127.0.0.1", nil | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	_ = mw(handler)(c) | ||||
|  | ||||
| 	assert.Equal(t, false, beforeFuncRan) | ||||
| } | ||||
|  | ||||
| func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	var beforeFuncRan bool | ||||
| 	handler := func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	} | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStore(5) | ||||
|  | ||||
| 	req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 	req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") | ||||
|  | ||||
| 	rec := httptest.NewRecorder() | ||||
|  | ||||
| 	c := e.NewContext(req, rec) | ||||
|  | ||||
| 	mw := RateLimiterWithConfig(RateLimiterConfig{ | ||||
| 		Skipper: func(c echo.Context) bool { | ||||
| 			return false | ||||
| 		}, | ||||
| 		BeforeFunc: func(c echo.Context) { | ||||
| 			beforeFuncRan = true | ||||
| 		}, | ||||
| 		Store: inMemoryStore, | ||||
| 		IdentifierExtractor: func(ctx echo.Context) (string, error) { | ||||
| 			return "127.0.0.1", nil | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	_ = mw(handler)(c) | ||||
|  | ||||
| 	assert.Equal(t, true, beforeFuncRan) | ||||
| } | ||||
|  | ||||
| func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { | ||||
| 	e := echo.New() | ||||
|  | ||||
| 	handler := func(c echo.Context) error { | ||||
| 		return c.String(http.StatusOK, "test") | ||||
| 	} | ||||
|  | ||||
| 	var beforeRan bool | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
|  | ||||
| 	req := httptest.NewRequest(http.MethodGet, "/", nil) | ||||
| 	req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") | ||||
|  | ||||
| 	rec := httptest.NewRecorder() | ||||
|  | ||||
| 	c := e.NewContext(req, rec) | ||||
|  | ||||
| 	mw := RateLimiterWithConfig(RateLimiterConfig{ | ||||
| 		BeforeFunc: func(c echo.Context) { | ||||
| 			beforeRan = true | ||||
| 		}, | ||||
| 		Store: inMemoryStore, | ||||
| 		IdentifierExtractor: func(ctx echo.Context) (string, error) { | ||||
| 			return "127.0.0.1", nil | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	_ = mw(handler)(c) | ||||
|  | ||||
| 	assert.Equal(t, true, beforeRan) | ||||
| } | ||||
|  | ||||
| func TestRateLimiterMemoryStore_Allow(t *testing.T) { | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3, ExpiresIn: 2 * time.Second}) | ||||
| 	testCases := []struct { | ||||
| 		id      string | ||||
| 		allowed bool | ||||
| 	}{ | ||||
| 		{"127.0.0.1", true},  // 0 ms | ||||
| 		{"127.0.0.1", true},  // 220 ms burst #2 | ||||
| 		{"127.0.0.1", true},  // 440 ms burst #3 | ||||
| 		{"127.0.0.1", false}, // 660 ms block | ||||
| 		{"127.0.0.1", false}, // 880 ms block | ||||
| 		{"127.0.0.1", true},  // 1100 ms next second #1 | ||||
| 		{"127.0.0.2", true},  // 1320 ms allow other ip | ||||
| 		{"127.0.0.1", false}, // 1540 ms no burst | ||||
| 		{"127.0.0.1", false}, // 1760 ms no burst | ||||
| 		{"127.0.0.1", false}, // 1980 ms no burst | ||||
| 		{"127.0.0.1", true},  // 2200 ms no burst | ||||
| 		{"127.0.0.1", false}, // 2420 ms no burst | ||||
| 		{"127.0.0.1", false}, // 2640 ms no burst | ||||
| 		{"127.0.0.1", false}, // 2860 ms no burst | ||||
| 		{"127.0.0.1", true},  // 3080 ms no burst | ||||
| 		{"127.0.0.1", false}, // 3300 ms no burst | ||||
| 		{"127.0.0.1", false}, // 3520 ms no burst | ||||
| 		{"127.0.0.1", false}, // 3740 ms no burst | ||||
| 		{"127.0.0.1", false}, // 3960 ms no burst | ||||
| 		{"127.0.0.1", true},  // 4180 ms no burst | ||||
| 		{"127.0.0.1", false}, // 4400 ms no burst | ||||
| 		{"127.0.0.1", false}, // 4620 ms no burst | ||||
| 		{"127.0.0.1", false}, // 4840 ms no burst | ||||
| 		{"127.0.0.1", true},  // 5060 ms no burst | ||||
| 	} | ||||
|  | ||||
| 	for i, tc := range testCases { | ||||
| 		t.Logf("Running testcase #%d => %v", i, time.Duration(i)*220*time.Millisecond) | ||||
| 		now = func() time.Time { | ||||
| 			return time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC).Add(time.Duration(i) * 220 * time.Millisecond) | ||||
| 		} | ||||
| 		allowed, _ := inMemoryStore.Allow(tc.id) | ||||
| 		assert.Equal(t, tc.allowed, allowed) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { | ||||
| 	var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) | ||||
| 	now = func() time.Time { | ||||
| 		return time.Now() | ||||
| 	} | ||||
| 	fmt.Println(now()) | ||||
| 	inMemoryStore.visitors = map[string]*Visitor{ | ||||
| 		"A": { | ||||
| 			Limiter:  rate.NewLimiter(1, 3), | ||||
| 			lastSeen: now(), | ||||
| 		}, | ||||
| 		"B": { | ||||
| 			Limiter:  rate.NewLimiter(1, 3), | ||||
| 			lastSeen: now().Add(-1 * time.Minute), | ||||
| 		}, | ||||
| 		"C": { | ||||
| 			Limiter:  rate.NewLimiter(1, 3), | ||||
| 			lastSeen: now().Add(-5 * time.Minute), | ||||
| 		}, | ||||
| 		"D": { | ||||
| 			Limiter:  rate.NewLimiter(1, 3), | ||||
| 			lastSeen: now().Add(-10 * time.Minute), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	inMemoryStore.Allow("D") | ||||
| 	inMemoryStore.cleanupStaleVisitors() | ||||
|  | ||||
| 	var exists bool | ||||
|  | ||||
| 	_, exists = inMemoryStore.visitors["A"] | ||||
| 	assert.Equal(t, true, exists) | ||||
|  | ||||
| 	_, exists = inMemoryStore.visitors["B"] | ||||
| 	assert.Equal(t, true, exists) | ||||
|  | ||||
| 	_, exists = inMemoryStore.visitors["C"] | ||||
| 	assert.Equal(t, false, exists) | ||||
|  | ||||
| 	_, exists = inMemoryStore.visitors["D"] | ||||
| 	assert.Equal(t, true, exists) | ||||
| } | ||||
|  | ||||
| func TestNewRateLimiterMemoryStore(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		rate              rate.Limit | ||||
| 		burst             int | ||||
| 		expiresIn         time.Duration | ||||
| 		expectedExpiresIn time.Duration | ||||
| 	}{ | ||||
| 		{1, 3, 5 * time.Second, 5 * time.Second}, | ||||
| 		{2, 4, 0, 3 * time.Minute}, | ||||
| 		{1, 5, 10 * time.Minute, 10 * time.Minute}, | ||||
| 		{3, 7, 0, 3 * time.Minute}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		store := NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: tc.rate, Burst: tc.burst, ExpiresIn: tc.expiresIn}) | ||||
| 		assert.Equal(t, tc.rate, store.rate) | ||||
| 		assert.Equal(t, tc.burst, store.burst) | ||||
| 		assert.Equal(t, tc.expectedExpiresIn, store.expiresIn) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func generateAddressList(count int) []string { | ||||
| 	addrs := make([]string, count) | ||||
| 	for i := 0; i < count; i++ { | ||||
| 		addrs[i] = random.String(15) | ||||
| 	} | ||||
| 	return addrs | ||||
| } | ||||
|  | ||||
| func run(wg *sync.WaitGroup, store RateLimiterStore, addrs []string, max int, b *testing.B) { | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		store.Allow(addrs[rand.Intn(max)]) | ||||
| 	} | ||||
| 	wg.Done() | ||||
| } | ||||
|  | ||||
| func benchmarkStore(store RateLimiterStore, parallel int, max int, b *testing.B) { | ||||
| 	addrs := generateAddressList(max) | ||||
| 	wg := &sync.WaitGroup{} | ||||
| 	for i := 0; i < parallel; i++ { | ||||
| 		wg.Add(1) | ||||
| 		go run(wg, store, addrs, max, b) | ||||
| 	} | ||||
| 	wg.Wait() | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	testExpiresIn = 1000 * time.Millisecond | ||||
| ) | ||||
|  | ||||
| func BenchmarkRateLimiterMemoryStore_1000(b *testing.B) { | ||||
| 	var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) | ||||
| 	benchmarkStore(store, 10, 1000, b) | ||||
| } | ||||
|  | ||||
| func BenchmarkRateLimiterMemoryStore_10000(b *testing.B) { | ||||
| 	var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) | ||||
| 	benchmarkStore(store, 10, 10000, b) | ||||
| } | ||||
|  | ||||
| func BenchmarkRateLimiterMemoryStore_100000(b *testing.B) { | ||||
| 	var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) | ||||
| 	benchmarkStore(store, 10, 100000, b) | ||||
| } | ||||
|  | ||||
| func BenchmarkRateLimiterMemoryStore_conc100_10000(b *testing.B) { | ||||
| 	var store = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 100, Burst: 200, ExpiresIn: testExpiresIn}) | ||||
| 	benchmarkStore(store, 100, 10000, b) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user