1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Set subdomains to AllowOrigins with wildcard (#1301)

* Set subdomains to AllowOrigins with wildcard

* Create IsSubDomain

* Avoid panic when pattern length smaller than domain length

* Change names, improve formula
This commit is contained in:
atsushi-ishibashi 2019-03-10 03:32:49 +09:00 committed by Vishal Rana
parent 5434a5392f
commit 1f6cc362cc
4 changed files with 169 additions and 0 deletions

View File

@ -102,6 +102,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Simple request

View File

@ -66,4 +66,20 @@ func TestCORS(t *testing.T) {
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
assert.Equal(t, "true", rec.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rec.Header().Get(echo.HeaderAccessControlMaxAge))
// Preflight request with `AllowOrigins` which allow all subdomains with *
req = httptest.NewRequest(http.MethodOptions, "/", nil)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, "http://aaa.example.com")
cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"http://*.example.com"},
})
h = cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, "http://aaa.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
req.Header.Set(echo.HeaderOrigin, "http://bbb.example.com")
h(c)
assert.Equal(t, "http://bbb.example.com", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
}

54
middleware/util.go Normal file
View File

@ -0,0 +1,54 @@
package middleware
import (
"strings"
)
func matchScheme(domain, pattern string) bool {
didx := strings.Index(domain, ":")
pidx := strings.Index(pattern, ":")
return didx != -1 && pidx != -1 && domain[:didx] == pattern[:pidx]
}
// matchSubdomain compares authority with wildcard
func matchSubdomain(domain, pattern string) bool {
if !matchScheme(domain, pattern) {
return false
}
didx := strings.Index(domain, "://")
pidx := strings.Index(pattern, "://")
if didx == -1 || pidx == -1 {
return false
}
domAuth := domain[didx+3:]
// to avoid long loop by invalid long domain
if len(domAuth) > 253 {
return false
}
patAuth := pattern[pidx+3:]
domComp := strings.Split(domAuth, ".")
patComp := strings.Split(patAuth, ".")
for i := len(domComp)/2 - 1; i >= 0; i-- {
opp := len(domComp) - 1 - i
domComp[i], domComp[opp] = domComp[opp], domComp[i]
}
for i := len(patComp)/2 - 1; i >= 0; i-- {
opp := len(patComp) - 1 - i
patComp[i], patComp[opp] = patComp[opp], patComp[i]
}
for i, v := range domComp {
if len(patComp) <= i {
return false
}
p := patComp[i]
if p == "*" {
return true
}
if p != v {
return false
}
}
return false
}

95
middleware/util_test.go Normal file
View File

@ -0,0 +1,95 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_matchScheme(t *testing.T) {
tests := []struct {
domain, pattern string
expected bool
}{
{
domain: "http://example.com",
pattern: "http://example.com",
expected: true,
},
{
domain: "https://example.com",
pattern: "https://example.com",
expected: true,
},
{
domain: "http://example.com",
pattern: "https://example.com",
expected: false,
},
{
domain: "https://example.com",
pattern: "http://example.com",
expected: false,
},
}
for _, v := range tests {
assert.Equal(t, v.expected, matchScheme(v.domain, v.pattern))
}
}
func Test_matchSubdomain(t *testing.T) {
tests := []struct {
domain, pattern string
expected bool
}{
{
domain: "http://aaa.example.com",
pattern: "http://*.example.com",
expected: true,
},
{
domain: "http://bbb.aaa.example.com",
pattern: "http://*.example.com",
expected: true,
},
{
domain: "http://bbb.aaa.example.com",
pattern: "http://*.aaa.example.com",
expected: true,
},
{
domain: "http://aaa.example.com:8080",
pattern: "http://*.example.com:8080",
expected: true,
},
{
domain: "http://fuga.hoge.com",
pattern: "http://*.example.com",
expected: false,
},
{
domain: "http://ccc.bbb.example.com",
pattern: "http://*.aaa.example.com",
expected: false,
},
{
domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
pattern: "http://*.example.com",
expected: false,
},
{
domain: "http://ccc.bbb.example.com",
pattern: "http://example.com",
expected: false,
},
}
for _, v := range tests {
assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern))
}
}