1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-15 01:34:53 +02:00

CSRF/RequestID mw: switch math/random usage to crypto/random

This commit is contained in:
toimtoimtoim
2023-07-21 09:49:27 +03:00
committed by Martti T
parent 3f8ae15b57
commit 626f13e338
6 changed files with 47 additions and 9 deletions

View File

@ -6,7 +6,6 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type ( type (
@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.TokenLength == 0 { if config.TokenLength == 0 {
config.TokenLength = DefaultCSRFConfig.TokenLength config.TokenLength = DefaultCSRFConfig.TokenLength
} }
if config.TokenLookup == "" { if config.TokenLookup == "" {
config.TokenLookup = DefaultCSRFConfig.TokenLookup config.TokenLookup = DefaultCSRFConfig.TokenLookup
} }
@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
token := "" token := ""
if k, err := c.Cookie(config.CookieName); err != nil { if k, err := c.Cookie(config.CookieName); err != nil {
token = random.String(config.TokenLength) // Generate token token = randomString(config.TokenLength)
} else { } else {
token = k.Value // Reuse token token = k.Value // Reuse token
} }

View File

@ -8,7 +8,6 @@ import (
"testing" "testing"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) {
assert.Error(t, h(c)) assert.Error(t, h(c))
// Valid CSRF token // Valid CSRF token
token := random.String(32) token := randomString(32)
req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header.Set(echo.HeaderXCSRFToken, token) req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {

View File

@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
func generateAddressList(count int) []string { func generateAddressList(count int) []string {
addrs := make([]string, count) addrs := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
addrs[i] = random.String(15) addrs[i] = randomString(15)
} }
return addrs return addrs
} }

View File

@ -2,7 +2,6 @@ package middleware
import ( import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/gommon/random"
) )
type ( type (
@ -12,7 +11,7 @@ type (
Skipper Skipper Skipper Skipper
// Generator defines a function to generate an ID. // Generator defines a function to generate an ID.
// Optional. Default value random.String(32). // Optional. Defaults to generator for random string of length 32.
Generator func() string Generator func() string
// RequestIDHandler defines a function which is executed for a request id. // RequestIDHandler defines a function which is executed for a request id.
@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
} }
func generator() string { func generator() string {
return random.String(32) return randomString(32)
} }

View File

@ -1,6 +1,8 @@
package middleware package middleware
import ( import (
"crypto/rand"
"fmt"
"strings" "strings"
) )
@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool {
} }
return false return false
} }
func randomString(length uint8) string {
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
// we are out of random. let the request fail
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
}
for i, b := range bytes {
bytes[i] = charset[b%byte(len(charset))]
}
return string(bytes)
}

View File

@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern)) assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
} }
} }
func TestRandomString(t *testing.T) {
var testCases = []struct {
name string
whenLength uint8
expect string
}{
{
name: "ok, 16",
whenLength: 16,
},
{
name: "ok, 32",
whenLength: 32,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uid := randomString(tc.whenLength)
assert.Len(t, uid, int(tc.whenLength))
})
}
}