diff --git a/context.go b/context.go index 91ab6e48..ea542cb8 100644 --- a/context.go +++ b/context.go @@ -210,6 +210,13 @@ type ( } ) +const ( + // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. + // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. + // It is added to context only when Router does not find matching method handler for request. + ContextKeyHeaderAllow = "____echo____header_allow" +) + const ( defaultMemory = 32 << 20 // 32 MB indexPage = "index.html" diff --git a/echo.go b/echo.go index ad03dd51..8747039e 100644 --- a/echo.go +++ b/echo.go @@ -190,8 +190,11 @@ const ( // Headers const ( - HeaderAccept = "Accept" - HeaderAcceptEncoding = "Accept-Encoding" + HeaderAccept = "Accept" + HeaderAcceptEncoding = "Accept-Encoding" + // HeaderAllow is header field that lists the set of methods advertised as supported by the target resource. + // Allow header is mandatory for status 405 (method not found) and useful OPTIONS method responses. + // See: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 HeaderAllow = "Allow" HeaderAuthorization = "Authorization" HeaderContentDisposition = "Content-Disposition" @@ -302,6 +305,13 @@ var ( } MethodNotAllowedHandler = func(c Context) error { + // 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + // >> An origin server MUST generate an Allow field in a 405 (Method Not Allowed) response + // and MAY do so in any other response. + routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) + if ok && routerAllowMethods != "" { + c.Response().Header().Set(HeaderAllow, routerAllowMethods) + } return ErrMethodNotAllowed } ) diff --git a/echo_test.go b/echo_test.go index f2891586..13a51b6c 100644 --- a/echo_test.go +++ b/echo_test.go @@ -716,13 +716,16 @@ func TestEchoNotFound(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) { e := New() + e.GET("/", func(c Context) error { return c.String(http.StatusOK, "Echo!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusMethodNotAllowed, rec.Code) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) } func TestEchoContext(t *testing.T) { diff --git a/middleware/cors.go b/middleware/cors.go index d6ef8964..a5122f26 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -29,6 +29,8 @@ type ( // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. + // If `allowMethods` is left empty will fill for preflight request `Access-Control-Allow-Methods` header value + // from `Allow` header that echo.Router set into context. AllowMethods []string `yaml:"allow_methods"` // AllowHeaders defines a list of request headers that can be used when @@ -41,6 +43,8 @@ type ( // a response to a preflight request, this indicates whether or not the // actual request can be made using credentials. // Optional. Default value false. + // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. + // See http://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html AllowCredentials bool `yaml:"allow_credentials"` // ExposeHeaders defines a whitelist headers that clients are allowed to @@ -80,7 +84,9 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { if len(config.AllowOrigins) == 0 { config.AllowOrigins = DefaultCORSConfig.AllowOrigins } + hasCustomAllowMethods := true if len(config.AllowMethods) == 0 { + hasCustomAllowMethods = false config.AllowMethods = DefaultCORSConfig.AllowMethods } @@ -109,10 +115,28 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" - preflight := req.Method == http.MethodOptions res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) - // No Origin provided + // Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method, + // Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request + // For simplicity we just consider method type and later `Origin` header. + preflight := req.Method == http.MethodOptions + + // Although router adds special handler in case of OPTIONS method we avoid calling next for OPTIONS in this middleware + // as CORS requests do not have cookies / authentication headers by default, so we could get stuck in auth + // middlewares by calling next(c). + // But we still want to send `Allow` header as response in case of Non-CORS OPTIONS request as router default + // handler does. + routerAllowMethods := "" + if preflight { + tmpAllowMethods, ok := c.Get(echo.ContextKeyHeaderAllow).(string) + if ok && tmpAllowMethods != "" { + routerAllowMethods = tmpAllowMethods + c.Response().Header().Set(echo.HeaderAllow, routerAllowMethods) + } + } + + // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { if !preflight { return next(c) @@ -145,19 +169,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } } - // Check allowed origin patterns - for _, re := range allowOriginPatterns { - if allowOrigin == "" { - didx := strings.Index(origin, "://") - if didx == -1 { - continue - } - domAuth := origin[didx+3:] - // to avoid regex cost by invalid long domain - if len(domAuth) > 253 { - break - } - + checkPatterns := false + if allowOrigin == "" { + // to avoid regex cost by invalid (long) domains (253 is domain name max limit) + if len(origin) <= (253+3+4) && strings.Contains(origin, "://") { + checkPatterns = true + } + } + if checkPatterns { + for _, re := range allowOriginPatterns { if match, _ := regexp.MatchString(re, origin); match { allowOrigin = origin break @@ -174,12 +194,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { return c.NoContent(http.StatusNoContent) } + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) + if config.AllowCredentials { + res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + } + // Simple request if !preflight { - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") - } if exposeHeaders != "" { res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } @@ -189,11 +210,13 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Preflight request res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) - res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) - if config.AllowCredentials { - res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") + + if !hasCustomAllowMethods && routerAllowMethods != "" { + res.Header().Set(echo.HeaderAccessControlAllowMethods, routerAllowMethods) + } else { + res.Header().Set(echo.HeaderAccessControlAllowMethods, allowMethods) } + if allowHeaders != "" { res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) } else { diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 717abe49..daadbab6 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -251,114 +251,238 @@ func Test_allowOriginSubdomain(t *testing.T) { } } -func TestCorsHeaders(t *testing.T) { - tests := []struct { - domain, allowedOrigin, method string - expected bool +func TestCORSWithConfig_AllowMethods(t *testing.T) { + var testCases = []struct { + name string + allowOrigins []string + allowContextKey string + + whenOrigin string + whenAllowMethods []string + + expectAllow string + expectAccessControlAllowMethods string }{ { - domain: "", // Request does not have Origin header - allowedOrigin: "*", - method: http.MethodGet, - expected: false, + name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenOrigin: "", + expectAllow: "OPTIONS, GET", }, { - domain: "http://example.com", - allowedOrigin: "*", - method: http.MethodGet, - expected: true, + name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", + allowContextKey: "", + whenAllowMethods: nil, + whenOrigin: "", + expectAllow: "", }, { - domain: "", // Request does not have Origin header - allowedOrigin: "http://example.com", - method: http.MethodGet, - expected: false, + name: "custom AllowMethods, preflight, existing origin, sets both headers different values", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenOrigin: "http://google.com", + expectAllow: "OPTIONS, GET", + expectAccessControlAllowMethods: "GET,HEAD", }, { - domain: "http://bar.com", - allowedOrigin: "http://example.com", - method: http.MethodGet, - expected: false, + name: "default AllowMethods, preflight, existing origin, sets both headers", + allowContextKey: "OPTIONS, GET", + whenAllowMethods: nil, + whenOrigin: "http://google.com", + expectAllow: "OPTIONS, GET", + expectAccessControlAllowMethods: "OPTIONS, GET", }, { - domain: "http://example.com", - allowedOrigin: "http://example.com", - method: http.MethodGet, - expected: true, - }, - { - domain: "", // Request does not have Origin header - allowedOrigin: "*", - method: http.MethodOptions, - expected: false, - }, - { - domain: "http://example.com", - allowedOrigin: "*", - method: http.MethodOptions, - expected: true, - }, - { - domain: "", // Request does not have Origin header - allowedOrigin: "http://example.com", - method: http.MethodOptions, - expected: false, - }, - { - domain: "http://bar.com", - allowedOrigin: "http://example.com", - method: http.MethodGet, - expected: false, - }, - { - domain: "http://example.com", - allowedOrigin: "http://example.com", - method: http.MethodOptions, - expected: true, + name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods", + allowContextKey: "", + whenAllowMethods: nil, + whenOrigin: "http://google.com", + expectAllow: "", + expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", }, } - e := echo.New() - for _, tt := range tests { - req := httptest.NewRequest(tt.method, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - if tt.domain != "" { - req.Header.Set(echo.HeaderOrigin, tt.domain) - } - cors := CORSWithConfig(CORSConfig{ - AllowOrigins: []string{tt.allowedOrigin}, - //AllowCredentials: true, - //MaxAge: 3600, + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + e.GET("/test", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + + cors := CORSWithConfig(CORSConfig{ + AllowOrigins: tc.allowOrigins, + AllowMethods: tc.whenAllowMethods, + }) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) + if tc.allowContextKey != "" { + c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) + } + + h := cors(echo.NotFoundHandler) + h(c) + + assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) + assert.Equal(t, tc.expectAccessControlAllowMethods, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) }) - h := cors(echo.NotFoundHandler) - h(c) + } +} - assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) +func TestCorsHeaders(t *testing.T) { + tests := []struct { + name string + originDomain string + method string + allowedOrigin string + expected bool + expectStatus int + expectAllowHeader string + }{ + { + name: "non-preflight request, allow any origin, missing origin header = no CORS logic done", + originDomain: "", + allowedOrigin: "*", + method: http.MethodGet, + expected: false, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow any origin, specific origin domain", + originDomain: "http://example.com", + allowedOrigin: "*", + method: http.MethodGet, + expected: true, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow specific origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow specific origin, different origin header = CORS logic failure", + originDomain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: false, + expectStatus: http.StatusOK, + }, + { + name: "non-preflight request, allow specific origin, matching origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodGet, + expected: true, + expectStatus: http.StatusOK, + }, + { + name: "preflight, allow any origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "*", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow any origin, existing origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "*", + method: http.MethodOptions, + expected: true, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow any origin, missing origin header = no CORS logic done", + originDomain: "", // Request does not have Origin header + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow specific origin, different origin header = no CORS logic done", + originDomain: "http://bar.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: false, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + { + name: "preflight, allow specific origin, matching origin header = CORS logic done", + originDomain: "http://example.com", + allowedOrigin: "http://example.com", + method: http.MethodOptions, + expected: true, + expectStatus: http.StatusNoContent, + expectAllowHeader: "OPTIONS, GET, POST", + }, + } - expectedAllowOrigin := "" - if tt.allowedOrigin == "*" { - expectedAllowOrigin = "*" - } else { - expectedAllowOrigin = tt.domain - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() - switch { - case tt.expected && tt.method == http.MethodOptions: - assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) - assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) - case tt.expected && tt.method == http.MethodGet: - assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin - default: - assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) - assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin - } + e.Use(CORSWithConfig(CORSConfig{ + AllowOrigins: []string{tc.allowedOrigin}, + //AllowCredentials: true, + //MaxAge: 3600, + })) + + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "OK") + }) + e.POST("/", func(c echo.Context) error { + return c.String(http.StatusCreated, "OK") + }) + + req := httptest.NewRequest(tc.method, "/", nil) + rec := httptest.NewRecorder() + + if tc.originDomain != "" { + req.Header.Set(echo.HeaderOrigin, tc.originDomain) + } + + // we run through whole Echo handler chain to see how CORS works with Router OPTIONS handler + e.ServeHTTP(rec, req) + + assert.Equal(t, echo.HeaderOrigin, rec.Header().Get(echo.HeaderVary)) + assert.Equal(t, tc.expectAllowHeader, rec.Header().Get(echo.HeaderAllow)) + assert.Equal(t, tc.expectStatus, rec.Code) + + expectedAllowOrigin := "" + if tc.allowedOrigin == "*" { + expectedAllowOrigin = "*" + } else { + expectedAllowOrigin = tc.originDomain + } + switch { + case tc.expected && tc.method == http.MethodOptions: + assert.Contains(t, rec.Header(), echo.HeaderAccessControlAllowMethods) + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + + assert.Equal(t, 3, len(rec.Header()[echo.HeaderVary])) + + case tc.expected && tc.method == http.MethodGet: + assert.Equal(t, expectedAllowOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + default: + assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) + assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary])) // Vary: Origin + } + }) - if tt.method == http.MethodOptions { - assert.Equal(t, http.StatusNoContent, rec.Code) - } } } diff --git a/router.go b/router.go index dc93e29c..1a2ce561 100644 --- a/router.go +++ b/router.go @@ -1,6 +1,7 @@ package echo import ( + "bytes" "net/http" ) @@ -31,17 +32,18 @@ type ( kind uint8 children []*node methodHandler struct { - connect HandlerFunc - delete HandlerFunc - get HandlerFunc - head HandlerFunc - options HandlerFunc - patch HandlerFunc - post HandlerFunc - propfind HandlerFunc - put HandlerFunc - trace HandlerFunc - report HandlerFunc + connect HandlerFunc + delete HandlerFunc + get HandlerFunc + head HandlerFunc + options HandlerFunc + patch HandlerFunc + post HandlerFunc + propfind HandlerFunc + put HandlerFunc + trace HandlerFunc + report HandlerFunc + allowHeader string } ) @@ -68,6 +70,51 @@ func (m *methodHandler) isHandler() bool { m.report != nil } +func (m *methodHandler) updateAllowHeader() { + buf := new(bytes.Buffer) + buf.WriteString(http.MethodOptions) + + if m.connect != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodConnect) + } + if m.delete != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodDelete) + } + if m.get != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodGet) + } + if m.head != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodHead) + } + if m.patch != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodPatch) + } + if m.post != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodPost) + } + if m.propfind != nil { + buf.WriteString(", PROPFIND") + } + if m.put != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodPut) + } + if m.trace != nil { + buf.WriteString(", ") + buf.WriteString(http.MethodTrace) + } + if m.report != nil { + buf.WriteString(", REPORT") + } + m.allowHeader = buf.String() +} + // NewRouter returns a new Router instance. func NewRouter(e *Echo) *Router { return &Router{ @@ -326,6 +373,7 @@ func (n *node) addHandler(method string, h HandlerFunc) { n.methodHandler.report = h } + n.methodHandler.updateAllowHeader() if h != nil { n.isHandler = true } else { @@ -362,13 +410,14 @@ func (n *node) findHandler(method string) HandlerFunc { } } -func (n *node) checkMethodNotAllowed() HandlerFunc { - for _, m := range methods { - if h := n.findHandler(m); h != nil { - return MethodNotAllowedHandler - } +func optionsMethodHandler(allowMethods string) func(c Context) error { + return func(c Context) error { + // Note: we are not handling most of the CORS headers here. CORS is handled by CORS middleware + // 'OPTIONS' method RFC: https://httpwg.org/specs/rfc7231.html#OPTIONS + // 'Allow' header RFC: https://datatracker.ietf.org/doc/html/rfc7231#section-7.4.1 + c.Response().Header().Add(HeaderAllow, allowMethods) + return c.NoContent(http.StatusNoContent) } - return NotFoundHandler } // Find lookup a handler registered for method and path. It also parses URL for path @@ -563,7 +612,15 @@ func (r *Router) Find(method, path string, c Context) { // use previous match as basis. although we have no matching handler we have path match. // so we can send http.StatusMethodNotAllowed (405) instead of http.StatusNotFound (404) currentNode = previousBestMatchNode - ctx.handler = currentNode.checkMethodNotAllowed() + + ctx.handler = NotFoundHandler + if currentNode.isHandler { + ctx.Set(ContextKeyHeaderAllow, currentNode.methodHandler.allowHeader) + ctx.handler = MethodNotAllowedHandler + if method == http.MethodOptions { + ctx.handler = optionsMethodHandler(currentNode.methodHandler.allowHeader) + } + } } ctx.path = currentNode.ppath ctx.pnames = currentNode.pnames diff --git a/router_test.go b/router_test.go index 57be74de..5cbb8d9b 100644 --- a/router_test.go +++ b/router_test.go @@ -3,6 +3,7 @@ package echo import ( "fmt" "net/http" + "net/http/httptest" "strings" "testing" @@ -725,12 +726,13 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { r.Add(http.MethodPost, "/users/:id", handlerFunc) var testCases = []struct { - name string - whenMethod string - whenURL string - expectRoute interface{} - expectParam map[string]string - expectError error + name string + whenMethod string + whenURL string + expectRoute interface{} + expectParam map[string]string + expectError error + expectAllowHeader string }{ { name: "exact match for route+method", @@ -740,11 +742,12 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { expectParam: map[string]string{"id": "1"}, }, { - name: "matches node but not method. sends 405 from best match node", - whenMethod: http.MethodPut, - whenURL: "/users/1", - expectRoute: nil, - expectError: ErrMethodNotAllowed, + name: "matches node but not method. sends 405 from best match node", + whenMethod: http.MethodPut, + whenURL: "/users/1", + expectRoute: nil, + expectError: ErrMethodNotAllowed, + expectAllowHeader: "OPTIONS, POST", }, { name: "best match is any route up in tree", @@ -756,7 +759,9 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := e.NewContext(nil, nil).(*context) + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) method := http.MethodGet if tc.whenMethod != "" { @@ -775,10 +780,36 @@ func TestMethodNotAllowedAndNotFound(t *testing.T) { assert.Equal(t, expectedValue, c.Param(param)) } checkUnusedParamValues(t, c, tc.expectParam) + + assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get(HeaderAllow)) }) } } +func TestRouterOptionsMethodHandler(t *testing.T) { + e := New() + + var keyInContext interface{} + e.Use(func(next HandlerFunc) HandlerFunc { + return func(c Context) error { + err := next(c) + keyInContext = c.Get(ContextKeyHeaderAllow) + return err + } + }) + e.GET("/test", func(c Context) error { + return c.String(http.StatusOK, "Echo!") + }) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) + assert.Equal(t, "OPTIONS, GET", keyInContext) +} + func TestRouterTwoParam(t *testing.T) { e := New() r := e.router @@ -2288,6 +2319,73 @@ func TestRouterPanicWhenParamNoRootOnlyChildsFailsFind(t *testing.T) { } } +func TestRouterHandleMethodOptions(t *testing.T) { + e := New() + r := e.router + + r.Add(http.MethodGet, "/users", handlerFunc) + r.Add(http.MethodPost, "/users", handlerFunc) + r.Add(http.MethodPut, "/users/:id", handlerFunc) + r.Add(http.MethodGet, "/users/:id", handlerFunc) + + var testCases = []struct { + name string + whenMethod string + whenURL string + expectAllowHeader string + expectStatus int + }{ + { + name: "allows GET and POST handlers", + whenMethod: http.MethodOptions, + whenURL: "/users", + expectAllowHeader: "OPTIONS, GET, POST", + expectStatus: http.StatusNoContent, + }, + { + name: "allows GET and PUT handlers", + whenMethod: http.MethodOptions, + whenURL: "/users/1", + expectAllowHeader: "OPTIONS, GET, PUT", + expectStatus: http.StatusNoContent, + }, + { + name: "GET does not have allows header", + whenMethod: http.MethodGet, + whenURL: "/users", + expectAllowHeader: "", + expectStatus: http.StatusOK, + }, + { + name: "path with no handlers does not set Allows header", + whenMethod: http.MethodOptions, + whenURL: "/notFound", + expectAllowHeader: "", + expectStatus: http.StatusNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.whenMethod, tc.whenURL, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec).(*context) + + r.Find(tc.whenMethod, tc.whenURL, c) + err := c.handler(c) + + if tc.expectStatus >= 400 { + assert.Error(t, err) + he := err.(*HTTPError) + assert.Equal(t, tc.expectStatus, he.Code) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectStatus, rec.Code) + } + assert.Equal(t, tc.expectAllowHeader, c.Response().Header().Get("Allow")) + }) + } +} + func benchmarkRouterRoutes(b *testing.B, routes []*Route, routesToFind []*Route) { e := New() r := e.router