diff --git a/middleware/util.go b/middleware/util.go index aa34d78f..0aa0420f 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -1,9 +1,11 @@ package middleware import ( + "bufio" "crypto/rand" - "fmt" + "io" "strings" + "sync" ) func matchScheme(domain, pattern string) bool { @@ -55,17 +57,38 @@ func matchSubdomain(domain, pattern string) bool { return false } -func randomString(length uint8) string { - charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +// https://tip.golang.org/doc/go1.19#:~:text=Read%20no%20longer%20buffers%20random%20data%20obtained%20from%20the%20operating%20system%20between%20calls +var randomReaderPool = sync.Pool{New: func() interface{} { + return bufio.NewReader(rand.Reader) +}} - bytes := make([]byte, length) - _, err := rand.Read(bytes) - if err != nil { - // we are out of random. let the request fail - panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err)) +const randomStringCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +const randomStringCharsetLen = 52 // len(randomStringCharset) +const randomStringMaxByte = 255 - (256 % randomStringCharsetLen) + +func randomString(length uint8) string { + reader := randomReaderPool.Get().(*bufio.Reader) + defer randomReaderPool.Put(reader) + + b := make([]byte, length) + r := make([]byte, length+(length/4)) // perf: avoid read from rand.Reader many times + var i uint8 = 0 + + for { + _, err := io.ReadFull(reader, r) + if err != nil { + panic("unexpected error happened when reading from bufio.NewReader(crypto/rand.Reader)") + } + for _, rb := range r { + if rb > randomStringMaxByte { + // Skip this number to avoid bias. + continue + } + b[i] = randomStringCharset[rb%randomStringCharsetLen] + i++ + if i == length { + return string(b) + } + } } - for i, b := range bytes { - bytes[i] = charset[b%byte(len(charset))] - } - return string(bytes) } diff --git a/middleware/util_test.go b/middleware/util_test.go index 7562d4a5..d0f20bba 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_matchScheme(t *testing.T) { @@ -117,3 +118,31 @@ func TestRandomString(t *testing.T) { }) } } + +func TestRandomStringBias(t *testing.T) { + t.Parallel() + const slen = 33 + const loop = 100000 + + counts := make(map[rune]int) + var count int64 + + for i := 0; i < loop; i++ { + s := randomString(slen) + require.Equal(t, slen, len(s)) + for _, b := range s { + counts[b]++ + count++ + } + } + + require.Equal(t, randomStringCharsetLen, len(counts)) + + avg := float64(count) / float64(len(counts)) + for k, n := range counts { + diff := float64(n) / avg + if diff < 0.95 || diff > 1.05 { + t.Errorf("Bias on '%c': expected average %f, got %d", k, avg, n) + } + } +}