1
0
mirror of https://github.com/labstack/echo.git synced 2025-05-31 23:19:42 +02:00

Closes #517, closes #518

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-05-22 07:58:21 -07:00
parent 795ab0ad77
commit 702e6d0967
7 changed files with 39 additions and 6 deletions

View File

@ -106,8 +106,8 @@ type (
// Del deletes data from the context. // Del deletes data from the context.
Del(string) Del(string)
// Exists checks if that key exists in the context. // Contains checks if the key exists in the context.
Exists(string) bool Contains(string) bool
// Bind binds the request body into provided type `i`. The default binder // Bind binds the request body into provided type `i`. The default binder
// does it based on Content-Type header. // does it based on Content-Type header.
@ -323,7 +323,7 @@ func (c *context) Del(key string) {
delete(c.store, key) delete(c.store, key)
} }
func (c *context) Exists(key string) bool { func (c *context) Contains(key string) bool {
_, ok := c.store[key] _, ok := c.store[key]
return ok return ok
} }

View File

@ -146,8 +146,11 @@ type (
// no values associated with the key, Get returns "". // no values associated with the key, Get returns "".
Get(string) string Get(string) string
// Keys returns header keys. // Keys returns the header keys.
Keys() []string Keys() []string
// Contains checks if the header is set.
Contains(string) bool
} }
// URL defines the interface for HTTP request url. // URL defines the interface for HTTP request url.

View File

@ -47,6 +47,11 @@ func (h *RequestHeader) Keys() (keys []string) {
return return
} }
// Contains implements `engine.Header#Contains` function.
func (h *RequestHeader) Contains(key string) bool {
return h.Contains(key)
}
func (h *RequestHeader) reset(hdr *fasthttp.RequestHeader) { func (h *RequestHeader) reset(hdr *fasthttp.RequestHeader) {
h.RequestHeader = hdr h.RequestHeader = hdr
} }
@ -82,6 +87,11 @@ func (h *ResponseHeader) Keys() (keys []string) {
return return
} }
// Contains implements `engine.Header#Contains` function.
func (h *ResponseHeader) Contains(key string) bool {
return h.Contains(key)
}
func (h *ResponseHeader) reset(hdr *fasthttp.ResponseHeader) { func (h *ResponseHeader) reset(hdr *fasthttp.ResponseHeader) {
h.ResponseHeader = hdr h.ResponseHeader = hdr
} }

View File

@ -40,6 +40,12 @@ func (h *Header) Keys() (keys []string) {
return return
} }
// Contains implements `engine.Header#Contains` function.
func (h *Header) Contains(key string) bool {
_, ok := h.Header[key]
return ok
}
func (h *Header) reset(hdr http.Header) { func (h *Header) reset(hdr http.Header) {
h.Header = hdr h.Header = hdr
} }

View File

@ -78,6 +78,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
origin := req.Header().Get(echo.HeaderOrigin) origin := req.Header().Get(echo.HeaderOrigin)
originSet := req.Header().Contains(echo.HeaderOrigin) // Issue #517
// Check allowed origins // Check allowed origins
allowedOrigin := "" allowedOrigin := ""
@ -91,7 +92,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
// Simple request // Simple request
if req.Method() != echo.OPTIONS { if req.Method() != echo.OPTIONS {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if origin == "" || allowedOrigin == "" { if !originSet || allowedOrigin == "" {
return next(c) return next(c)
} }
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)
@ -108,7 +109,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
if origin == "" || allowedOrigin == "" { if !originSet || allowedOrigin == "" {
return next(c) return next(c)
} }
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin)

View File

@ -25,6 +25,14 @@ func TestCORS(t *testing.T) {
h(c) h(c)
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Empty origin header
req = test.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderOrigin, "")
h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard origin // Wildcard origin
req = test.NewRequest(echo.GET, "/", nil) req = test.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()

View File

@ -34,6 +34,11 @@ func (h *Header) Keys() (keys []string) {
return return
} }
func (h *Header) Contains(key string) bool {
_, ok := h.header[key]
return ok
}
func (h *Header) reset(hdr http.Header) { func (h *Header) reset(hdr http.Header) {
h.header = hdr h.header = hdr
} }