From e0364caf3635d3df1fcbd445b4e212e1ae1729ca Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Fri, 22 May 2015 20:26:52 -0700 Subject: [PATCH] Better handling in middleware for WebSocket Signed-off-by: Vishal Rana --- echo.go | 58 ++++++++++++++++++++++++++---------------- group.go | 7 +++++ middleware/auth.go | 5 ++++ middleware/compress.go | 6 +++-- router.go | 14 +++++----- router_test.go | 20 +++++++-------- 6 files changed, 69 insertions(+), 41 deletions(-) create mode 100644 group.go diff --git a/echo.go b/echo.go index a20f5971..1b95ee8a 100644 --- a/echo.go +++ b/echo.go @@ -20,7 +20,7 @@ import ( type ( Echo struct { - Router *router + router *Router prefix string middleware []MiddlewareFunc http2 bool @@ -33,10 +33,12 @@ type ( pool sync.Pool debug bool } + HTTPError struct { code int message string } + Middleware interface{} MiddlewareFunc func(HandlerFunc) HandlerFunc Handler interface{} @@ -99,6 +101,13 @@ const ( ContentLength = "Content-Length" ContentType = "Content-Type" Authorization = "Authorization" + Upgrade = "Upgrade" + + //----------- + // Protocols + //----------- + + WebSocket = "websocket" ) var ( @@ -122,30 +131,12 @@ var ( RendererNotRegistered = errors.New("echo ⇒ renderer not registered") ) -func NewHTTPError(code int, msg ...string) *HTTPError { - he := &HTTPError{code: code, message: http.StatusText(code)} - for _, m := range msg { - he.message = m - } - return he -} - -// Code returns code. -func (e *HTTPError) Code() int { - return e.code -} - -// Error returns message. -func (e *HTTPError) Error() string { - return e.message -} - // New creates an Echo instance. func New() (e *Echo) { e = &Echo{ uris: make(map[Handler]string), } - e.Router = NewRouter(e) + e.router = NewRouter(e) e.pool.New = func() interface{} { return NewContext(nil, new(Response), e) } @@ -196,6 +187,11 @@ func (e *Echo) Group(pfx string, m ...Middleware) *Echo { return &g } +// Router returns router. +func (e *Echo) Router() *Router { + return e.router +} + // HTTP2 enables HTTP2 support. func (e *Echo) HTTP2(on bool) { e.http2 = on @@ -302,7 +298,7 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) { func (e *Echo) add(method, path string, h Handler) { key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() e.uris[key] = path - e.Router.Add(method, e.prefix+path, wrapHandler(h), e) + e.router.Add(method, e.prefix+path, wrapHandler(h), e) } // Index serves index file. @@ -361,7 +357,7 @@ func (e *Echo) URL(h Handler, params ...interface{}) string { func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { c := e.pool.Get().(*Context) - h, echo := e.Router.Find(r.Method, r.URL.Path, c) + h, echo := e.router.Find(r.Method, r.URL.Path, c) if echo != nil { e = echo } @@ -419,6 +415,24 @@ func (e *Echo) run(s *http.Server, files ...string) { } } +func NewHTTPError(code int, msg ...string) *HTTPError { + he := &HTTPError{code: code, message: http.StatusText(code)} + for _, m := range msg { + he.message = m + } + return he +} + +// Code returns code. +func (e *HTTPError) Code() int { + return e.code +} + +// Error returns message. +func (e *HTTPError) Error() string { + return e.message +} + // wraps middleware func wrapMiddleware(m Middleware) MiddlewareFunc { switch m := m.(type) { diff --git a/group.go b/group.go new file mode 100644 index 00000000..18ad367e --- /dev/null +++ b/group.go @@ -0,0 +1,7 @@ +package echo + +type ( + Group struct { + *Echo + } +) diff --git a/middleware/auth.go b/middleware/auth.go index bd3a1469..d441e8fc 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -21,6 +21,11 @@ const ( // For invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn AuthFunc) echo.HandlerFunc { return func(c *echo.Context) error { + // Skip for WebSocket + if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket { + return nil + } + auth := c.Request().Header.Get(echo.Authorization) i := 0 code := http.StatusBadRequest diff --git a/middleware/compress.go b/middleware/compress.go index 33736a27..2c4326e8 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -5,8 +5,9 @@ import ( "io" "strings" - "github.com/labstack/echo" "net/http" + + "github.com/labstack/echo" ) type ( @@ -27,7 +28,8 @@ func Gzip() echo.MiddlewareFunc { return func(h echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { - if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) { + if (c.Request().Header.Get(echo.Upgrade)) != echo.WebSocket && // Skip for WebSocket + strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) { w := gzip.NewWriter(c.Response().Writer()) defer w.Close() gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()} diff --git a/router.go b/router.go index 3cb4fff8..0422546b 100644 --- a/router.go +++ b/router.go @@ -3,7 +3,7 @@ package echo import "net/http" type ( - router struct { + Router struct { trees map[string]*node echo *Echo } @@ -27,8 +27,8 @@ const ( mtype ) -func NewRouter(e *Echo) (r *router) { - r = &router{ +func NewRouter(e *Echo) (r *Router) { + r = &Router{ trees: make(map[string]*node), echo: e, } @@ -41,7 +41,7 @@ func NewRouter(e *Echo) (r *router) { return } -func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) { +func (r *Router) Add(method, path string, h HandlerFunc, echo *Echo) { var pnames []string // Param names for i, l := 0, len(path); i < l; i++ { @@ -71,7 +71,7 @@ func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) { r.insert(method, path, h, stype, pnames, echo) } -func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) { +func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) { cn := r.trees[method] // Current node as root search := path @@ -201,7 +201,7 @@ func lcp(a, b string) (i int) { return } -func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) { +func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) { cn := r.trees[method] // Current node as root search := path @@ -305,7 +305,7 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E } } -func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { c := r.echo.pool.Get().(*Context) h, _ := r.Find(req.Method, req.URL.Path, c) c.reset(w, req, r.echo) diff --git a/router_test.go b/router_test.go index b1b4ff8e..a1b671c3 100644 --- a/router_test.go +++ b/router_test.go @@ -280,7 +280,7 @@ var ( ) func TestRouterStatic(t *testing.T) { - r := New().Router + r := New().router b := new(bytes.Buffer) path := "/folders/a/files/echo.gif" r.Add(GET, path, func(*Context) error { @@ -299,7 +299,7 @@ func TestRouterStatic(t *testing.T) { } func TestRouterParam(t *testing.T) { - r := New().Router + r := New().router r.Add(GET, "/users/:id", func(c *Context) error { return nil }, nil) @@ -314,7 +314,7 @@ func TestRouterParam(t *testing.T) { } func TestRouterTwoParam(t *testing.T) { - r := New().Router + r := New().router r.Add(GET, "/users/:uid/files/:fid", func(*Context) error { return nil }, nil) @@ -338,7 +338,7 @@ func TestRouterTwoParam(t *testing.T) { } func TestRouterMatchAny(t *testing.T) { - r := New().Router + r := New().router r.Add(GET, "/users/*", func(*Context) error { return nil }, nil) @@ -363,7 +363,7 @@ func TestRouterMatchAny(t *testing.T) { } func TestRouterMicroParam(t *testing.T) { - r := New().Router + r := New().router r.Add(GET, "/:a/:b/:c", func(c *Context) error { return nil }, nil) @@ -384,7 +384,7 @@ func TestRouterMicroParam(t *testing.T) { } func TestRouterMultiRoute(t *testing.T) { - r := New().Router + r := New().router b := new(bytes.Buffer) // Routes @@ -425,7 +425,7 @@ func TestRouterMultiRoute(t *testing.T) { } func TestRouterPriority(t *testing.T) { - r := New().Router + r := New().router // Routes r.Add(GET, "/users", func(c *Context) error { @@ -536,7 +536,7 @@ func TestRouterPriority(t *testing.T) { } func TestRouterParamNames(t *testing.T) { - r := New().Router + r := New().router b := new(bytes.Buffer) // Routes @@ -596,7 +596,7 @@ func TestRouterParamNames(t *testing.T) { } func TestRouterAPI(t *testing.T) { - r := New().Router + r := New().router for _, route := range api { r.Add(route.method, route.path, func(c *Context) error { for i, n := range c.pnames { @@ -618,7 +618,7 @@ func TestRouterAPI(t *testing.T) { } func TestRouterServeHTTP(t *testing.T) { - r := New().Router + r := New().router r.Add(GET, "/users", func(*Context) error { return nil }, nil)