1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-15 01:34:53 +02:00

Allow header support in Router, MethodNotFoundHandler (405) and CORS middleware

This commit is contained in:
toimtoimtoim
2021-12-04 20:02:11 +02:00
committed by Martti T
parent 4fffee2ec8
commit 5b26a5257b
7 changed files with 467 additions and 145 deletions

View File

@ -210,6 +210,13 @@ type (
} }
) )
const (
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
// Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
// It is added to context only when Router does not find matching method handler for request.
ContextKeyHeaderAllow = "____echo____header_allow"
)
const ( const (
defaultMemory = 32 << 20 // 32 MB defaultMemory = 32 << 20 // 32 MB
indexPage = "index.html" indexPage = "index.html"

14
echo.go
View File

@ -190,8 +190,11 @@ const (
// Headers // Headers
const ( const (
HeaderAccept = "Accept" HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding" HeaderAcceptEncoding = "Accept-Encoding"
// HeaderAllow is header field that lists the set of methods advertised as supported by the target resource.
// Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses.
// See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
HeaderAllow = "Allow" HeaderAllow = "Allow"
HeaderAuthorization = "Authorization" HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition" HeaderContentDisposition = "Content-Disposition"
@ -302,6 +305,13 @@ var (
} }
MethodNotAllowedHandler = func(c Context) error { MethodNotAllowedHandler = func(c Context) error {
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
// >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response
// and MAY do so in any other response.
routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string)
if ok && routerAllowMethods != "" {
c.Response().Header().Set(HeaderAllow, routerAllowMethods)
}
return ErrMethodNotAllowed return ErrMethodNotAllowed
} }
) )

View File

@ -716,13 +716,16 @@ func TestEchoNotFound(t *testing.T) {
func TestEchoMethodNotAllowed(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) {
e := New() e := New()
e.GET("/", func(c Context) error { e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "Echo!") return c.String(http.StatusOK, "Echo!")
}) })
req := httptest.NewRequest(http.MethodPost, "/", nil) req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
} }
func TestEchoContext(t *testing.T) { func TestEchoContext(t *testing.T) {

View File

@ -29,6 +29,8 @@ type (
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context.
AllowMethods []string `yaml:"allow_methods"` AllowMethods []string `yaml:"allow_methods"`
// AllowHeaders defines a list of request headers that can be used when // AllowHeaders defines a list of request headers that can be used when
@ -41,6 +43,8 @@ type (
// a response to a preflight request, this indicates whether or not the // a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. // actual request can be made using credentials.
// Optional. Default value false. // Optional. Default value false.
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
AllowCredentials bool `yaml:"allow_credentials"` AllowCredentials bool `yaml:"allow_credentials"`
// ExposeHeaders defines a whitelist headers that clients are allowed to // ExposeHeaders defines a whitelist headers that clients are allowed to
@ -80,7 +84,9 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
if len(config.AllowOrigins) == 0 { if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins config.AllowOrigins = DefaultCORSConfig.AllowOrigins
} }
hasCustomAllowMethods := true
if len(config.AllowMethods) == 0 { if len(config.AllowMethods) == 0 {
hasCustomAllowMethods = false
config.AllowMethods = DefaultCORSConfig.AllowMethods config.AllowMethods = DefaultCORSConfig.AllowMethods
} }
@ -109,10 +115,28 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin) origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := "" allowOrigin := ""
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
// as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
// middlewares by calling next(c).
// But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
// handler does.
routerAllowMethods := ""
if preflight {
tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
if ok && tmpAllowMethods != "" {
routerAllowMethods = tmpAllowMethods
c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
}
}
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" { if origin == "" {
if !preflight { if !preflight {
return next(c) return next(c)
@ -145,19 +169,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
} }
// Check allowed origin patterns checkPatterns := false
for _, re := range allowOriginPatterns { if allowOrigin == "" {
if allowOrigin == "" { // to avoid regex cost by invalid (long) domains (253 is domain name max limit)
didx := strings.Index(origin, "://") if len(origin) <= (253+3+4) && strings.Contains(origin, "://") {
if didx == -1 { checkPatterns = true
continue }
} }
domAuth := origin[didx+3:] if checkPatterns {
// to avoid regex cost by invalid long domain for _, re := range allowOriginPatterns {
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match { if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin allowOrigin = origin
break break
@ -174,12 +194,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
} }
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
// Simple request // Simple request
if !preflight { if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" { if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
} }
@ -189,11 +210,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Preflight request // Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) if !hasCustomAllowMethods && routerAllowMethods != "" {
if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") } else {
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
} }
if allowHeaders != "" { if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else { } else {

View File

@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) {
} }
} }
func TestCorsHeaders(t *testing.T) { func TestCORSWithConfig_AllowMethods(t *testing.T) {
tests := []struct { var testCases = []struct {
domain, allowedOrigin, method string name string
expected bool allowOrigins []string
allowContextKey string
whenOrigin string
whenAllowMethods []string
expectAllow string
expectAccessControlAllowMethods string
}{ }{
{ {
domain: "", // Request does not have Origin header name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
allowedOrigin: "*", allowContextKey: "OPTIONS, GET",
method: http.MethodGet, whenAllowMethods: []string{http.MethodGet, http.MethodHead},
expected: false, whenOrigin: "",
expectAllow: "OPTIONS, GET",
}, },
{ {
domain: "http://example.com", name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
allowedOrigin: "*", allowContextKey: "",
method: http.MethodGet, whenAllowMethods: nil,
expected: true, whenOrigin: "",
expectAllow: "",
}, },
{ {
domain: "", // Request does not have Origin header name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
allowedOrigin: "http://example.com", allowContextKey: "OPTIONS, GET",
method: http.MethodGet, whenAllowMethods: []string{http.MethodGet, http.MethodHead},
expected: false, whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "GET,HEAD",
}, },
{ {
domain: "http://bar.com", name: "default AllowMethods, preflight, existing origin, sets both headers",
allowedOrigin: "http://example.com", allowContextKey: "OPTIONS, GET",
method: http.MethodGet, whenAllowMethods: nil,
expected: false, whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "OPTIONS, GET",
}, },
{ {
domain: "http://example.com", name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
allowedOrigin: "http://example.com", allowContextKey: "",
method: http.MethodGet, whenAllowMethods: nil,
expected: true, whenOrigin: "http://google.com",
}, expectAllow: "",
{ expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
}, },
} }
e := echo.New() for _, tc := range testCases {
for _, tt := range tests { t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, "/", nil) e := echo.New()
rec := httptest.NewRecorder() e.GET("/test", func(c echo.Context) error {
c := e.NewContext(req, rec) return c.String(http.StatusOK, "OK")
if tt.domain != "" { })
req.Header.Set(echo.HeaderOrigin, tt.domain)
} cors := CORSWithConfig(CORSConfig{
cors := CORSWithConfig(CORSConfig{ AllowOrigins: tc.allowOrigins,
AllowOrigins: []string{tt.allowedOrigin}, AllowMethods: tc.whenAllowMethods,
//AllowCredentials: true, })
//MaxAge: 3600,
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
if tc.allowContextKey != "" {
c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
}
h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
}) })
h := cors(echo.NotFoundHandler) }
h(c) }
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) func TestCorsHeaders(t *testing.T) {
tests := []struct {
name string
originDomain string
method string
allowedOrigin string
expected bool
expectStatus int
expectAllowHeader string
}{
{
name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
originDomain: "",
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow any origin, specific origin domain",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, existing origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, different origin header = no CORS logic done",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
}
expectedAllowOrigin := "" for _, tc := range tests {
if tt.allowedOrigin == "*" { t.Run(tc.name, func(t *testing.T) {
expectedAllowOrigin = "*" e := echo.New()
} else {
expectedAllowOrigin = tt.domain
}
switch { e.Use(CORSWithConfig(CORSConfig{
case tt.expected && tt.method == http.MethodOptions: AllowOrigins: []string{tc.allowedOrigin},
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) //AllowCredentials: true,
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) //MaxAge: 3600,
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) }))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) e.GET("/", func(c echo.Context) error {
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin return c.String(http.StatusOK, "OK")
default: })
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) e.POST("/", func(c echo.Context) error {
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin return c.String(http.StatusCreated, "OK")
} })
req := httptest.NewRequest(tc.method, "/", nil)
rec := httptest.NewRecorder()
if tc.originDomain != "" {
req.Header.Set(echo.HeaderOrigin, tc.originDomain)
}
// we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
e.ServeHTTP(rec, req)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectStatus, rec.Code)
expectedAllowOrigin := ""
if tc.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tc.originDomain
}
switch {
case tc.expected && tc.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tc.expected && tc.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
})
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
} }
} }

View File

@ -1,6 +1,7 @@
package echo package echo
import ( import (
"bytes"
"net/http" "net/http"
) )
@ -31,17 +32,18 @@ type (
kind uint8 kind uint8
children []*node children []*node
methodHandler struct { methodHandler struct {
connect HandlerFunc connect HandlerFunc
delete HandlerFunc delete HandlerFunc
get HandlerFunc get HandlerFunc
head HandlerFunc head HandlerFunc
options HandlerFunc options HandlerFunc
patch HandlerFunc patch HandlerFunc
post HandlerFunc post HandlerFunc
propfind HandlerFunc propfind HandlerFunc
put HandlerFunc put HandlerFunc
trace HandlerFunc trace HandlerFunc
report HandlerFunc report HandlerFunc
allowHeader string
} }
) )
@ -68,6 +70,51 @@ func (m *methodHandler) isHandler() bool {
m.report != nil m.report != nil
} }
func (m *methodHandler) updateAllowHeader() {
buf := new(bytes.Buffer)
buf.WriteString(http.MethodOptions)
if m.connect != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodConnect)
}
if m.delete != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodDelete)
}
if m.get != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodGet)
}
if m.head != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodHead)
}
if m.patch != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPatch)
}
if m.post != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPost)
}
if m.propfind != nil {
buf.WriteString(", PROPFIND")
}
if m.put != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPut)
}
if m.trace != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodTrace)
}
if m.report != nil {
buf.WriteString(", REPORT")
}
m.allowHeader = buf.String()
}
// NewRouter returns a new Router instance. // NewRouter returns a new Router instance.
func NewRouter(e *Echo) *Router { func NewRouter(e *Echo) *Router {
return &Router{ return &Router{
@ -326,6 +373,7 @@ func (n *node) addHandler(method string, h HandlerFunc) {
n.methodHandler.report = h n.methodHandler.report = h
} }
n.methodHandler.updateAllowHeader()
if h != nil { if h != nil {
n.isHandler = true n.isHandler = true
} else { } else {
@ -362,13 +410,14 @@ func (n *node) findHandler(method string) HandlerFunc {
} }
} }
func (n *node) checkMethodNotAllowed() HandlerFunc { func optionsMethodHandler(allowMethods string) func(c Context) error {
for _, m := range methods { return func(c Context) error {
if h := n.findHandler(m); h != nil { // Note: we are not handling most of the CORS headers here. CORS is handled by CORS middleware
return MethodNotAllowedHandler // 'OPTIONS' method RFC: https://httpwg.org/specs/rfc7231.html#OPTIONS
} // 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
c.Response().Header().Add(HeaderAllow, allowMethods)
return c.NoContent(http.StatusNoContent)
} }
return NotFoundHandler
} }
// Find lookup a handler registered for method and path. It also parses URL for path // Find lookup a handler registered for method and path. It also parses URL for path
@ -563,7 +612,15 @@ func (r *Router) Find(method, path string, c Context) {
// use previous match as basis. although we have no matching handler we have path match. // use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode currentNode = previousBestMatchNode
ctx.handler = currentNode.checkMethodNotAllowed()
ctx.handler = NotFoundHandler
if currentNode.isHandler {
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
ctx.handler = MethodNotAllowedHandler
if method == http.MethodOptions {
ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader)
}
}
} }
ctx.path = currentNode.ppath ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames ctx.pnames = currentNode.pnames

View File

@ -3,6 +3,7 @@ package echo
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
@ -725,12 +726,13 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
r.Add(http.MethodPost, "/users/:id", handlerFunc) r.Add(http.MethodPost, "/users/:id", handlerFunc)
var testCases = []struct { var testCases = []struct {
name string name string
whenMethod string whenMethod string
whenURL string whenURL string
expectRoute interface{} expectRoute interface{}
expectParam map[string]string expectParam map[string]string
expectError error expectError error
expectAllowHeader string
}{ }{
{ {
name: "exact match for route+method", name: "exact match for route+method",
@ -740,11 +742,12 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
expectParam: map[string]string{"id": "1"}, expectParam: map[string]string{"id": "1"},
}, },
{ {
name: "matches node but not method. sends 405 from best match node", name: "matches node but not method. sends 405 from best match node",
whenMethod: http.MethodPut, whenMethod: http.MethodPut,
whenURL: "/users/1", whenURL: "/users/1",
expectRoute: nil, expectRoute: nil,
expectError: ErrMethodNotAllowed, expectError: ErrMethodNotAllowed,
expectAllowHeader: "OPTIONS, POST",
}, },
{ {
name: "best match is any route up in tree", name: "best match is any route up in tree",
@ -756,7 +759,9 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
c := e.NewContext(nil, nil).(*context) req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
method := http.MethodGet method := http.MethodGet
if tc.whenMethod != "" { if tc.whenMethod != "" {
@ -775,10 +780,36 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
assert.Equal(t, expectedValue, c.Param(param)) assert.Equal(t, expectedValue, c.Param(param))
} }
checkUnusedParamValues(t, c, tc.expectParam) checkUnusedParamValues(t, c, tc.expectParam)
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow))
}) })
} }
} }
func TestRouterOptionsMethodHandler(t *testing.T) {
e := New()
var keyInContext interface{}
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
err := next(c)
keyInContext = c.Get(ContextKeyHeaderAllow)
return err
}
})
e.GET("/test", func(c Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
assert.Equal(t, "OPTIONS, GET", keyInContext)
}
func TestRouterTwoParam(t *testing.T) { func TestRouterTwoParam(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
@ -2288,6 +2319,73 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
} }
} }
func TestRouterHandleMethodOptions(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users", handlerFunc)
r.Add(http.MethodPost, "/users", handlerFunc)
r.Add(http.MethodPut, "/users/:id", handlerFunc)
r.Add(http.MethodGet, "/users/:id", handlerFunc)
var testCases = []struct {
name string
whenMethod string
whenURL string
expectAllowHeader string
expectStatus int
}{
{
name: "allows GET and POST handlers",
whenMethod: http.MethodOptions,
whenURL: "/users",
expectAllowHeader: "OPTIONS, GET, POST",
expectStatus: http.StatusNoContent,
},
{
name: "allows GET and PUT handlers",
whenMethod: http.MethodOptions,
whenURL: "/users/1",
expectAllowHeader: "OPTIONS, GET, PUT",
expectStatus: http.StatusNoContent,
},
{
name: "GET does not have allows header",
whenMethod: http.MethodGet,
whenURL: "/users",
expectAllowHeader: "",
expectStatus: http.StatusOK,
},
{
name: "path with no handlers does not set Allows header",
whenMethod: http.MethodOptions,
whenURL: "/notFound",
expectAllowHeader: "",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
r.Find(tc.whenMethod, tc.whenURL, c)
err := c.handler(c)
if tc.expectStatus >= 400 {
assert.Error(t, err)
he := err.(*HTTPError)
assert.Equal(t, tc.expectStatus, he.Code)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expectStatus, rec.Code)
}
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow"))
})
}
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) {
e := New() e := New()
r := e.router r := e.router