mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Allow
header support in Router, MethodNotFoundHandler (405) and CORS middleware
This commit is contained in:
parent
4fffee2ec8
commit
5b26a5257b
@ -210,6 +210,13 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
|
||||
// Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
|
||||
// It is added to context only when Router does not find matching method handler for request.
|
||||
ContextKeyHeaderAllow = "____echo____header_allow"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMemory = 32 << 20 // 32 MB
|
||||
indexPage = "index.html"
|
||||
|
14
echo.go
14
echo.go
@ -190,8 +190,11 @@ const (
|
||||
|
||||
// Headers
|
||||
const (
|
||||
HeaderAccept = "Accept"
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
HeaderAccept = "Accept"
|
||||
HeaderAcceptEncoding = "Accept-Encoding"
|
||||
// HeaderAllow is header field that lists the set of methods advertised as supported by the target resource.
|
||||
// Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses.
|
||||
// See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
|
||||
HeaderAllow = "Allow"
|
||||
HeaderAuthorization = "Authorization"
|
||||
HeaderContentDisposition = "Content-Disposition"
|
||||
@ -302,6 +305,13 @@ var (
|
||||
}
|
||||
|
||||
MethodNotAllowedHandler = func(c Context) error {
|
||||
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
|
||||
// >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response
|
||||
// and MAY do so in any other response.
|
||||
routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string)
|
||||
if ok && routerAllowMethods != "" {
|
||||
c.Response().Header().Set(HeaderAllow, routerAllowMethods)
|
||||
}
|
||||
return ErrMethodNotAllowed
|
||||
}
|
||||
)
|
||||
|
@ -716,13 +716,16 @@ func TestEchoNotFound(t *testing.T) {
|
||||
|
||||
func TestEchoMethodNotAllowed(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
e.GET("/", func(c Context) error {
|
||||
return c.String(http.StatusOK, "Echo!")
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
|
||||
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
|
||||
}
|
||||
|
||||
func TestEchoContext(t *testing.T) {
|
||||
|
@ -29,6 +29,8 @@ type (
|
||||
// AllowMethods defines a list methods allowed when accessing the resource.
|
||||
// This is used in response to a preflight request.
|
||||
// Optional. Default value DefaultCORSConfig.AllowMethods.
|
||||
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value
|
||||
// from `Allow` header that echo.Router set into context.
|
||||
AllowMethods []string `yaml:"allow_methods"`
|
||||
|
||||
// AllowHeaders defines a list of request headers that can be used when
|
||||
@ -41,6 +43,8 @@ type (
|
||||
// a response to a preflight request, this indicates whether or not the
|
||||
// actual request can be made using credentials.
|
||||
// Optional. Default value false.
|
||||
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
|
||||
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
|
||||
AllowCredentials bool `yaml:"allow_credentials"`
|
||||
|
||||
// ExposeHeaders defines a whitelist headers that clients are allowed to
|
||||
@ -80,7 +84,9 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
if len(config.AllowOrigins) == 0 {
|
||||
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
|
||||
}
|
||||
hasCustomAllowMethods := true
|
||||
if len(config.AllowMethods) == 0 {
|
||||
hasCustomAllowMethods = false
|
||||
config.AllowMethods = DefaultCORSConfig.AllowMethods
|
||||
}
|
||||
|
||||
@ -109,10 +115,28 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
origin := req.Header.Get(echo.HeaderOrigin)
|
||||
allowOrigin := ""
|
||||
|
||||
preflight := req.Method == http.MethodOptions
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
|
||||
|
||||
// No Origin provided
|
||||
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
|
||||
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
// For simplicity we just consider method type and later `Origin` header.
|
||||
preflight := req.Method == http.MethodOptions
|
||||
|
||||
// Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
|
||||
// as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
|
||||
// middlewares by calling next(c).
|
||||
// But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
|
||||
// handler does.
|
||||
routerAllowMethods := ""
|
||||
if preflight {
|
||||
tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
|
||||
if ok && tmpAllowMethods != "" {
|
||||
routerAllowMethods = tmpAllowMethods
|
||||
c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
|
||||
}
|
||||
}
|
||||
|
||||
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
|
||||
if origin == "" {
|
||||
if !preflight {
|
||||
return next(c)
|
||||
@ -145,19 +169,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed origin patterns
|
||||
for _, re := range allowOriginPatterns {
|
||||
if allowOrigin == "" {
|
||||
didx := strings.Index(origin, "://")
|
||||
if didx == -1 {
|
||||
continue
|
||||
}
|
||||
domAuth := origin[didx+3:]
|
||||
// to avoid regex cost by invalid long domain
|
||||
if len(domAuth) > 253 {
|
||||
break
|
||||
}
|
||||
|
||||
checkPatterns := false
|
||||
if allowOrigin == "" {
|
||||
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
|
||||
if len(origin) <= (253+3+4) && strings.Contains(origin, "://") {
|
||||
checkPatterns = true
|
||||
}
|
||||
}
|
||||
if checkPatterns {
|
||||
for _, re := range allowOriginPatterns {
|
||||
if match, _ := regexp.MatchString(re, origin); match {
|
||||
allowOrigin = origin
|
||||
break
|
||||
@ -174,12 +194,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
|
||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
|
||||
// Simple request
|
||||
if !preflight {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||
}
|
||||
if exposeHeaders != "" {
|
||||
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
|
||||
}
|
||||
@ -189,11 +210,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
|
||||
// Preflight request
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
|
||||
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
|
||||
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
|
||||
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
|
||||
if config.AllowCredentials {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
|
||||
|
||||
if !hasCustomAllowMethods && routerAllowMethods != "" {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
|
||||
} else {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
|
||||
}
|
||||
|
||||
if allowHeaders != "" {
|
||||
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
|
||||
} else {
|
||||
|
@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorsHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain, allowedOrigin, method string
|
||||
expected bool
|
||||
func TestCORSWithConfig_AllowMethods(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
name string
|
||||
allowOrigins []string
|
||||
allowContextKey string
|
||||
|
||||
whenOrigin string
|
||||
whenAllowMethods []string
|
||||
|
||||
expectAllow string
|
||||
expectAccessControlAllowMethods string
|
||||
}{
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
|
||||
allowContextKey: "OPTIONS, GET",
|
||||
whenAllowMethods: []string{http.MethodGet, http.MethodHead},
|
||||
whenOrigin: "",
|
||||
expectAllow: "OPTIONS, GET",
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
|
||||
allowContextKey: "",
|
||||
whenAllowMethods: nil,
|
||||
whenOrigin: "",
|
||||
expectAllow: "",
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
|
||||
allowContextKey: "OPTIONS, GET",
|
||||
whenAllowMethods: []string{http.MethodGet, http.MethodHead},
|
||||
whenOrigin: "http://google.com",
|
||||
expectAllow: "OPTIONS, GET",
|
||||
expectAccessControlAllowMethods: "GET,HEAD",
|
||||
},
|
||||
{
|
||||
domain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
name: "default AllowMethods, preflight, existing origin, sets both headers",
|
||||
allowContextKey: "OPTIONS, GET",
|
||||
whenAllowMethods: nil,
|
||||
whenOrigin: "http://google.com",
|
||||
expectAllow: "OPTIONS, GET",
|
||||
expectAccessControlAllowMethods: "OPTIONS, GET",
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
domain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
domain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
|
||||
allowContextKey: "",
|
||||
whenAllowMethods: nil,
|
||||
whenOrigin: "http://google.com",
|
||||
expectAllow: "",
|
||||
expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
|
||||
},
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest(tt.method, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
if tt.domain != "" {
|
||||
req.Header.Set(echo.HeaderOrigin, tt.domain)
|
||||
}
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tt.allowedOrigin},
|
||||
//AllowCredentials: true,
|
||||
//MaxAge: 3600,
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
e.GET("/test", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
|
||||
cors := CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: tc.allowOrigins,
|
||||
AllowMethods: tc.whenAllowMethods,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec)
|
||||
|
||||
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
|
||||
if tc.allowContextKey != "" {
|
||||
c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
|
||||
}
|
||||
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
|
||||
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
|
||||
assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
|
||||
})
|
||||
h := cors(echo.NotFoundHandler)
|
||||
h(c)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||
func TestCorsHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
originDomain string
|
||||
method string
|
||||
allowedOrigin string
|
||||
expected bool
|
||||
expectStatus int
|
||||
expectAllowHeader string
|
||||
}{
|
||||
{
|
||||
name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
|
||||
originDomain: "",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow any origin, specific origin domain",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
|
||||
originDomain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
|
||||
originDomain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: false,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodGet,
|
||||
expected: true,
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "preflight, allow any origin, missing origin header = no CORS logic done",
|
||||
originDomain: "", // Request does not have Origin header
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow any origin, existing origin header = CORS logic done",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "*",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow any origin, missing origin header = no CORS logic done",
|
||||
originDomain: "", // Request does not have Origin header
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow specific origin, different origin header = no CORS logic done",
|
||||
originDomain: "http://bar.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: false,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
{
|
||||
name: "preflight, allow specific origin, matching origin header = CORS logic done",
|
||||
originDomain: "http://example.com",
|
||||
allowedOrigin: "http://example.com",
|
||||
method: http.MethodOptions,
|
||||
expected: true,
|
||||
expectStatus: http.StatusNoContent,
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
},
|
||||
}
|
||||
|
||||
expectedAllowOrigin := ""
|
||||
if tt.allowedOrigin == "*" {
|
||||
expectedAllowOrigin = "*"
|
||||
} else {
|
||||
expectedAllowOrigin = tt.domain
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := echo.New()
|
||||
|
||||
switch {
|
||||
case tt.expected && tt.method == http.MethodOptions:
|
||||
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||
case tt.expected && tt.method == http.MethodGet:
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
default:
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
}
|
||||
e.Use(CORSWithConfig(CORSConfig{
|
||||
AllowOrigins: []string{tc.allowedOrigin},
|
||||
//AllowCredentials: true,
|
||||
//MaxAge: 3600,
|
||||
}))
|
||||
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "OK")
|
||||
})
|
||||
e.POST("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusCreated, "OK")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(tc.method, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
if tc.originDomain != "" {
|
||||
req.Header.Set(echo.HeaderOrigin, tc.originDomain)
|
||||
}
|
||||
|
||||
// we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
|
||||
assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
|
||||
assert.Equal(t, tc.expectStatus, rec.Code)
|
||||
|
||||
expectedAllowOrigin := ""
|
||||
if tc.allowedOrigin == "*" {
|
||||
expectedAllowOrigin = "*"
|
||||
} else {
|
||||
expectedAllowOrigin = tc.originDomain
|
||||
}
|
||||
switch {
|
||||
case tc.expected && tc.method == http.MethodOptions:
|
||||
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
|
||||
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
|
||||
|
||||
case tc.expected && tc.method == http.MethodGet:
|
||||
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
default:
|
||||
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
|
||||
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
|
||||
}
|
||||
})
|
||||
|
||||
if tt.method == http.MethodOptions {
|
||||
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
93
router.go
93
router.go
@ -1,6 +1,7 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -31,17 +32,18 @@ type (
|
||||
kind uint8
|
||||
children []*node
|
||||
methodHandler struct {
|
||||
connect HandlerFunc
|
||||
delete HandlerFunc
|
||||
get HandlerFunc
|
||||
head HandlerFunc
|
||||
options HandlerFunc
|
||||
patch HandlerFunc
|
||||
post HandlerFunc
|
||||
propfind HandlerFunc
|
||||
put HandlerFunc
|
||||
trace HandlerFunc
|
||||
report HandlerFunc
|
||||
connect HandlerFunc
|
||||
delete HandlerFunc
|
||||
get HandlerFunc
|
||||
head HandlerFunc
|
||||
options HandlerFunc
|
||||
patch HandlerFunc
|
||||
post HandlerFunc
|
||||
propfind HandlerFunc
|
||||
put HandlerFunc
|
||||
trace HandlerFunc
|
||||
report HandlerFunc
|
||||
allowHeader string
|
||||
}
|
||||
)
|
||||
|
||||
@ -68,6 +70,51 @@ func (m *methodHandler) isHandler() bool {
|
||||
m.report != nil
|
||||
}
|
||||
|
||||
func (m *methodHandler) updateAllowHeader() {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.WriteString(http.MethodOptions)
|
||||
|
||||
if m.connect != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodConnect)
|
||||
}
|
||||
if m.delete != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodDelete)
|
||||
}
|
||||
if m.get != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodGet)
|
||||
}
|
||||
if m.head != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodHead)
|
||||
}
|
||||
if m.patch != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodPatch)
|
||||
}
|
||||
if m.post != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodPost)
|
||||
}
|
||||
if m.propfind != nil {
|
||||
buf.WriteString(", PROPFIND")
|
||||
}
|
||||
if m.put != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodPut)
|
||||
}
|
||||
if m.trace != nil {
|
||||
buf.WriteString(", ")
|
||||
buf.WriteString(http.MethodTrace)
|
||||
}
|
||||
if m.report != nil {
|
||||
buf.WriteString(", REPORT")
|
||||
}
|
||||
m.allowHeader = buf.String()
|
||||
}
|
||||
|
||||
// NewRouter returns a new Router instance.
|
||||
func NewRouter(e *Echo) *Router {
|
||||
return &Router{
|
||||
@ -326,6 +373,7 @@ func (n *node) addHandler(method string, h HandlerFunc) {
|
||||
n.methodHandler.report = h
|
||||
}
|
||||
|
||||
n.methodHandler.updateAllowHeader()
|
||||
if h != nil {
|
||||
n.isHandler = true
|
||||
} else {
|
||||
@ -362,13 +410,14 @@ func (n *node) findHandler(method string) HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (n *node) checkMethodNotAllowed() HandlerFunc {
|
||||
for _, m := range methods {
|
||||
if h := n.findHandler(m); h != nil {
|
||||
return MethodNotAllowedHandler
|
||||
}
|
||||
func optionsMethodHandler(allowMethods string) func(c Context) error {
|
||||
return func(c Context) error {
|
||||
// Note: we are not handling most of the CORS headers here. CORS is handled by CORS middleware
|
||||
// 'OPTIONS' method RFC: https://httpwg.org/specs/rfc7231.html#OPTIONS
|
||||
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
|
||||
c.Response().Header().Add(HeaderAllow, allowMethods)
|
||||
return c.NoContent(http.StatusNoContent)
|
||||
}
|
||||
return NotFoundHandler
|
||||
}
|
||||
|
||||
// Find lookup a handler registered for method and path. It also parses URL for path
|
||||
@ -563,7 +612,15 @@ func (r *Router) Find(method, path string, c Context) {
|
||||
// use previous match as basis. although we have no matching handler we have path match.
|
||||
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
|
||||
currentNode = previousBestMatchNode
|
||||
ctx.handler = currentNode.checkMethodNotAllowed()
|
||||
|
||||
ctx.handler = NotFoundHandler
|
||||
if currentNode.isHandler {
|
||||
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
|
||||
ctx.handler = MethodNotAllowedHandler
|
||||
if method == http.MethodOptions {
|
||||
ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx.path = currentNode.ppath
|
||||
ctx.pnames = currentNode.pnames
|
||||
|
122
router_test.go
122
router_test.go
@ -3,6 +3,7 @@ package echo
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -725,12 +726,13 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
|
||||
r.Add(http.MethodPost, "/users/:id", handlerFunc)
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenMethod string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectParam map[string]string
|
||||
expectError error
|
||||
name string
|
||||
whenMethod string
|
||||
whenURL string
|
||||
expectRoute interface{}
|
||||
expectParam map[string]string
|
||||
expectError error
|
||||
expectAllowHeader string
|
||||
}{
|
||||
{
|
||||
name: "exact match for route+method",
|
||||
@ -740,11 +742,12 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
|
||||
expectParam: map[string]string{"id": "1"},
|
||||
},
|
||||
{
|
||||
name: "matches node but not method. sends 405 from best match node",
|
||||
whenMethod: http.MethodPut,
|
||||
whenURL: "/users/1",
|
||||
expectRoute: nil,
|
||||
expectError: ErrMethodNotAllowed,
|
||||
name: "matches node but not method. sends 405 from best match node",
|
||||
whenMethod: http.MethodPut,
|
||||
whenURL: "/users/1",
|
||||
expectRoute: nil,
|
||||
expectError: ErrMethodNotAllowed,
|
||||
expectAllowHeader: "OPTIONS, POST",
|
||||
},
|
||||
{
|
||||
name: "best match is any route up in tree",
|
||||
@ -756,7 +759,9 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := e.NewContext(nil, nil).(*context)
|
||||
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
method := http.MethodGet
|
||||
if tc.whenMethod != "" {
|
||||
@ -775,10 +780,36 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
|
||||
assert.Equal(t, expectedValue, c.Param(param))
|
||||
}
|
||||
checkUnusedParamValues(t, c, tc.expectParam)
|
||||
|
||||
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterOptionsMethodHandler(t *testing.T) {
|
||||
e := New()
|
||||
|
||||
var keyInContext interface{}
|
||||
e.Use(func(next HandlerFunc) HandlerFunc {
|
||||
return func(c Context) error {
|
||||
err := next(c)
|
||||
keyInContext = c.Get(ContextKeyHeaderAllow)
|
||||
return err
|
||||
}
|
||||
})
|
||||
e.GET("/test", func(c Context) error {
|
||||
return c.String(http.StatusOK, "Echo!")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, rec.Code)
|
||||
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
|
||||
assert.Equal(t, "OPTIONS, GET", keyInContext)
|
||||
}
|
||||
|
||||
func TestRouterTwoParam(t *testing.T) {
|
||||
e := New()
|
||||
r := e.router
|
||||
@ -2288,6 +2319,73 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouterHandleMethodOptions(t *testing.T) {
|
||||
e := New()
|
||||
r := e.router
|
||||
|
||||
r.Add(http.MethodGet, "/users", handlerFunc)
|
||||
r.Add(http.MethodPost, "/users", handlerFunc)
|
||||
r.Add(http.MethodPut, "/users/:id", handlerFunc)
|
||||
r.Add(http.MethodGet, "/users/:id", handlerFunc)
|
||||
|
||||
var testCases = []struct {
|
||||
name string
|
||||
whenMethod string
|
||||
whenURL string
|
||||
expectAllowHeader string
|
||||
expectStatus int
|
||||
}{
|
||||
{
|
||||
name: "allows GET and POST handlers",
|
||||
whenMethod: http.MethodOptions,
|
||||
whenURL: "/users",
|
||||
expectAllowHeader: "OPTIONS, GET, POST",
|
||||
expectStatus: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "allows GET and PUT handlers",
|
||||
whenMethod: http.MethodOptions,
|
||||
whenURL: "/users/1",
|
||||
expectAllowHeader: "OPTIONS, GET, PUT",
|
||||
expectStatus: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "GET does not have allows header",
|
||||
whenMethod: http.MethodGet,
|
||||
whenURL: "/users",
|
||||
expectAllowHeader: "",
|
||||
expectStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "path with no handlers does not set Allows header",
|
||||
whenMethod: http.MethodOptions,
|
||||
whenURL: "/notFound",
|
||||
expectAllowHeader: "",
|
||||
expectStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c := e.NewContext(req, rec).(*context)
|
||||
|
||||
r.Find(tc.whenMethod, tc.whenURL, c)
|
||||
err := c.handler(c)
|
||||
|
||||
if tc.expectStatus >= 400 {
|
||||
assert.Error(t, err)
|
||||
he := err.(*HTTPError)
|
||||
assert.Equal(t, tc.expectStatus, he.Code)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tc.expectStatus, rec.Code)
|
||||
}
|
||||
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) {
|
||||
e := New()
|
||||
r := e.router
|
||||
|
Loading…
Reference in New Issue
Block a user