1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-26 03:52:12 +02:00

65 lines
1.5 KiB
Go
Raw Normal View History

2019-04-11 15:07:22 +08:00
package blademaster
import (
"net/url"
"regexp"
"strings"
"github.com/bilibili/kratos/pkg/log"
2019-04-11 15:07:22 +08:00
)
func matchHostSuffix(suffix string) func(*url.URL) bool {
return func(uri *url.URL) bool {
return strings.HasSuffix(strings.ToLower(uri.Host), suffix)
}
}
func matchPattern(pattern *regexp.Regexp) func(*url.URL) bool {
return func(uri *url.URL) bool {
return pattern.MatchString(strings.ToLower(uri.String()))
}
}
// CSRF returns the csrf middleware to prevent invalid cross site request.
// Only referer is checked currently.
func CSRF(allowHosts []string, allowPattern []string) HandlerFunc {
validations := []func(*url.URL) bool{}
addHostSuffix := func(suffix string) {
validations = append(validations, matchHostSuffix(suffix))
}
addPattern := func(pattern string) {
validations = append(validations, matchPattern(regexp.MustCompile(pattern)))
}
for _, r := range allowHosts {
addHostSuffix(r)
}
for _, p := range allowPattern {
addPattern(p)
}
return func(c *Context) {
referer := c.Request.Header.Get("Referer")
if referer == "" {
log.V(5).Info("The request's Referer or Origin header is empty.")
2019-04-11 15:07:22 +08:00
c.AbortWithStatus(403)
return
}
illegal := true
if uri, err := url.Parse(referer); err == nil && uri.Host != "" {
for _, validate := range validations {
if validate(uri) {
illegal = false
break
}
}
}
if illegal {
log.V(5).Info("The request's Referer header `%s` does not match any of allowed referers.", referer)
c.AbortWithStatus(403)
return
}
}
}