From 702e6d0967e7dfa04388bb5effa38302d9ff22ad Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 22 May 2016 07:58:21 -0700 Subject: [PATCH] Closes #517, closes #518 Signed-off-by: Vishal Rana --- context.go | 6 +++--- engine/engine.go | 5 ++++- engine/fasthttp/header.go | 10 ++++++++++ engine/standard/header.go | 6 ++++++ middleware/cors.go | 5 +++-- middleware/cors_test.go | 8 ++++++++ test/header.go | 5 +++++ 7 files changed, 39 insertions(+), 6 deletions(-) diff --git a/context.go b/context.go index c4f01ba1..ee8a7cf4 100644 --- a/context.go +++ b/context.go @@ -106,8 +106,8 @@ type ( // Del deletes data from the context. Del(string) - // Exists checks if that key exists in the context. - Exists(string) bool + // Contains checks if the key exists in the context. + Contains(string) bool // Bind binds the request body into provided type `i`. The default binder // does it based on Content-Type header. @@ -323,7 +323,7 @@ func (c *context) Del(key string) { delete(c.store, key) } -func (c *context) Exists(key string) bool { +func (c *context) Contains(key string) bool { _, ok := c.store[key] return ok } diff --git a/engine/engine.go b/engine/engine.go index 31cfbc7e..e7577d86 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -146,8 +146,11 @@ type ( // no values associated with the key, Get returns "". Get(string) string - // Keys returns header keys. + // Keys returns the header keys. Keys() []string + + // Contains checks if the header is set. + Contains(string) bool } // URL defines the interface for HTTP request url. diff --git a/engine/fasthttp/header.go b/engine/fasthttp/header.go index 8b4cdb3d..9e88d266 100644 --- a/engine/fasthttp/header.go +++ b/engine/fasthttp/header.go @@ -47,6 +47,11 @@ func (h *RequestHeader) Keys() (keys []string) { return } +// Contains implements `engine.Header#Contains` function. +func (h *RequestHeader) Contains(key string) bool { + return h.Contains(key) +} + func (h *RequestHeader) reset(hdr *fasthttp.RequestHeader) { h.RequestHeader = hdr } @@ -82,6 +87,11 @@ func (h *ResponseHeader) Keys() (keys []string) { return } +// Contains implements `engine.Header#Contains` function. +func (h *ResponseHeader) Contains(key string) bool { + return h.Contains(key) +} + func (h *ResponseHeader) reset(hdr *fasthttp.ResponseHeader) { h.ResponseHeader = hdr } diff --git a/engine/standard/header.go b/engine/standard/header.go index 01f0e83c..001849f4 100644 --- a/engine/standard/header.go +++ b/engine/standard/header.go @@ -40,6 +40,12 @@ func (h *Header) Keys() (keys []string) { return } +// Contains implements `engine.Header#Contains` function. +func (h *Header) Contains(key string) bool { + _, ok := h.Header[key] + return ok +} + func (h *Header) reset(hdr http.Header) { h.Header = hdr } diff --git a/middleware/cors.go b/middleware/cors.go index b3300677..f82d383c 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -78,6 +78,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() origin := req.Header().Get(echo.HeaderOrigin) + originSet := req.Header().Contains(echo.HeaderOrigin) // Issue #517 // Check allowed origins allowedOrigin := "" @@ -91,7 +92,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // Simple request if req.Method() != echo.OPTIONS { res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) - if origin == "" || allowedOrigin == "" { + if !originSet || allowedOrigin == "" { return next(c) } res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) @@ -108,7 +109,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) - if origin == "" || allowedOrigin == "" { + if !originSet || allowedOrigin == "" { return next(c) } res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) diff --git a/middleware/cors_test.go b/middleware/cors_test.go index cd36c277..846c4b34 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -25,6 +25,14 @@ func TestCORS(t *testing.T) { h(c) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + // Empty origin header + req = test.NewRequest(echo.GET, "/", nil) + rec = test.NewResponseRecorder() + c = e.NewContext(req, rec) + req.Header().Set(echo.HeaderOrigin, "") + h(c) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + // Wildcard origin req = test.NewRequest(echo.GET, "/", nil) rec = test.NewResponseRecorder() diff --git a/test/header.go b/test/header.go index cc41fe2c..57e81fc6 100644 --- a/test/header.go +++ b/test/header.go @@ -34,6 +34,11 @@ func (h *Header) Keys() (keys []string) { return } +func (h *Header) Contains(key string) bool { + _, ok := h.header[key] + return ok +} + func (h *Header) reset(hdr http.Header) { h.header = hdr }