1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Allow header support in Router, MethodNotFoundHandler (405) and CORS middleware

This commit is contained in:
toimtoimtoim 2021-12-04 20:02:11 +02:00 committed by Martti T
parent 4fffee2ec8
commit 5b26a5257b
7 changed files with 467 additions and 145 deletions

View File

@ -210,6 +210,13 @@ type (
}
)
const (
// ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain.
// Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests.
// It is added to context only when Router does not find matching method handler for request.
ContextKeyHeaderAllow = "____echo____header_allow"
)
const (
defaultMemory = 32 << 20 // 32 MB
indexPage = "index.html"

14
echo.go
View File

@ -190,8 +190,11 @@ const (
// Headers
const (
HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding"
HeaderAccept = "Accept"
HeaderAcceptEncoding = "Accept-Encoding"
// HeaderAllow is header field that lists the set of methods advertised as supported by the target resource.
// Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses.
// See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
HeaderAllow = "Allow"
HeaderAuthorization = "Authorization"
HeaderContentDisposition = "Content-Disposition"
@ -302,6 +305,13 @@ var (
}
MethodNotAllowedHandler = func(c Context) error {
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
// >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response
// and MAY do so in any other response.
routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string)
if ok && routerAllowMethods != "" {
c.Response().Header().Set(HeaderAllow, routerAllowMethods)
}
return ErrMethodNotAllowed
}
)

View File

@ -716,13 +716,16 @@ func TestEchoNotFound(t *testing.T) {
func TestEchoMethodNotAllowed(t *testing.T) {
e := New()
e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodPost, "/", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
}
func TestEchoContext(t *testing.T) {

View File

@ -29,6 +29,8 @@ type (
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
// If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value
// from `Allow` header that echo.Router set into context.
AllowMethods []string `yaml:"allow_methods"`
// AllowHeaders defines a list of request headers that can be used when
@ -41,6 +43,8 @@ type (
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
// Optional. Default value false.
// Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`.
// See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html
AllowCredentials bool `yaml:"allow_credentials"`
// ExposeHeaders defines a whitelist headers that clients are allowed to
@ -80,7 +84,9 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
if len(config.AllowOrigins) == 0 {
config.AllowOrigins = DefaultCORSConfig.AllowOrigins
}
hasCustomAllowMethods := true
if len(config.AllowMethods) == 0 {
hasCustomAllowMethods = false
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
@ -109,10 +115,28 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided
// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
preflight := req.Method == http.MethodOptions
// Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware
// as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth
// middlewares by calling next(c).
// But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default
// handler does.
routerAllowMethods := ""
if preflight {
tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string)
if ok && tmpAllowMethods != "" {
routerAllowMethods = tmpAllowMethods
c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods)
}
}
// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
if !preflight {
return next(c)
@ -145,19 +169,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
checkPatterns := false
if allowOrigin == "" {
// to avoid regex cost by invalid (long) domains (253 is domain name max limit)
if len(origin) <= (253+3+4) && strings.Contains(origin, "://") {
checkPatterns = true
}
}
if checkPatterns {
for _, re := range allowOriginPatterns {
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
@ -174,12 +194,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
return c.NoContent(http.StatusNoContent)
}
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
// Simple request
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
@ -189,11 +210,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
if !hasCustomAllowMethods && routerAllowMethods != "" {
res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods)
} else {
res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods)
}
if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else {

View File

@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) {
}
}
func TestCorsHeaders(t *testing.T) {
tests := []struct {
domain, allowedOrigin, method string
expected bool
func TestCORSWithConfig_AllowMethods(t *testing.T) {
var testCases = []struct {
name string
allowOrigins []string
allowContextKey string
whenOrigin string
whenAllowMethods []string
expectAllow string
expectAccessControlAllowMethods string
}{
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
name: "custom AllowMethods, preflight, no origin, sets only allow header from context key",
allowContextKey: "OPTIONS, GET",
whenAllowMethods: []string{http.MethodGet, http.MethodHead},
whenOrigin: "",
expectAllow: "OPTIONS, GET",
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
name: "default AllowMethods, preflight, no origin, no allow header in context key and in response",
allowContextKey: "",
whenAllowMethods: nil,
whenOrigin: "",
expectAllow: "",
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
name: "custom AllowMethods, preflight, existing origin, sets both headers different values",
allowContextKey: "OPTIONS, GET",
whenAllowMethods: []string{http.MethodGet, http.MethodHead},
whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "GET,HEAD",
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
name: "default AllowMethods, preflight, existing origin, sets both headers",
allowContextKey: "OPTIONS, GET",
whenAllowMethods: nil,
whenOrigin: "http://google.com",
expectAllow: "OPTIONS, GET",
expectAccessControlAllowMethods: "OPTIONS, GET",
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
},
{
domain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
},
{
domain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
},
{
domain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods",
allowContextKey: "",
whenAllowMethods: nil,
whenOrigin: "http://google.com",
expectAllow: "",
expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE",
},
}
e := echo.New()
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
if tt.domain != "" {
req.Header.Set(echo.HeaderOrigin, tt.domain)
}
cors := CORSWithConfig(CORSConfig{
AllowOrigins: []string{tt.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
e.GET("/test", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
cors := CORSWithConfig(CORSConfig{
AllowOrigins: tc.allowOrigins,
AllowMethods: tc.whenAllowMethods,
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
req.Header.Set(echo.HeaderOrigin, tc.whenOrigin)
if tc.allowContextKey != "" {
c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey)
}
h := cors(echo.NotFoundHandler)
h(c)
assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods))
})
h := cors(echo.NotFoundHandler)
h(c)
}
}
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
func TestCorsHeaders(t *testing.T) {
tests := []struct {
name string
originDomain string
method string
allowedOrigin string
expected bool
expectStatus int
expectAllowHeader string
}{
{
name: "non-preflight request, allow any origin, missing origin header = no CORS logic done",
originDomain: "",
allowedOrigin: "*",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow any origin, specific origin domain",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, different origin header = CORS logic failure",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: false,
expectStatus: http.StatusOK,
},
{
name: "non-preflight request, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodGet,
expected: true,
expectStatus: http.StatusOK,
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "*",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, existing origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "*",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow any origin, missing origin header = no CORS logic done",
originDomain: "", // Request does not have Origin header
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, different origin header = no CORS logic done",
originDomain: "http://bar.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: false,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
{
name: "preflight, allow specific origin, matching origin header = CORS logic done",
originDomain: "http://example.com",
allowedOrigin: "http://example.com",
method: http.MethodOptions,
expected: true,
expectStatus: http.StatusNoContent,
expectAllowHeader: "OPTIONS, GET, POST",
},
}
expectedAllowOrigin := ""
if tt.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tt.domain
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
switch {
case tt.expected && tt.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tt.expected && tt.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
e.Use(CORSWithConfig(CORSConfig{
AllowOrigins: []string{tc.allowedOrigin},
//AllowCredentials: true,
//MaxAge: 3600,
}))
e.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
e.POST("/", func(c echo.Context) error {
return c.String(http.StatusCreated, "OK")
})
req := httptest.NewRequest(tc.method, "/", nil)
rec := httptest.NewRecorder()
if tc.originDomain != "" {
req.Header.Set(echo.HeaderOrigin, tc.originDomain)
}
// we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler
e.ServeHTTP(rec, req)
assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary))
assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow))
assert.Equal(t, tc.expectStatus, rec.Code)
expectedAllowOrigin := ""
if tc.allowedOrigin == "*" {
expectedAllowOrigin = "*"
} else {
expectedAllowOrigin = tc.originDomain
}
switch {
case tc.expected && tc.method == http.MethodOptions:
assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods)
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary]))
case tc.expected && tc.method == http.MethodGet:
assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
default:
assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin
}
})
if tt.method == http.MethodOptions {
assert.Equal(t, http.StatusNoContent, rec.Code)
}
}
}

View File

@ -1,6 +1,7 @@
package echo
import (
"bytes"
"net/http"
)
@ -31,17 +32,18 @@ type (
kind uint8
children []*node
methodHandler struct {
connect HandlerFunc
delete HandlerFunc
get HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
post HandlerFunc
propfind HandlerFunc
put HandlerFunc
trace HandlerFunc
report HandlerFunc
connect HandlerFunc
delete HandlerFunc
get HandlerFunc
head HandlerFunc
options HandlerFunc
patch HandlerFunc
post HandlerFunc
propfind HandlerFunc
put HandlerFunc
trace HandlerFunc
report HandlerFunc
allowHeader string
}
)
@ -68,6 +70,51 @@ func (m *methodHandler) isHandler() bool {
m.report != nil
}
func (m *methodHandler) updateAllowHeader() {
buf := new(bytes.Buffer)
buf.WriteString(http.MethodOptions)
if m.connect != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodConnect)
}
if m.delete != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodDelete)
}
if m.get != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodGet)
}
if m.head != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodHead)
}
if m.patch != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPatch)
}
if m.post != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPost)
}
if m.propfind != nil {
buf.WriteString(", PROPFIND")
}
if m.put != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodPut)
}
if m.trace != nil {
buf.WriteString(", ")
buf.WriteString(http.MethodTrace)
}
if m.report != nil {
buf.WriteString(", REPORT")
}
m.allowHeader = buf.String()
}
// NewRouter returns a new Router instance.
func NewRouter(e *Echo) *Router {
return &Router{
@ -326,6 +373,7 @@ func (n *node) addHandler(method string, h HandlerFunc) {
n.methodHandler.report = h
}
n.methodHandler.updateAllowHeader()
if h != nil {
n.isHandler = true
} else {
@ -362,13 +410,14 @@ func (n *node) findHandler(method string) HandlerFunc {
}
}
func (n *node) checkMethodNotAllowed() HandlerFunc {
for _, m := range methods {
if h := n.findHandler(m); h != nil {
return MethodNotAllowedHandler
}
func optionsMethodHandler(allowMethods string) func(c Context) error {
return func(c Context) error {
// Note: we are not handling most of the CORS headers here. CORS is handled by CORS middleware
// 'OPTIONS' method RFC: https://httpwg.org/specs/rfc7231.html#OPTIONS
// 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1
c.Response().Header().Add(HeaderAllow, allowMethods)
return c.NoContent(http.StatusNoContent)
}
return NotFoundHandler
}
// Find lookup a handler registered for method and path. It also parses URL for path
@ -563,7 +612,15 @@ func (r *Router) Find(method, path string, c Context) {
// use previous match as basis. although we have no matching handler we have path match.
// so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404)
currentNode = previousBestMatchNode
ctx.handler = currentNode.checkMethodNotAllowed()
ctx.handler = NotFoundHandler
if currentNode.isHandler {
ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader)
ctx.handler = MethodNotAllowedHandler
if method == http.MethodOptions {
ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader)
}
}
}
ctx.path = currentNode.ppath
ctx.pnames = currentNode.pnames

View File

@ -3,6 +3,7 @@ package echo
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
@ -725,12 +726,13 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
r.Add(http.MethodPost, "/users/:id", handlerFunc)
var testCases = []struct {
name string
whenMethod string
whenURL string
expectRoute interface{}
expectParam map[string]string
expectError error
name string
whenMethod string
whenURL string
expectRoute interface{}
expectParam map[string]string
expectError error
expectAllowHeader string
}{
{
name: "exact match for route+method",
@ -740,11 +742,12 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
expectParam: map[string]string{"id": "1"},
},
{
name: "matches node but not method. sends 405 from best match node",
whenMethod: http.MethodPut,
whenURL: "/users/1",
expectRoute: nil,
expectError: ErrMethodNotAllowed,
name: "matches node but not method. sends 405 from best match node",
whenMethod: http.MethodPut,
whenURL: "/users/1",
expectRoute: nil,
expectError: ErrMethodNotAllowed,
expectAllowHeader: "OPTIONS, POST",
},
{
name: "best match is any route up in tree",
@ -756,7 +759,9 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := e.NewContext(nil, nil).(*context)
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
method := http.MethodGet
if tc.whenMethod != "" {
@ -775,10 +780,36 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) {
assert.Equal(t, expectedValue, c.Param(param))
}
checkUnusedParamValues(t, c, tc.expectParam)
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow))
})
}
}
func TestRouterOptionsMethodHandler(t *testing.T) {
e := New()
var keyInContext interface{}
e.Use(func(next HandlerFunc) HandlerFunc {
return func(c Context) error {
err := next(c)
keyInContext = c.Get(ContextKeyHeaderAllow)
return err
}
})
e.GET("/test", func(c Context) error {
return c.String(http.StatusOK, "Echo!")
})
req := httptest.NewRequest(http.MethodOptions, "/test", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow))
assert.Equal(t, "OPTIONS, GET", keyInContext)
}
func TestRouterTwoParam(t *testing.T) {
e := New()
r := e.router
@ -2288,6 +2319,73 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) {
}
}
func TestRouterHandleMethodOptions(t *testing.T) {
e := New()
r := e.router
r.Add(http.MethodGet, "/users", handlerFunc)
r.Add(http.MethodPost, "/users", handlerFunc)
r.Add(http.MethodPut, "/users/:id", handlerFunc)
r.Add(http.MethodGet, "/users/:id", handlerFunc)
var testCases = []struct {
name string
whenMethod string
whenURL string
expectAllowHeader string
expectStatus int
}{
{
name: "allows GET and POST handlers",
whenMethod: http.MethodOptions,
whenURL: "/users",
expectAllowHeader: "OPTIONS, GET, POST",
expectStatus: http.StatusNoContent,
},
{
name: "allows GET and PUT handlers",
whenMethod: http.MethodOptions,
whenURL: "/users/1",
expectAllowHeader: "OPTIONS, GET, PUT",
expectStatus: http.StatusNoContent,
},
{
name: "GET does not have allows header",
whenMethod: http.MethodGet,
whenURL: "/users",
expectAllowHeader: "",
expectStatus: http.StatusOK,
},
{
name: "path with no handlers does not set Allows header",
whenMethod: http.MethodOptions,
whenURL: "/notFound",
expectAllowHeader: "",
expectStatus: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*context)
r.Find(tc.whenMethod, tc.whenURL, c)
err := c.handler(c)
if tc.expectStatus >= 400 {
assert.Error(t, err)
he := err.(*HTTPError)
assert.Equal(t, tc.expectStatus, he.Code)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expectStatus, rec.Code)
}
assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow"))
})
}
}
func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) {
e := New()
r := e.router