mirror of
https://github.com/labstack/echo.git
synced 2024-11-28 08:38:39 +02:00
Better handling in middleware for WebSocket
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
95f72a5170
commit
e0364caf36
58
echo.go
58
echo.go
@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
type (
|
type (
|
||||||
Echo struct {
|
Echo struct {
|
||||||
Router *router
|
router *Router
|
||||||
prefix string
|
prefix string
|
||||||
middleware []MiddlewareFunc
|
middleware []MiddlewareFunc
|
||||||
http2 bool
|
http2 bool
|
||||||
@ -33,10 +33,12 @@ type (
|
|||||||
pool sync.Pool
|
pool sync.Pool
|
||||||
debug bool
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
HTTPError struct {
|
HTTPError struct {
|
||||||
code int
|
code int
|
||||||
message string
|
message string
|
||||||
}
|
}
|
||||||
|
|
||||||
Middleware interface{}
|
Middleware interface{}
|
||||||
MiddlewareFunc func(HandlerFunc) HandlerFunc
|
MiddlewareFunc func(HandlerFunc) HandlerFunc
|
||||||
Handler interface{}
|
Handler interface{}
|
||||||
@ -99,6 +101,13 @@ const (
|
|||||||
ContentLength = "Content-Length"
|
ContentLength = "Content-Length"
|
||||||
ContentType = "Content-Type"
|
ContentType = "Content-Type"
|
||||||
Authorization = "Authorization"
|
Authorization = "Authorization"
|
||||||
|
Upgrade = "Upgrade"
|
||||||
|
|
||||||
|
//-----------
|
||||||
|
// Protocols
|
||||||
|
//-----------
|
||||||
|
|
||||||
|
WebSocket = "websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -122,30 +131,12 @@ var (
|
|||||||
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
|
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewHTTPError(code int, msg ...string) *HTTPError {
|
|
||||||
he := &HTTPError{code: code, message: http.StatusText(code)}
|
|
||||||
for _, m := range msg {
|
|
||||||
he.message = m
|
|
||||||
}
|
|
||||||
return he
|
|
||||||
}
|
|
||||||
|
|
||||||
// Code returns code.
|
|
||||||
func (e *HTTPError) Code() int {
|
|
||||||
return e.code
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error returns message.
|
|
||||||
func (e *HTTPError) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates an Echo instance.
|
// New creates an Echo instance.
|
||||||
func New() (e *Echo) {
|
func New() (e *Echo) {
|
||||||
e = &Echo{
|
e = &Echo{
|
||||||
uris: make(map[Handler]string),
|
uris: make(map[Handler]string),
|
||||||
}
|
}
|
||||||
e.Router = NewRouter(e)
|
e.router = NewRouter(e)
|
||||||
e.pool.New = func() interface{} {
|
e.pool.New = func() interface{} {
|
||||||
return NewContext(nil, new(Response), e)
|
return NewContext(nil, new(Response), e)
|
||||||
}
|
}
|
||||||
@ -196,6 +187,11 @@ func (e *Echo) Group(pfx string, m ...Middleware) *Echo {
|
|||||||
return &g
|
return &g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Router returns router.
|
||||||
|
func (e *Echo) Router() *Router {
|
||||||
|
return e.router
|
||||||
|
}
|
||||||
|
|
||||||
// HTTP2 enables HTTP2 support.
|
// HTTP2 enables HTTP2 support.
|
||||||
func (e *Echo) HTTP2(on bool) {
|
func (e *Echo) HTTP2(on bool) {
|
||||||
e.http2 = on
|
e.http2 = on
|
||||||
@ -302,7 +298,7 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) {
|
|||||||
func (e *Echo) add(method, path string, h Handler) {
|
func (e *Echo) add(method, path string, h Handler) {
|
||||||
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
|
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
|
||||||
e.uris[key] = path
|
e.uris[key] = path
|
||||||
e.Router.Add(method, e.prefix+path, wrapHandler(h), e)
|
e.router.Add(method, e.prefix+path, wrapHandler(h), e)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Index serves index file.
|
// Index serves index file.
|
||||||
@ -361,7 +357,7 @@ func (e *Echo) URL(h Handler, params ...interface{}) string {
|
|||||||
|
|
||||||
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
c := e.pool.Get().(*Context)
|
c := e.pool.Get().(*Context)
|
||||||
h, echo := e.Router.Find(r.Method, r.URL.Path, c)
|
h, echo := e.router.Find(r.Method, r.URL.Path, c)
|
||||||
if echo != nil {
|
if echo != nil {
|
||||||
e = echo
|
e = echo
|
||||||
}
|
}
|
||||||
@ -419,6 +415,24 @@ func (e *Echo) run(s *http.Server, files ...string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewHTTPError(code int, msg ...string) *HTTPError {
|
||||||
|
he := &HTTPError{code: code, message: http.StatusText(code)}
|
||||||
|
for _, m := range msg {
|
||||||
|
he.message = m
|
||||||
|
}
|
||||||
|
return he
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code returns code.
|
||||||
|
func (e *HTTPError) Code() int {
|
||||||
|
return e.code
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns message.
|
||||||
|
func (e *HTTPError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
// wraps middleware
|
// wraps middleware
|
||||||
func wrapMiddleware(m Middleware) MiddlewareFunc {
|
func wrapMiddleware(m Middleware) MiddlewareFunc {
|
||||||
switch m := m.(type) {
|
switch m := m.(type) {
|
||||||
|
@ -21,6 +21,11 @@ const (
|
|||||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||||
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
|
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
|
||||||
return func(c *echo.Context) error {
|
return func(c *echo.Context) error {
|
||||||
|
// Skip for WebSocket
|
||||||
|
if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
auth := c.Request().Header.Get(echo.Authorization)
|
auth := c.Request().Header.Get(echo.Authorization)
|
||||||
i := 0
|
i := 0
|
||||||
code := http.StatusBadRequest
|
code := http.StatusBadRequest
|
||||||
|
@ -5,8 +5,9 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/labstack/echo"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/labstack/echo"
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
@ -27,7 +28,8 @@ func Gzip() echo.MiddlewareFunc {
|
|||||||
|
|
||||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||||
return func(c *echo.Context) error {
|
return func(c *echo.Context) error {
|
||||||
if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
|
if (c.Request().Header.Get(echo.Upgrade)) != echo.WebSocket && // Skip for WebSocket
|
||||||
|
strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
|
||||||
w := gzip.NewWriter(c.Response().Writer())
|
w := gzip.NewWriter(c.Response().Writer())
|
||||||
defer w.Close()
|
defer w.Close()
|
||||||
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}
|
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}
|
||||||
|
14
router.go
14
router.go
@ -3,7 +3,7 @@ package echo
|
|||||||
import "net/http"
|
import "net/http"
|
||||||
|
|
||||||
type (
|
type (
|
||||||
router struct {
|
Router struct {
|
||||||
trees map[string]*node
|
trees map[string]*node
|
||||||
echo *Echo
|
echo *Echo
|
||||||
}
|
}
|
||||||
@ -27,8 +27,8 @@ const (
|
|||||||
mtype
|
mtype
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewRouter(e *Echo) (r *router) {
|
func NewRouter(e *Echo) (r *Router) {
|
||||||
r = &router{
|
r = &Router{
|
||||||
trees: make(map[string]*node),
|
trees: make(map[string]*node),
|
||||||
echo: e,
|
echo: e,
|
||||||
}
|
}
|
||||||
@ -41,7 +41,7 @@ func NewRouter(e *Echo) (r *router) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
|
func (r *Router) Add(method, path string, h HandlerFunc, echo *Echo) {
|
||||||
var pnames []string // Param names
|
var pnames []string // Param names
|
||||||
|
|
||||||
for i, l := 0, len(path); i < l; i++ {
|
for i, l := 0, len(path); i < l; i++ {
|
||||||
@ -71,7 +71,7 @@ func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
|
|||||||
r.insert(method, path, h, stype, pnames, echo)
|
r.insert(method, path, h, stype, pnames, echo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
|
func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
|
||||||
cn := r.trees[method] // Current node as root
|
cn := r.trees[method] // Current node as root
|
||||||
search := path
|
search := path
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ func lcp(a, b string) (i int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
|
func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
|
||||||
cn := r.trees[method] // Current node as root
|
cn := r.trees[method] // Current node as root
|
||||||
search := path
|
search := path
|
||||||
|
|
||||||
@ -305,7 +305,7 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
c := r.echo.pool.Get().(*Context)
|
c := r.echo.pool.Get().(*Context)
|
||||||
h, _ := r.Find(req.Method, req.URL.Path, c)
|
h, _ := r.Find(req.Method, req.URL.Path, c)
|
||||||
c.reset(w, req, r.echo)
|
c.reset(w, req, r.echo)
|
||||||
|
@ -280,7 +280,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestRouterStatic(t *testing.T) {
|
func TestRouterStatic(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
b := new(bytes.Buffer)
|
b := new(bytes.Buffer)
|
||||||
path := "/folders/a/files/echo.gif"
|
path := "/folders/a/files/echo.gif"
|
||||||
r.Add(GET, path, func(*Context) error {
|
r.Add(GET, path, func(*Context) error {
|
||||||
@ -299,7 +299,7 @@ func TestRouterStatic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterParam(t *testing.T) {
|
func TestRouterParam(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
r.Add(GET, "/users/:id", func(c *Context) error {
|
r.Add(GET, "/users/:id", func(c *Context) error {
|
||||||
return nil
|
return nil
|
||||||
}, nil)
|
}, nil)
|
||||||
@ -314,7 +314,7 @@ func TestRouterParam(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterTwoParam(t *testing.T) {
|
func TestRouterTwoParam(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
r.Add(GET, "/users/:uid/files/:fid", func(*Context) error {
|
r.Add(GET, "/users/:uid/files/:fid", func(*Context) error {
|
||||||
return nil
|
return nil
|
||||||
}, nil)
|
}, nil)
|
||||||
@ -338,7 +338,7 @@ func TestRouterTwoParam(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterMatchAny(t *testing.T) {
|
func TestRouterMatchAny(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
r.Add(GET, "/users/*", func(*Context) error {
|
r.Add(GET, "/users/*", func(*Context) error {
|
||||||
return nil
|
return nil
|
||||||
}, nil)
|
}, nil)
|
||||||
@ -363,7 +363,7 @@ func TestRouterMatchAny(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterMicroParam(t *testing.T) {
|
func TestRouterMicroParam(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
r.Add(GET, "/:a/:b/:c", func(c *Context) error {
|
r.Add(GET, "/:a/:b/:c", func(c *Context) error {
|
||||||
return nil
|
return nil
|
||||||
}, nil)
|
}, nil)
|
||||||
@ -384,7 +384,7 @@ func TestRouterMicroParam(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterMultiRoute(t *testing.T) {
|
func TestRouterMultiRoute(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
b := new(bytes.Buffer)
|
b := new(bytes.Buffer)
|
||||||
|
|
||||||
// Routes
|
// Routes
|
||||||
@ -425,7 +425,7 @@ func TestRouterMultiRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterPriority(t *testing.T) {
|
func TestRouterPriority(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
|
|
||||||
// Routes
|
// Routes
|
||||||
r.Add(GET, "/users", func(c *Context) error {
|
r.Add(GET, "/users", func(c *Context) error {
|
||||||
@ -536,7 +536,7 @@ func TestRouterPriority(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterParamNames(t *testing.T) {
|
func TestRouterParamNames(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
b := new(bytes.Buffer)
|
b := new(bytes.Buffer)
|
||||||
|
|
||||||
// Routes
|
// Routes
|
||||||
@ -596,7 +596,7 @@ func TestRouterParamNames(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterAPI(t *testing.T) {
|
func TestRouterAPI(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
for _, route := range api {
|
for _, route := range api {
|
||||||
r.Add(route.method, route.path, func(c *Context) error {
|
r.Add(route.method, route.path, func(c *Context) error {
|
||||||
for i, n := range c.pnames {
|
for i, n := range c.pnames {
|
||||||
@ -618,7 +618,7 @@ func TestRouterAPI(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRouterServeHTTP(t *testing.T) {
|
func TestRouterServeHTTP(t *testing.T) {
|
||||||
r := New().Router
|
r := New().router
|
||||||
r.Add(GET, "/users", func(*Context) error {
|
r.Add(GET, "/users", func(*Context) error {
|
||||||
return nil
|
return nil
|
||||||
}, nil)
|
}, nil)
|
||||||
|
Loading…
Reference in New Issue
Block a user