1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Better handling in middleware for WebSocket

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-22 20:26:52 -07:00
parent 95f72a5170
commit e0364caf36
6 changed files with 69 additions and 41 deletions

58
echo.go
View File

@ -20,7 +20,7 @@ import (
type (
Echo struct {
Router *router
router *Router
prefix string
middleware []MiddlewareFunc
http2 bool
@ -33,10 +33,12 @@ type (
pool sync.Pool
debug bool
}
HTTPError struct {
code int
message string
}
Middleware interface{}
MiddlewareFunc func(HandlerFunc) HandlerFunc
Handler interface{}
@ -99,6 +101,13 @@ const (
ContentLength = "Content-Length"
ContentType = "Content-Type"
Authorization = "Authorization"
Upgrade = "Upgrade"
//-----------
// Protocols
//-----------
WebSocket = "websocket"
)
var (
@ -122,30 +131,12 @@ var (
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
)
func NewHTTPError(code int, msg ...string) *HTTPError {
he := &HTTPError{code: code, message: http.StatusText(code)}
for _, m := range msg {
he.message = m
}
return he
}
// Code returns code.
func (e *HTTPError) Code() int {
return e.code
}
// Error returns message.
func (e *HTTPError) Error() string {
return e.message
}
// New creates an Echo instance.
func New() (e *Echo) {
e = &Echo{
uris: make(map[Handler]string),
}
e.Router = NewRouter(e)
e.router = NewRouter(e)
e.pool.New = func() interface{} {
return NewContext(nil, new(Response), e)
}
@ -196,6 +187,11 @@ func (e *Echo) Group(pfx string, m ...Middleware) *Echo {
return &g
}
// Router returns router.
func (e *Echo) Router() *Router {
return e.router
}
// HTTP2 enables HTTP2 support.
func (e *Echo) HTTP2(on bool) {
e.http2 = on
@ -302,7 +298,7 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) {
func (e *Echo) add(method, path string, h Handler) {
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
e.uris[key] = path
e.Router.Add(method, e.prefix+path, wrapHandler(h), e)
e.router.Add(method, e.prefix+path, wrapHandler(h), e)
}
// Index serves index file.
@ -361,7 +357,7 @@ func (e *Echo) URL(h Handler, params ...interface{}) string {
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := e.pool.Get().(*Context)
h, echo := e.Router.Find(r.Method, r.URL.Path, c)
h, echo := e.router.Find(r.Method, r.URL.Path, c)
if echo != nil {
e = echo
}
@ -419,6 +415,24 @@ func (e *Echo) run(s *http.Server, files ...string) {
}
}
func NewHTTPError(code int, msg ...string) *HTTPError {
he := &HTTPError{code: code, message: http.StatusText(code)}
for _, m := range msg {
he.message = m
}
return he
}
// Code returns code.
func (e *HTTPError) Code() int {
return e.code
}
// Error returns message.
func (e *HTTPError) Error() string {
return e.message
}
// wraps middleware
func wrapMiddleware(m Middleware) MiddlewareFunc {
switch m := m.(type) {

7
group.go Normal file
View File

@ -0,0 +1,7 @@
package echo
type (
Group struct {
*Echo
}
)

View File

@ -21,6 +21,11 @@ const (
// For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
// Skip for WebSocket
if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket {
return nil
}
auth := c.Request().Header.Get(echo.Authorization)
i := 0
code := http.StatusBadRequest

View File

@ -5,8 +5,9 @@ import (
"io"
"strings"
"github.com/labstack/echo"
"net/http"
"github.com/labstack/echo"
)
type (
@ -27,7 +28,8 @@ func Gzip() echo.MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
if (c.Request().Header.Get(echo.Upgrade)) != echo.WebSocket && // Skip for WebSocket
strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
w := gzip.NewWriter(c.Response().Writer())
defer w.Close()
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}

View File

@ -3,7 +3,7 @@ package echo
import "net/http"
type (
router struct {
Router struct {
trees map[string]*node
echo *Echo
}
@ -27,8 +27,8 @@ const (
mtype
)
func NewRouter(e *Echo) (r *router) {
r = &router{
func NewRouter(e *Echo) (r *Router) {
r = &Router{
trees: make(map[string]*node),
echo: e,
}
@ -41,7 +41,7 @@ func NewRouter(e *Echo) (r *router) {
return
}
func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
func (r *Router) Add(method, path string, h HandlerFunc, echo *Echo) {
var pnames []string // Param names
for i, l := 0, len(path); i < l; i++ {
@ -71,7 +71,7 @@ func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
r.insert(method, path, h, stype, pnames, echo)
}
func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
cn := r.trees[method] // Current node as root
search := path
@ -201,7 +201,7 @@ func lcp(a, b string) (i int) {
return
}
func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
cn := r.trees[method] // Current node as root
search := path
@ -305,7 +305,7 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
}
}
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := r.echo.pool.Get().(*Context)
h, _ := r.Find(req.Method, req.URL.Path, c)
c.reset(w, req, r.echo)

View File

@ -280,7 +280,7 @@ var (
)
func TestRouterStatic(t *testing.T) {
r := New().Router
r := New().router
b := new(bytes.Buffer)
path := "/folders/a/files/echo.gif"
r.Add(GET, path, func(*Context) error {
@ -299,7 +299,7 @@ func TestRouterStatic(t *testing.T) {
}
func TestRouterParam(t *testing.T) {
r := New().Router
r := New().router
r.Add(GET, "/users/:id", func(c *Context) error {
return nil
}, nil)
@ -314,7 +314,7 @@ func TestRouterParam(t *testing.T) {
}
func TestRouterTwoParam(t *testing.T) {
r := New().Router
r := New().router
r.Add(GET, "/users/:uid/files/:fid", func(*Context) error {
return nil
}, nil)
@ -338,7 +338,7 @@ func TestRouterTwoParam(t *testing.T) {
}
func TestRouterMatchAny(t *testing.T) {
r := New().Router
r := New().router
r.Add(GET, "/users/*", func(*Context) error {
return nil
}, nil)
@ -363,7 +363,7 @@ func TestRouterMatchAny(t *testing.T) {
}
func TestRouterMicroParam(t *testing.T) {
r := New().Router
r := New().router
r.Add(GET, "/:a/:b/:c", func(c *Context) error {
return nil
}, nil)
@ -384,7 +384,7 @@ func TestRouterMicroParam(t *testing.T) {
}
func TestRouterMultiRoute(t *testing.T) {
r := New().Router
r := New().router
b := new(bytes.Buffer)
// Routes
@ -425,7 +425,7 @@ func TestRouterMultiRoute(t *testing.T) {
}
func TestRouterPriority(t *testing.T) {
r := New().Router
r := New().router
// Routes
r.Add(GET, "/users", func(c *Context) error {
@ -536,7 +536,7 @@ func TestRouterPriority(t *testing.T) {
}
func TestRouterParamNames(t *testing.T) {
r := New().Router
r := New().router
b := new(bytes.Buffer)
// Routes
@ -596,7 +596,7 @@ func TestRouterParamNames(t *testing.T) {
}
func TestRouterAPI(t *testing.T) {
r := New().Router
r := New().router
for _, route := range api {
r.Add(route.method, route.path, func(c *Context) error {
for i, n := range c.pnames {
@ -618,7 +618,7 @@ func TestRouterAPI(t *testing.T) {
}
func TestRouterServeHTTP(t *testing.T) {
r := New().Router
r := New().router
r.Add(GET, "/users", func(*Context) error {
return nil
}, nil)