1
0
mirror of https://github.com/labstack/echo.git synced 2024-11-24 08:22:21 +02:00

Better handling in middleware for WebSocket

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-22 20:26:52 -07:00
parent 95f72a5170
commit e0364caf36
6 changed files with 69 additions and 41 deletions

58
echo.go
View File

@ -20,7 +20,7 @@ import (
type ( type (
Echo struct { Echo struct {
Router *router router *Router
prefix string prefix string
middleware []MiddlewareFunc middleware []MiddlewareFunc
http2 bool http2 bool
@ -33,10 +33,12 @@ type (
pool sync.Pool pool sync.Pool
debug bool debug bool
} }
HTTPError struct { HTTPError struct {
code int code int
message string message string
} }
Middleware interface{} Middleware interface{}
MiddlewareFunc func(HandlerFunc) HandlerFunc MiddlewareFunc func(HandlerFunc) HandlerFunc
Handler interface{} Handler interface{}
@ -99,6 +101,13 @@ const (
ContentLength = "Content-Length" ContentLength = "Content-Length"
ContentType = "Content-Type" ContentType = "Content-Type"
Authorization = "Authorization" Authorization = "Authorization"
Upgrade = "Upgrade"
//-----------
// Protocols
//-----------
WebSocket = "websocket"
) )
var ( var (
@ -122,30 +131,12 @@ var (
RendererNotRegistered = errors.New("echo ⇒ renderer not registered") RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
) )
func NewHTTPError(code int, msg ...string) *HTTPError {
he := &HTTPError{code: code, message: http.StatusText(code)}
for _, m := range msg {
he.message = m
}
return he
}
// Code returns code.
func (e *HTTPError) Code() int {
return e.code
}
// Error returns message.
func (e *HTTPError) Error() string {
return e.message
}
// New creates an Echo instance. // New creates an Echo instance.
func New() (e *Echo) { func New() (e *Echo) {
e = &Echo{ e = &Echo{
uris: make(map[Handler]string), uris: make(map[Handler]string),
} }
e.Router = NewRouter(e) e.router = NewRouter(e)
e.pool.New = func() interface{} { e.pool.New = func() interface{} {
return NewContext(nil, new(Response), e) return NewContext(nil, new(Response), e)
} }
@ -196,6 +187,11 @@ func (e *Echo) Group(pfx string, m ...Middleware) *Echo {
return &g return &g
} }
// Router returns router.
func (e *Echo) Router() *Router {
return e.router
}
// HTTP2 enables HTTP2 support. // HTTP2 enables HTTP2 support.
func (e *Echo) HTTP2(on bool) { func (e *Echo) HTTP2(on bool) {
e.http2 = on e.http2 = on
@ -302,7 +298,7 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) {
func (e *Echo) add(method, path string, h Handler) { func (e *Echo) add(method, path string, h Handler) {
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
e.uris[key] = path e.uris[key] = path
e.Router.Add(method, e.prefix+path, wrapHandler(h), e) e.router.Add(method, e.prefix+path, wrapHandler(h), e)
} }
// Index serves index file. // Index serves index file.
@ -361,7 +357,7 @@ func (e *Echo) URL(h Handler, params ...interface{}) string {
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := e.pool.Get().(*Context) c := e.pool.Get().(*Context)
h, echo := e.Router.Find(r.Method, r.URL.Path, c) h, echo := e.router.Find(r.Method, r.URL.Path, c)
if echo != nil { if echo != nil {
e = echo e = echo
} }
@ -419,6 +415,24 @@ func (e *Echo) run(s *http.Server, files ...string) {
} }
} }
func NewHTTPError(code int, msg ...string) *HTTPError {
he := &HTTPError{code: code, message: http.StatusText(code)}
for _, m := range msg {
he.message = m
}
return he
}
// Code returns code.
func (e *HTTPError) Code() int {
return e.code
}
// Error returns message.
func (e *HTTPError) Error() string {
return e.message
}
// wraps middleware // wraps middleware
func wrapMiddleware(m Middleware) MiddlewareFunc { func wrapMiddleware(m Middleware) MiddlewareFunc {
switch m := m.(type) { switch m := m.(type) {

7
group.go Normal file
View File

@ -0,0 +1,7 @@
package echo
type (
Group struct {
*Echo
}
)

View File

@ -21,6 +21,11 @@ const (
// For invalid credentials, it sends "401 - Unauthorized" response. // For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn AuthFunc) echo.HandlerFunc { func BasicAuth(fn AuthFunc) echo.HandlerFunc {
return func(c *echo.Context) error { return func(c *echo.Context) error {
// Skip for WebSocket
if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket {
return nil
}
auth := c.Request().Header.Get(echo.Authorization) auth := c.Request().Header.Get(echo.Authorization)
i := 0 i := 0
code := http.StatusBadRequest code := http.StatusBadRequest

View File

@ -5,8 +5,9 @@ import (
"io" "io"
"strings" "strings"
"github.com/labstack/echo"
"net/http" "net/http"
"github.com/labstack/echo"
) )
type ( type (
@ -27,7 +28,8 @@ func Gzip() echo.MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc { return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error { return func(c *echo.Context) error {
if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) { if (c.Request().Header.Get(echo.Upgrade)) != echo.WebSocket && // Skip for WebSocket
strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
w := gzip.NewWriter(c.Response().Writer()) w := gzip.NewWriter(c.Response().Writer())
defer w.Close() defer w.Close()
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()} gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}

View File

@ -3,7 +3,7 @@ package echo
import "net/http" import "net/http"
type ( type (
router struct { Router struct {
trees map[string]*node trees map[string]*node
echo *Echo echo *Echo
} }
@ -27,8 +27,8 @@ const (
mtype mtype
) )
func NewRouter(e *Echo) (r *router) { func NewRouter(e *Echo) (r *Router) {
r = &router{ r = &Router{
trees: make(map[string]*node), trees: make(map[string]*node),
echo: e, echo: e,
} }
@ -41,7 +41,7 @@ func NewRouter(e *Echo) (r *router) {
return return
} }
func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) { func (r *Router) Add(method, path string, h HandlerFunc, echo *Echo) {
var pnames []string // Param names var pnames []string // Param names
for i, l := 0, len(path); i < l; i++ { for i, l := 0, len(path); i < l; i++ {
@ -71,7 +71,7 @@ func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
r.insert(method, path, h, stype, pnames, echo) r.insert(method, path, h, stype, pnames, echo)
} }
func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) { func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
cn := r.trees[method] // Current node as root cn := r.trees[method] // Current node as root
search := path search := path
@ -201,7 +201,7 @@ func lcp(a, b string) (i int) {
return return
} }
func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) { func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
cn := r.trees[method] // Current node as root cn := r.trees[method] // Current node as root
search := path search := path
@ -305,7 +305,7 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
} }
} }
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := r.echo.pool.Get().(*Context) c := r.echo.pool.Get().(*Context)
h, _ := r.Find(req.Method, req.URL.Path, c) h, _ := r.Find(req.Method, req.URL.Path, c)
c.reset(w, req, r.echo) c.reset(w, req, r.echo)

View File

@ -280,7 +280,7 @@ var (
) )
func TestRouterStatic(t *testing.T) { func TestRouterStatic(t *testing.T) {
r := New().Router r := New().router
b := new(bytes.Buffer) b := new(bytes.Buffer)
path := "/folders/a/files/echo.gif" path := "/folders/a/files/echo.gif"
r.Add(GET, path, func(*Context) error { r.Add(GET, path, func(*Context) error {
@ -299,7 +299,7 @@ func TestRouterStatic(t *testing.T) {
} }
func TestRouterParam(t *testing.T) { func TestRouterParam(t *testing.T) {
r := New().Router r := New().router
r.Add(GET, "/users/:id", func(c *Context) error { r.Add(GET, "/users/:id", func(c *Context) error {
return nil return nil
}, nil) }, nil)
@ -314,7 +314,7 @@ func TestRouterParam(t *testing.T) {
} }
func TestRouterTwoParam(t *testing.T) { func TestRouterTwoParam(t *testing.T) {
r := New().Router r := New().router
r.Add(GET, "/users/:uid/files/:fid", func(*Context) error { r.Add(GET, "/users/:uid/files/:fid", func(*Context) error {
return nil return nil
}, nil) }, nil)
@ -338,7 +338,7 @@ func TestRouterTwoParam(t *testing.T) {
} }
func TestRouterMatchAny(t *testing.T) { func TestRouterMatchAny(t *testing.T) {
r := New().Router r := New().router
r.Add(GET, "/users/*", func(*Context) error { r.Add(GET, "/users/*", func(*Context) error {
return nil return nil
}, nil) }, nil)
@ -363,7 +363,7 @@ func TestRouterMatchAny(t *testing.T) {
} }
func TestRouterMicroParam(t *testing.T) { func TestRouterMicroParam(t *testing.T) {
r := New().Router r := New().router
r.Add(GET, "/:a/:b/:c", func(c *Context) error { r.Add(GET, "/:a/:b/:c", func(c *Context) error {
return nil return nil
}, nil) }, nil)
@ -384,7 +384,7 @@ func TestRouterMicroParam(t *testing.T) {
} }
func TestRouterMultiRoute(t *testing.T) { func TestRouterMultiRoute(t *testing.T) {
r := New().Router r := New().router
b := new(bytes.Buffer) b := new(bytes.Buffer)
// Routes // Routes
@ -425,7 +425,7 @@ func TestRouterMultiRoute(t *testing.T) {
} }
func TestRouterPriority(t *testing.T) { func TestRouterPriority(t *testing.T) {
r := New().Router r := New().router
// Routes // Routes
r.Add(GET, "/users", func(c *Context) error { r.Add(GET, "/users", func(c *Context) error {
@ -536,7 +536,7 @@ func TestRouterPriority(t *testing.T) {
} }
func TestRouterParamNames(t *testing.T) { func TestRouterParamNames(t *testing.T) {
r := New().Router r := New().router
b := new(bytes.Buffer) b := new(bytes.Buffer)
// Routes // Routes
@ -596,7 +596,7 @@ func TestRouterParamNames(t *testing.T) {
} }
func TestRouterAPI(t *testing.T) { func TestRouterAPI(t *testing.T) {
r := New().Router r := New().router
for _, route := range api { for _, route := range api {
r.Add(route.method, route.path, func(c *Context) error { r.Add(route.method, route.path, func(c *Context) error {
for i, n := range c.pnames { for i, n := range c.pnames {
@ -618,7 +618,7 @@ func TestRouterAPI(t *testing.T) {
} }
func TestRouterServeHTTP(t *testing.T) { func TestRouterServeHTTP(t *testing.T) {
r := New().Router r := New().router
r.Add(GET, "/users", func(*Context) error { r.Add(GET, "/users", func(*Context) error {
return nil return nil
}, nil) }, nil)