1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-24 03:16:14 +02:00
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-04-07 16:16:58 -07:00
parent 0172fe675b
commit a3352d880c
4 changed files with 220 additions and 14 deletions

37
echo.go
View File

@ -152,20 +152,29 @@ const (
// Headers
const (
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
HeaderContentEncoding = "Content-Encoding"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXRealIP = "X-Real-IP"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
HeaderContentEncoding = "Content-Encoding"
HeaderContentLength = "Content-Length"
HeaderContentType = "Content-Type"
HeaderIfModifiedSince = "If-Modified-Since"
HeaderLastModified = "Last-Modified"
HeaderLocation = "Location"
HeaderUpgrade = "Upgrade"
HeaderVary = "Vary"
HeaderWWWAuthenticate = "WWW-Authenticate"
HeaderXForwardedFor = "X-Forwarded-For"
HeaderXRealIP = "X-Real-IP"
HeaderOrigin = "Origin"
HeaderAccessControlRequestMethod = "Access-Control-Request-Method"
HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers"
HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin"
HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods"
HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers"
HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials"
HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers"
HeaderAccessControlMaxAge = "Access-Control-Max-Age"
)
var (

133
middleware/cors.go Normal file
View File

@ -0,0 +1,133 @@
package middleware
import (
"net/http"
"strconv"
"strings"
"github.com/labstack/echo"
)
type (
// CORSConfig defines the config for CORS middleware.
CORSConfig struct {
// AllowOrigin defines a list of origins that may access the resource.
// Optional with default value as []string{"*"}.
AllowOrigins []string
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional with default value as `DefaultCORSConfig.AllowMethods`.
AllowMethods []string
// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This in response to a preflight request.
// Optional with default value as []string{}.
AllowHeaders []string
// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
// Optional with default value as false.
AllowCredentials bool
// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
// Optional with default value as []string{}.
ExposeHeaders []string
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
// Optional with default value as 0.
MaxAge int
}
)
var (
// DefaultCORSConfig is the default CORS middleware config.
DefaultCORSConfig = CORSConfig{
AllowOrigins: []string{"*"},
AllowMethods: []string{echo.GET, echo.HEAD, echo.PUT, echo.POST, echo.DELETE},
}
)
// CORS returns a cross-origin HTTP request (CORS) middleware.
// See https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS
func CORS() echo.MiddlewareFunc {
return CORSFromConfig(DefaultCORSConfig)
}
// CORSFromConfig returns a CORS middleware from config.
// See `CORS()`.
func CORSFromConfig(config CORSConfig) echo.MiddlewareFunc {
// Defaults
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
if len(config.AllowMethods) == 0 {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
maxAge := strconv.Itoa(config.MaxAge)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
rq := c.Request()
origin := c.Request().Header().Get(echo.HeaderOrigin)
header := c.Response().Header()
// Check allowed origins
allowedOrigin := ""
for _, o := range config.AllowOrigins {
if o == "*" || o == origin {
allowedOrigin = o
break
}
}
// Simple request
if rq.Method() != echo.OPTIONS {
header.Add(echo.HeaderVary, echo.HeaderOrigin)
if origin == "" || allowedOrigin == "" {
return next(c)
}
header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
if config.AllowCredentials {
header.Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
header.Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
return next(c)
}
// Preflight request
header.Add(echo.HeaderVary, echo.HeaderOrigin)
header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
header.Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
if origin == "" || allowedOrigin == "" {
return next(c)
}
header.Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
header.Set(echo.HeaderAccessControlAllowMethods, allowMethods)
if config.AllowCredentials {
header.Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if allowHeaders != "" {
header.Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := rq.Header().Get(echo.HeaderAccessControlRequestHeaders)
if h != "" {
header.Set(echo.HeaderAccessControlAllowHeaders, h)
}
}
if config.MaxAge > 0 {
header.Set(echo.HeaderAccessControlMaxAge, maxAge)
}
return c.NoContent(http.StatusNoContent)
}
}
}

63
middleware/cors_test.go Normal file
View File

@ -0,0 +1,63 @@
package middleware
import (
"net/http"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert"
)
func TestCORS(t *testing.T) {
e := echo.New()
rq := test.NewRequest(echo.GET, "/", nil)
rc := test.NewResponseRecorder()
c := echo.NewContext(rq, rc, e)
cors := CORSFromConfig(CORSConfig{
AllowCredentials: true,
})
h := cors(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// No origin header
h(c)
assert.Equal(t, "", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard origin
rq = test.NewRequest(echo.GET, "/", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
rq.Header().Set(echo.HeaderOrigin, "localhost")
h(c)
assert.Equal(t, "*", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Simple request
rq = test.NewRequest(echo.GET, "/", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
rq.Header().Set(echo.HeaderOrigin, "localhost")
cors = CORSFromConfig(CORSConfig{
AllowOrigins: []string{"localhost"},
AllowCredentials: true,
MaxAge: 3600,
})
h = cors(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
h(c)
assert.Equal(t, "localhost", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Preflight request
rq = test.NewRequest(echo.OPTIONS, "/", nil)
rc = test.NewResponseRecorder()
c = echo.NewContext(rq, rc, e)
rq.Header().Set(echo.HeaderOrigin, "localhost")
rq.Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
h(c)
assert.Equal(t, "localhost", rc.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rc.Header().Get(echo.HeaderAccessControlAllowMethods))
assert.Equal(t, "true", rc.Header().Get(echo.HeaderAccessControlAllowCredentials))
assert.Equal(t, "3600", rc.Header().Get(echo.HeaderAccessControlMaxAge))
}

View File

@ -0,0 +1 @@
package middleware