mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Set subdomains to AllowOrigins with wildcard (#1301)
* Set subdomains to AllowOrigins with wildcard * Create IsSubDomain * Avoid panic when pattern length smaller than domain length * Change names, improve formula
This commit is contained in:
parent
5434a5392f
commit
1f6cc362cc
@ -102,6 +102,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
allowOrigin = o
|
||||
break
|
||||
}
|
||||
if matchSubdomain(origin, o) {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Simple request
|
||||
|
@ -66,4 +66,20 @@ func TestCORS(t *testing.T) {
|
||||
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
|
||||
|
||||
// Preflight request with `AllowOrigins` which allow all subdomains with *
|
||||
req = httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com")
|
||||
cors = CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{"http://*.example.com"},
|
||||
})
|
||||
h = cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com")
|
||||
h(c)
|
||||
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
}
|
||||
|
54
middleware/util.go
Normal file
54
middleware/util.go
Normal file
@ -0,0 +1,54 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func matchScheme(domain, pattern string) bool {
|
||||
didx := strings.Index(domain, ":")
|
||||
pidx := strings.Index(pattern, ":")
|
||||
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
|
||||
}
|
||||
|
||||
// matchSubdomain compares authority with wildcard
|
||||
func matchSubdomain(domain, pattern string) bool {
|
||||
if !matchScheme(domain, pattern) {
|
||||
return false
|
||||
}
|
||||
didx := strings.Index(domain, "://")
|
||||
pidx := strings.Index(pattern, "://")
|
||||
if didx == -1 || pidx == -1 {
|
||||
return false
|
||||
}
|
||||
domAuth := domain[didx+3:]
|
||||
// to avoid long loop by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
return false
|
||||
}
|
||||
patAuth := pattern[pidx+3:]
|
||||
|
||||
domComp := strings.Split(domAuth, ".")
|
||||
patComp := strings.Split(patAuth, ".")
|
||||
for i := len(domComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(domComp) - 1 - i
|
||||
domComp[i], domComp[opp] = domComp[opp], domComp[i]
|
||||
}
|
||||
for i := len(patComp)/2 - 1; i >= 0; i-- {
|
||||
opp := len(patComp) - 1 - i
|
||||
patComp[i], patComp[opp] = patComp[opp], patComp[i]
|
||||
}
|
||||
|
||||
for i, v := range domComp {
|
||||
if len(patComp) <= i {
|
||||
return false
|
||||
}
|
||||
p := patComp[i]
|
||||
if p == "*" {
|
||||
return true
|
||||
}
|
||||
if p != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
95
middleware/util_test.go
Normal file
95
middleware/util_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_matchScheme(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, pattern string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
domain: "http://example.com",
|
||||
pattern: "http://example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "https://example.com",
|
||||
pattern: "https://example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
pattern: "https://example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "https://example.com",
|
||||
pattern: "http://example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range tests {
|
||||
assert.Equal(t, v.expected, matchScheme(v.domain, v.pattern))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_matchSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, pattern string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
domain: "http://aaa.example.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://bbb.aaa.example.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://bbb.aaa.example.com",
|
||||
pattern: "http://*.aaa.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "http://aaa.example.com:8080",
|
||||
pattern: "http://*.example.com:8080",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
{
|
||||
domain: "http://fuga.hoge.com",
|
||||
pattern: "http://*.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://ccc.bbb.example.com",
|
||||
pattern: "http://*.aaa.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
|
||||
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
|
||||
pattern: "http://*.example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://ccc.bbb.example.com",
|
||||
pattern: "http://example.com",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range tests {
|
||||
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user