1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-17 01:43:02 +02:00

Allow header support in Router, MethodNotFoundHandler (405) and CORS middleware

This commit is contained in:
toimtoimtoim
2021-12-04 20:02:11 +02:00
committed by Martti T
parent 4fffee2ec8
commit 5b26a5257b
7 changed files with 467 additions and 145 deletions

View File

@ -210,6 +210,13 @@ type (
} }
) )
const (
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
// Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
// It is added to context only when Router does not find matching method handler for request.
ContextKeyHeaderAllow = "____echo____header_allow"
)
const ( const (
defaultMemory = 32 << 20 // 32 MB defaultMemory = 32 << 20 // 32 MB
indexPage = "index.html" indexPage = "index.html"

10
echo.go
View File

@ -192,6 +192,9 @@ const (
const ( const (
HeaderAccept = "Accept" HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding" HeaderAcceptEncoding = "Accept-Encoding"
// HeaderAllow is header field that lists the set of methods advertised as supported by the target resource.
// Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses.
// See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
HeaderAllow = "Allow" HeaderAllow = "Allow"
HeaderAuthorization = "Authorization" HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition" HeaderContentDisposition = "Content-Disposition"
@ -302,6 +305,13 @@ var (
} }
MethodNotAllowedHandler = func(c Context) error { MethodNotAllowedHandler = func(c Context) error {
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
// >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response
// and MAY do so in any other response.
routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string)
if ok && routerAllowMethods != "" {
c.Response().Header().Set(HeaderAllow, routerAllowMethods)
}
return ErrMethodNotAllowed return ErrMethodNotAllowed
} }
) )

View File

@ -716,13 +716,16 @@ func TestEchoNotFound(t *testing.T) {
func TestEchoMethodNotAllowed(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) {
e := New() e := New()
e.GET("/", func(c Context) error { e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "Echo!") return c.String(http.StatusOK, "Echo!")
}) })
req := httptest.NewRequest(http.MethodPost, "/", nil) req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(rec, req) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
} }
func TestEchoContext(t *testing.T) { func TestEchoContext(t *testing.T) {

View File

@ -29,6 +29,8 @@ type (
// AllowMethods defines a list methods allowed when accessing the resource. // AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request. // This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods. // Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context.
AllowMethods []string `yaml:"allow_methods"` AllowMethods []string `yaml:"allow_methods"`
// AllowHeaders defines a list of request headers that can be used when // AllowHeaders defines a list of request headers that can be used when
@ -41,6 +43,8 @@ type (
// a response to a preflight request, this indicates whether or not the // a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials. // actual request can be made using credentials.
// Optional. Default value false. // Optional. Default value false.
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
AllowCredentials bool `yaml:"allow_credentials"` AllowCredentials bool `yaml:"allow_credentials"`
// ExposeHeaders defines a whitelist headers that clients are allowed to // ExposeHeaders defines a whitelist headers that clients are allowed to
@ -80,7 +84,9 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
if len(config.AllowOrigins) == 0 { if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins config.AllowOrigins = DefaultCORSConfig.AllowOrigins
} }
hasCustomAllowMethods := true
if len(config.AllowMethods) == 0 { if len(config.AllowMethods) == 0 {
hasCustomAllowMethods = false
config.AllowMethods = DefaultCORSConfig.AllowMethods config.AllowMethods = DefaultCORSConfig.AllowMethods
} }
@ -109,10 +115,28 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin) origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := "" allowOrigin := ""
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
// as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
// middlewares by calling next(c).
// But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
// handler does.
routerAllowMethods := ""
if preflight {
tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
if ok && tmpAllowMethods != "" {
routerAllowMethods = tmpAllowMethods
c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
}
}
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" { if origin == "" {
if !preflight { if !preflight {
return next(c) return next(c)
@ -145,19 +169,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
} }
// Check allowed origin patterns checkPatterns := false
for _, re := range allowOriginPatterns {
if allowOrigin == "" { if allowOrigin == "" {
didx := strings.Index(origin, "://") // to avoid regex cost by invalid (long) domains (253 is domain name max limit)
if didx == -1 { if len(origin) <= (253+3+4) && strings.Contains(origin, "://") {
continue checkPatterns = true
} }
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
} }
if checkPatterns {
for _, re := range allowOriginPatterns {
if match, _ := regexp.MatchString(re, origin); match { if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin allowOrigin = origin
break break
@ -174,12 +194,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
} }
// Simple request
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials { if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
} }
// Simple request
if !preflight {
if exposeHeaders != "" { if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
} }
@ -189,11 +210,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Preflight request // Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if !hasCustomAllowMethods && routerAllowMethods != "" {
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
} else {
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
} }
if allowHeaders != "" { if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else { } else {

View File

@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) {
} }
} }
func TestCorsHeaders(t *testing.T) { func TestCORSWithConfig_AllowMethods(t *testing.T) {
tests := []struct { var testCases = []struct {
domain, allowedOrigin, method string name string
expected bool allowOrigins []string
allowContextKey string
whenOrigin string
whenAllowMethods []string
expectAllow string
expectAccessControlAllowMethods string
}{ }{
{ {
domain: "", // Request does not have Origin header name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
allowedOrigin: "*", allowContextKey: "OPTIONS, GET",
method: http.MethodGet, whenAllowMethods: []string{http.MethodGet, http.MethodHead},
expected: false, whenOrigin: "",
expectAllow: "OPTIONS, GET",
}, },
{ {
domain: "http://example.com", name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
allowedOrigin: "*", allowContextKey: "",
method: http.MethodGet, whenAllowMethods: nil,
expected: true, whenOrigin: "",
expectAllow: "",
}, },
{ {
domain: "", // Request does not have Origin header name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
allowedOrigin: "http://example.com", allowContextKey: "OPTIONS, GET",
method: http.MethodGet, whenAllowMethods: []string{http.MethodGet, http.MethodHead},
expected: false, whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "GET,HEAD",
}, },
{ {
domain: "http://bar.com", name: "default AllowMethods, preflight, existing origin, sets both headers",
allowedOrigin: "http://example.com", allowContextKey: "OPTIONS, GET",
method: http.MethodGet, whenAllowMethods: nil,
expected: false, whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "OPTIONS, GET",
}, },
{ {
domain: "http://example.com", name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
allowedOrigin: "http://example.com", allowContextKey: "",
method: http.MethodGet, whenAllowMethods: nil,
expected: true, whenOrigin: "http://google.com",
}, expectAllow: "",
{ expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
}, },
} }
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New() e := echo.New()
for _, tt := range tests { e.GET("/test", func(c echo.Context) error {
req := httptest.NewRequest(tt.method, "/", nil) return c.String(http.StatusOK, "OK")
})
cors := CORSWithConfig(CORSConfig{
AllowOrigins: tc.allowOrigins,
AllowMethods: tc.whenAllowMethods,
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain) req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
if tc.allowContextKey != "" {
c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
} }
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
})
h := cors(echo.NotFoundHandler) h := cors(echo.NotFoundHandler)
h(c) h(c)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
expectedAllowOrigin := "" })
if tt.allowedOrigin == "*" { }
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
} }
func TestCorsHeaders(t *testing.T) {
tests := []struct {
name string
originDomain string
method string
allowedOrigin string
expected bool
expectStatus int
expectAllowHeader string
}{
{
name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
originDomain: "",
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow any origin, specific origin domain",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, existing origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, different origin header = no CORS logic done",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.Use(CORSWithConfig(CORSConfig{
AllowOrigins: []string{tc.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
}))
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
e.POST("/", func(c echo.Context) error {
return c.String(http.StatusCreated, "OK")
})
req := httptest.NewRequest(tc.method, "/", nil)
rec := httptest.NewRecorder()
if tc.originDomain != "" {
req.Header.Set(echo.HeaderOrigin, tc.originDomain)
}
// we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
e.ServeHTTP(rec, req)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectStatus, rec.Code)
expectedAllowOrigin := ""
if tc.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tc.originDomain
}
switch { switch {
case tt.expected && tt.method == http.MethodOptions: case tc.expected && tc.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
case tc.expected && tc.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default: default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
} }
})
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
} }
} }

View File

@ -1,6 +1,7 @@
package echo package echo
import ( import (
"bytes"
"net/http" "net/http"
) )
@ -42,6 +43,7 @@ type (
put HandlerFunc put HandlerFunc
trace HandlerFunc trace HandlerFunc
report HandlerFunc report HandlerFunc
allowHeader string
} }
) )
@ -68,6 +70,51 @@ func (m *methodHandler) isHandler() bool {
m.report != nil m.report != nil
} }
func (m *methodHandler) updateAllowHeader() {
buf := new(bytes.Buffer)
buf.WriteString(http.MethodOptions)
if m.connect != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodConnect)
}
if m.delete != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodDelete)
}
if m.get != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodGet)
}
if m.head != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodHead)
}
if m.patch != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPatch)
}
if m.post != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPost)
}
if m.propfind != nil {
buf.WriteString(", PROPFIND")
}
if m.put != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPut)
}
if m.trace != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodTrace)
}
if m.report != nil {
buf.WriteString(", REPORT")
}
m.allowHeader = buf.String()
}
// NewRouter returns a new Router instance. // NewRouter returns a new Router instance.
func NewRouter(e *Echo) *Router { func NewRouter(e *Echo) *Router {
return &Router{ return &Router{
@ -326,6 +373,7 @@ func (n *node) addHandler(method string, h HandlerFunc) {
n.methodHandler.report = h n.methodHandler.report = h
} }
n.methodHandler.updateAllowHeader()
if h != nil { if h != nil {
n.isHandler = true n.isHandler = true
} else { } else {
@ -362,14 +410,15 @@ func (n *node) findHandler(method string) HandlerFunc {
} }
} }
func (n *node) checkMethodNotAllowed() HandlerFunc { func optionsMethodHandler(allowMethods string) func(c Context) error {
for _, m := range methods { return func(c Context) error {
if h := n.findHandler(m); h != nil { // Note: we are not handling most of the CORS headers here. CORS is handled by CORS middleware
return MethodNotAllowedHandler // 'OPTIONS' method RFC: https://httpwg.org/specs/rfc7231.html#OPTIONS
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
c.Response().Header().Add(HeaderAllow, allowMethods)
return c.NoContent(http.StatusNoContent)
} }
} }
return NotFoundHandler
}
// Find lookup a handler registered for method and path. It also parses URL for path // Find lookup a handler registered for method and path. It also parses URL for path
// parameters and load them into context. // parameters and load them into context.
@ -563,7 +612,15 @@ func (r *Router) Find(method, path string, c Context) {
// use previous match as basis. although we have no matching handler we have path match. // use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode currentNode = previousBestMatchNode
ctx.handler = currentNode.checkMethodNotAllowed()
ctx.handler = NotFoundHandler
if currentNode.isHandler {
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
ctx.handler = MethodNotAllowedHandler
if method == http.MethodOptions {
ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader)
}
}
} }
ctx.path = currentNode.ppath ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames ctx.pnames = currentNode.pnames

View File

@ -3,6 +3,7 @@ package echo
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest"
"strings" "strings"
"testing" "testing"
@ -731,6 +732,7 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
expectRoute interface{} expectRoute interface{}
expectParam map[string]string expectParam map[string]string
expectError error expectError error
expectAllowHeader string
}{ }{
{ {
name: "exact match for route+method", name: "exact match for route+method",
@ -745,6 +747,7 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
whenURL: "/users/1", whenURL: "/users/1",
expectRoute: nil, expectRoute: nil,
expectError: ErrMethodNotAllowed, expectError: ErrMethodNotAllowed,
expectAllowHeader: "OPTIONS, POST",
}, },
{ {
name: "best match is any route up in tree", name: "best match is any route up in tree",
@ -756,7 +759,9 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
c := e.NewContext(nil, nil).(*context) req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
method := http.MethodGet method := http.MethodGet
if tc.whenMethod != "" { if tc.whenMethod != "" {
@ -775,10 +780,36 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
assert.Equal(t, expectedValue, c.Param(param)) assert.Equal(t, expectedValue, c.Param(param))
} }
checkUnusedParamValues(t, c, tc.expectParam) checkUnusedParamValues(t, c, tc.expectParam)
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow))
}) })
} }
} }
func TestRouterOptionsMethodHandler(t *testing.T) {
e := New()
var keyInContext interface{}
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
err := next(c)
keyInContext = c.Get(ContextKeyHeaderAllow)
return err
}
})
e.GET("/test", func(c Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
assert.Equal(t, "OPTIONS, GET", keyInContext)
}
func TestRouterTwoParam(t *testing.T) { func TestRouterTwoParam(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
@ -2288,6 +2319,73 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
} }
} }
func TestRouterHandleMethodOptions(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users", handlerFunc)
r.Add(http.MethodPost, "/users", handlerFunc)
r.Add(http.MethodPut, "/users/:id", handlerFunc)
r.Add(http.MethodGet, "/users/:id", handlerFunc)
var testCases = []struct {
name string
whenMethod string
whenURL string
expectAllowHeader string
expectStatus int
}{
{
name: "allows GET and POST handlers",
whenMethod: http.MethodOptions,
whenURL: "/users",
expectAllowHeader: "OPTIONS, GET, POST",
expectStatus: http.StatusNoContent,
},
{
name: "allows GET and PUT handlers",
whenMethod: http.MethodOptions,
whenURL: "/users/1",
expectAllowHeader: "OPTIONS, GET, PUT",
expectStatus: http.StatusNoContent,
},
{
name: "GET does not have allows header",
whenMethod: http.MethodGet,
whenURL: "/users",
expectAllowHeader: "",
expectStatus: http.StatusOK,
},
{
name: "path with no handlers does not set Allows header",
whenMethod: http.MethodOptions,
whenURL: "/notFound",
expectAllowHeader: "",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
r.Find(tc.whenMethod, tc.whenURL, c)
err := c.handler(c)
if tc.expectStatus >= 400 {
assert.Error(t, err)
he := err.(*HTTPError)
assert.Equal(t, tc.expectStatus, he.Code)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expectStatus, rec.Code)
}
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow"))
})
}
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) {
e := New() e := New()
r := e.router r := e.router