mirror of
https://github.com/labstack/echo.git
synced 2025-01-24 03:16:14 +02:00
parent
0172fe675b
commit
a3352d880c
37
echo.go
37
echo.go
@ -152,20 +152,29 @@ const (
|
||||
|
||||
// Headers
|
||||
const (
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
HeaderAuthorization = "Authorization"
|
||||
HeaderContentDisposition = "Content-Disposition"
|
||||
HeaderContentEncoding = "Content-Encoding"
|
||||
HeaderContentLength = "Content-Length"
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderIfModifiedSince = "If-Modified-Since"
|
||||
HeaderLastModified = "Last-Modified"
|
||||
HeaderLocation = "Location"
|
||||
HeaderUpgrade = "Upgrade"
|
||||
HeaderVary = "Vary"
|
||||
HeaderWWWAuthenticate = "WWW-Authenticate"
|
||||
HeaderXForwardedFor = "X-Forwarded-For"
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
HeaderAuthorization = "Authorization"
|
||||
HeaderContentDisposition = "Content-Disposition"
|
||||
HeaderContentEncoding = "Content-Encoding"
|
||||
HeaderContentLength = "Content-Length"
|
||||
HeaderContentType = "Content-Type"
|
||||
HeaderIfModifiedSince = "If-Modified-Since"
|
||||
HeaderLastModified = "Last-Modified"
|
||||
HeaderLocation = "Location"
|
||||
HeaderUpgrade = "Upgrade"
|
||||
HeaderVary = "Vary"
|
||||
HeaderWWWAuthenticate = "WWW-Authenticate"
|
||||
HeaderXForwardedFor = "X-Forwarded-For"
|
||||
HeaderXRealIP = "X-Real-IP"
|
||||
HeaderOrigin = "Origin"
|
||||
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
|
||||
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
|
||||
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
|
||||
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
|
||||
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
|
||||
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
|
||||
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
|
||||
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
|
||||
)
|
||||
|
||||
var (
|
||||
|
133
middleware/cors.go
Normal file
133
middleware/cors.go
Normal file
@ -0,0 +1,133 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
type (
|
||||
// CORSConfig defines the config for CORS middleware.
|
||||
CORSConfig struct {
|
||||
// AllowOrigin defines a list of origins that may access the resource.
|
||||
// Optional with default value as []string{"*"}.
|
||||
AllowOrigins []string
|
||||
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
// Optional with default value as `DefaultCORSConfig.AllowMethods`.
|
||||
AllowMethods []string
|
||||
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
// making the actual request. This in response to a preflight request.
|
||||
// Optional with default value as []string{}.
|
||||
AllowHeaders []string
|
||||
|
||||
// AllowCredentials indicates whether or not the response to the request
|
||||
// can be exposed when the credentials flag is true. When used as part of
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials.
|
||||
// Optional with default value as false.
|
||||
AllowCredentials bool
|
||||
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
// access.
|
||||
// Optional with default value as []string{}.
|
||||
ExposeHeaders []string
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached.
|
||||
// Optional with default value as 0.
|
||||
MaxAge int
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultCORSConfig is the default CORS middleware config.
|
||||
DefaultCORSConfig = CORSConfig{
|
||||
AllowOrigins: []string{"*"},
|
||||
AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.POST, echo.DELETE},
|
||||
}
|
||||
)
|
||||
|
||||
// CORS returns a cross-origin HTTP request (CORS) middleware.
|
||||
// See https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
|
||||
func CORS() echo.MiddlewareFunc {
|
||||
return CORSFromConfig(DefaultCORSConfig)
|
||||
}
|
||||
|
||||
// CORSFromConfig returns a CORS middleware from config.
|
||||
// See `CORS()`.
|
||||
func CORSFromConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
// Defaults
|
||||
if len(config.AllowOrigins) == 0 {
|
||||
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
|
||||
}
|
||||
if len(config.AllowMethods) == 0 {
|
||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||
}
|
||||
allowMethods := strings.Join(config.AllowMethods, ",")
|
||||
allowHeaders := strings.Join(config.AllowHeaders, ",")
|
||||
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
|
||||
maxAge := strconv.Itoa(config.MaxAge)
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
rq := c.Request()
|
||||
origin := c.Request().Header().Get(echo.HeaderOrigin)
|
||||
header := c.Response().Header()
|
||||
|
||||
// Check allowed origins
|
||||
allowedOrigin := ""
|
||||
for _, o := range config.AllowOrigins {
|
||||
if o == "*" || o == origin {
|
||||
allowedOrigin = o
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Simple request
|
||||
if rq.Method() != echo.OPTIONS {
|
||||
header.Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||
if origin == "" || allowedOrigin == "" {
|
||||
return next(c)
|
||||
}
|
||||
header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
|
||||
if config.AllowCredentials {
|
||||
header.Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
if exposeHeaders != "" {
|
||||
header.Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Preflight request
|
||||
header.Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||
header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
||||
header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
||||
if origin == "" || allowedOrigin == "" {
|
||||
return next(c)
|
||||
}
|
||||
header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
|
||||
header.Set(echo.HeaderAccessControlAllowMethods, allowMethods)
|
||||
if config.AllowCredentials {
|
||||
header.Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
if allowHeaders != "" {
|
||||
header.Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
|
||||
} else {
|
||||
h := rq.Header().Get(echo.HeaderAccessControlRequestHeaders)
|
||||
if h != "" {
|
||||
header.Set(echo.HeaderAccessControlAllowHeaders, h)
|
||||
}
|
||||
}
|
||||
if config.MaxAge > 0 {
|
||||
header.Set(echo.HeaderAccessControlMaxAge, maxAge)
|
||||
}
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
}
|
63
middleware/cors_test.go
Normal file
63
middleware/cors_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"github.com/labstack/echo/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
e := echo.New()
|
||||
rq := test.NewRequest(echo.GET, "/", nil)
|
||||
rc := test.NewResponseRecorder()
|
||||
c := echo.NewContext(rq, rc, e)
|
||||
cors := CORSFromConfig(CORSConfig{
|
||||
AllowCredentials: true,
|
||||
})
|
||||
h := cors(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// No origin header
|
||||
h(c)
|
||||
assert.Equal(t, "", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
// Wildcard origin
|
||||
rq = test.NewRequest(echo.GET, "/", nil)
|
||||
rc = test.NewResponseRecorder()
|
||||
c = echo.NewContext(rq, rc, e)
|
||||
rq.Header().Set(echo.HeaderOrigin, "localhost")
|
||||
h(c)
|
||||
assert.Equal(t, "*", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
// Simple request
|
||||
rq = test.NewRequest(echo.GET, "/", nil)
|
||||
rc = test.NewResponseRecorder()
|
||||
c = echo.NewContext(rq, rc, e)
|
||||
rq.Header().Set(echo.HeaderOrigin, "localhost")
|
||||
cors = CORSFromConfig(CORSConfig{
|
||||
AllowOrigins: []string{"localhost"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 3600,
|
||||
})
|
||||
h = cors(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
h(c)
|
||||
assert.Equal(t, "localhost", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
// Preflight request
|
||||
rq = test.NewRequest(echo.OPTIONS, "/", nil)
|
||||
rc = test.NewResponseRecorder()
|
||||
c = echo.NewContext(rq, rc, e)
|
||||
rq.Header().Set(echo.HeaderOrigin, "localhost")
|
||||
rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
|
||||
h(c)
|
||||
assert.Equal(t, "localhost", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.NotEmpty(t, rc.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
assert.Equal(t, "true", rc.Header().Get(echo.HeaderAccessControlAllowCredentials))
|
||||
assert.Equal(t, "3600", rc.Header().Get(echo.HeaderAccessControlMaxAge))
|
||||
}
|
1
middleware/static_test.go
Normal file
1
middleware/static_test.go
Normal file
@ -0,0 +1 @@
|
||||
package middleware
|
Loading…
x
Reference in New Issue
Block a user