mirror of
https://github.com/labstack/echo.git
synced 2024-11-24 08:22:21 +02:00
CSRF/RequestID mw: switch math/random usage to crypto/random
This commit is contained in:
parent
3f8ae15b57
commit
626f13e338
@ -6,7 +6,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
if config.TokenLength == 0 {
|
||||
config.TokenLength = DefaultCSRFConfig.TokenLength
|
||||
}
|
||||
|
||||
if config.TokenLookup == "" {
|
||||
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
||||
}
|
||||
@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
||||
|
||||
token := ""
|
||||
if k, err := c.Cookie(config.CookieName); err != nil {
|
||||
token = random.String(config.TokenLength) // Generate token
|
||||
token = randomString(config.TokenLength)
|
||||
} else {
|
||||
token = k.Value // Reuse token
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) {
|
||||
assert.Error(t, h(c))
|
||||
|
||||
// Valid CSRF token
|
||||
token := random.String(32)
|
||||
token := randomString(32)
|
||||
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
||||
req.Header.Set(echo.HeaderXCSRFToken, token)
|
||||
if assert.NoError(t, h(c)) {
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
|
||||
func generateAddressList(count int) []string {
|
||||
addrs := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
addrs[i] = random.String(15)
|
||||
addrs[i] = randomString(15)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package middleware
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/random"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -12,7 +11,7 @@ type (
|
||||
Skipper Skipper
|
||||
|
||||
// Generator defines a function to generate an ID.
|
||||
// Optional. Default value random.String(32).
|
||||
// Optional. Defaults to generator for random string of length 32.
|
||||
Generator func() string
|
||||
|
||||
// RequestIDHandler defines a function which is executed for a request id.
|
||||
@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
func generator() string {
|
||||
return random.String(32)
|
||||
return randomString(32)
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func randomString(length uint8) string {
|
||||
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
bytes := make([]byte, length)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
// we are out of random. let the request fail
|
||||
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
|
||||
}
|
||||
for i, b := range bytes {
|
||||
bytes[i] = charset[b%byte(len(charset))]
|
||||
}
|
||||
return string(bytes)
|
||||
}
|
||||
|
@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) {
|
||||
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomString(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenLength uint8
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "ok, 16",
|
||||
whenLength: 16,
|
||||
},
|
||||
{
|
||||
name: "ok, 32",
|
||||
whenLength: 32,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
uid := randomString(tc.whenLength)
|
||||
assert.Len(t, uid, int(tc.whenLength))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user