mirror of
https://github.com/labstack/echo.git
synced 2025-07-15 01:34:53 +02:00
CSRF/RequestID mw: switch math/random usage to crypto/random
This commit is contained in:
@ -6,7 +6,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/gommon/random"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
if config.TokenLength == 0 {
|
if config.TokenLength == 0 {
|
||||||
config.TokenLength = DefaultCSRFConfig.TokenLength
|
config.TokenLength = DefaultCSRFConfig.TokenLength
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.TokenLookup == "" {
|
if config.TokenLookup == "" {
|
||||||
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
config.TokenLookup = DefaultCSRFConfig.TokenLookup
|
||||||
}
|
}
|
||||||
@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
|
|||||||
|
|
||||||
token := ""
|
token := ""
|
||||||
if k, err := c.Cookie(config.CookieName); err != nil {
|
if k, err := c.Cookie(config.CookieName); err != nil {
|
||||||
token = random.String(config.TokenLength) // Generate token
|
token = randomString(config.TokenLength)
|
||||||
} else {
|
} else {
|
||||||
token = k.Value // Reuse token
|
token = k.Value // Reuse token
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/gommon/random"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) {
|
|||||||
assert.Error(t, h(c))
|
assert.Error(t, h(c))
|
||||||
|
|
||||||
// Valid CSRF token
|
// Valid CSRF token
|
||||||
token := random.String(32)
|
token := randomString(32)
|
||||||
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
req.Header.Set(echo.HeaderCookie, "_csrf="+token)
|
||||||
req.Header.Set(echo.HeaderXCSRFToken, token)
|
req.Header.Set(echo.HeaderXCSRFToken, token)
|
||||||
if assert.NoError(t, h(c)) {
|
if assert.NoError(t, h(c)) {
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/gommon/random"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
)
|
)
|
||||||
@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) {
|
|||||||
func generateAddressList(count int) []string {
|
func generateAddressList(count int) []string {
|
||||||
addrs := make([]string, count)
|
addrs := make([]string, count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
addrs[i] = random.String(15)
|
addrs[i] = randomString(15)
|
||||||
}
|
}
|
||||||
return addrs
|
return addrs
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
"github.com/labstack/gommon/random"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -12,7 +11,7 @@ type (
|
|||||||
Skipper Skipper
|
Skipper Skipper
|
||||||
|
|
||||||
// Generator defines a function to generate an ID.
|
// Generator defines a function to generate an ID.
|
||||||
// Optional. Default value random.String(32).
|
// Optional. Defaults to generator for random string of length 32.
|
||||||
Generator func() string
|
Generator func() string
|
||||||
|
|
||||||
// RequestIDHandler defines a function which is executed for a request id.
|
// RequestIDHandler defines a function which is executed for a request id.
|
||||||
@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generator() string {
|
func generator() string {
|
||||||
return random.String(32)
|
return randomString(32)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randomString(length uint8) string {
|
||||||
|
charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
|
|
||||||
|
bytes := make([]byte, length)
|
||||||
|
_, err := rand.Read(bytes)
|
||||||
|
if err != nil {
|
||||||
|
// we are out of random. let the request fail
|
||||||
|
panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err))
|
||||||
|
}
|
||||||
|
for i, b := range bytes {
|
||||||
|
bytes[i] = charset[b%byte(len(charset))]
|
||||||
|
}
|
||||||
|
return string(bytes)
|
||||||
|
}
|
||||||
|
@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) {
|
|||||||
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
|
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRandomString(t *testing.T) {
|
||||||
|
var testCases = []struct {
|
||||||
|
name string
|
||||||
|
whenLength uint8
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ok, 16",
|
||||||
|
whenLength: 16,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok, 32",
|
||||||
|
whenLength: 32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
uid := randomString(tc.whenLength)
|
||||||
|
assert.Len(t, uid, int(tc.whenLength))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user