mirror of
https://github.com/labstack/echo.git
synced 2024-12-24 20:14:31 +02:00
Better handling in middleware for WebSocket
Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
parent
95f72a5170
commit
e0364caf36
58
echo.go
58
echo.go
@ -20,7 +20,7 @@ import (
|
||||
|
||||
type (
|
||||
Echo struct {
|
||||
Router *router
|
||||
router *Router
|
||||
prefix string
|
||||
middleware []MiddlewareFunc
|
||||
http2 bool
|
||||
@ -33,10 +33,12 @@ type (
|
||||
pool sync.Pool
|
||||
debug bool
|
||||
}
|
||||
|
||||
HTTPError struct {
|
||||
code int
|
||||
message string
|
||||
}
|
||||
|
||||
Middleware interface{}
|
||||
MiddlewareFunc func(HandlerFunc) HandlerFunc
|
||||
Handler interface{}
|
||||
@ -99,6 +101,13 @@ const (
|
||||
ContentLength = "Content-Length"
|
||||
ContentType = "Content-Type"
|
||||
Authorization = "Authorization"
|
||||
Upgrade = "Upgrade"
|
||||
|
||||
//-----------
|
||||
// Protocols
|
||||
//-----------
|
||||
|
||||
WebSocket = "websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -122,30 +131,12 @@ var (
|
||||
RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
|
||||
)
|
||||
|
||||
func NewHTTPError(code int, msg ...string) *HTTPError {
|
||||
he := &HTTPError{code: code, message: http.StatusText(code)}
|
||||
for _, m := range msg {
|
||||
he.message = m
|
||||
}
|
||||
return he
|
||||
}
|
||||
|
||||
// Code returns code.
|
||||
func (e *HTTPError) Code() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
// Error returns message.
|
||||
func (e *HTTPError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// New creates an Echo instance.
|
||||
func New() (e *Echo) {
|
||||
e = &Echo{
|
||||
uris: make(map[Handler]string),
|
||||
}
|
||||
e.Router = NewRouter(e)
|
||||
e.router = NewRouter(e)
|
||||
e.pool.New = func() interface{} {
|
||||
return NewContext(nil, new(Response), e)
|
||||
}
|
||||
@ -196,6 +187,11 @@ func (e *Echo) Group(pfx string, m ...Middleware) *Echo {
|
||||
return &g
|
||||
}
|
||||
|
||||
// Router returns router.
|
||||
func (e *Echo) Router() *Router {
|
||||
return e.router
|
||||
}
|
||||
|
||||
// HTTP2 enables HTTP2 support.
|
||||
func (e *Echo) HTTP2(on bool) {
|
||||
e.http2 = on
|
||||
@ -302,7 +298,7 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) {
|
||||
func (e *Echo) add(method, path string, h Handler) {
|
||||
key := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
|
||||
e.uris[key] = path
|
||||
e.Router.Add(method, e.prefix+path, wrapHandler(h), e)
|
||||
e.router.Add(method, e.prefix+path, wrapHandler(h), e)
|
||||
}
|
||||
|
||||
// Index serves index file.
|
||||
@ -361,7 +357,7 @@ func (e *Echo) URL(h Handler, params ...interface{}) string {
|
||||
|
||||
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
c := e.pool.Get().(*Context)
|
||||
h, echo := e.Router.Find(r.Method, r.URL.Path, c)
|
||||
h, echo := e.router.Find(r.Method, r.URL.Path, c)
|
||||
if echo != nil {
|
||||
e = echo
|
||||
}
|
||||
@ -419,6 +415,24 @@ func (e *Echo) run(s *http.Server, files ...string) {
|
||||
}
|
||||
}
|
||||
|
||||
func NewHTTPError(code int, msg ...string) *HTTPError {
|
||||
he := &HTTPError{code: code, message: http.StatusText(code)}
|
||||
for _, m := range msg {
|
||||
he.message = m
|
||||
}
|
||||
return he
|
||||
}
|
||||
|
||||
// Code returns code.
|
||||
func (e *HTTPError) Code() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
// Error returns message.
|
||||
func (e *HTTPError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// wraps middleware
|
||||
func wrapMiddleware(m Middleware) MiddlewareFunc {
|
||||
switch m := m.(type) {
|
||||
|
@ -21,6 +21,11 @@ const (
|
||||
// For invalid credentials, it sends "401 - Unauthorized" response.
|
||||
func BasicAuth(fn AuthFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
// Skip for WebSocket
|
||||
if (c.Request().Header.Get(echo.Upgrade)) == echo.WebSocket {
|
||||
return nil
|
||||
}
|
||||
|
||||
auth := c.Request().Header.Get(echo.Authorization)
|
||||
i := 0
|
||||
code := http.StatusBadRequest
|
||||
|
@ -5,8 +5,9 @@ import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo"
|
||||
)
|
||||
|
||||
type (
|
||||
@ -27,7 +28,8 @@ func Gzip() echo.MiddlewareFunc {
|
||||
|
||||
return func(h echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
|
||||
if (c.Request().Header.Get(echo.Upgrade)) != echo.WebSocket && // Skip for WebSocket
|
||||
strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) {
|
||||
w := gzip.NewWriter(c.Response().Writer())
|
||||
defer w.Close()
|
||||
gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()}
|
||||
|
14
router.go
14
router.go
@ -3,7 +3,7 @@ package echo
|
||||
import "net/http"
|
||||
|
||||
type (
|
||||
router struct {
|
||||
Router struct {
|
||||
trees map[string]*node
|
||||
echo *Echo
|
||||
}
|
||||
@ -27,8 +27,8 @@ const (
|
||||
mtype
|
||||
)
|
||||
|
||||
func NewRouter(e *Echo) (r *router) {
|
||||
r = &router{
|
||||
func NewRouter(e *Echo) (r *Router) {
|
||||
r = &Router{
|
||||
trees: make(map[string]*node),
|
||||
echo: e,
|
||||
}
|
||||
@ -41,7 +41,7 @@ func NewRouter(e *Echo) (r *router) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
|
||||
func (r *Router) Add(method, path string, h HandlerFunc, echo *Echo) {
|
||||
var pnames []string // Param names
|
||||
|
||||
for i, l := 0, len(path); i < l; i++ {
|
||||
@ -71,7 +71,7 @@ func (r *router) Add(method, path string, h HandlerFunc, echo *Echo) {
|
||||
r.insert(method, path, h, stype, pnames, echo)
|
||||
}
|
||||
|
||||
func (r *router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
|
||||
func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []string, echo *Echo) {
|
||||
cn := r.trees[method] // Current node as root
|
||||
search := path
|
||||
|
||||
@ -201,7 +201,7 @@ func lcp(a, b string) (i int) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
|
||||
func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *Echo) {
|
||||
cn := r.trees[method] // Current node as root
|
||||
search := path
|
||||
|
||||
@ -305,7 +305,7 @@ func (r *router) Find(method, path string, ctx *Context) (h HandlerFunc, echo *E
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
c := r.echo.pool.Get().(*Context)
|
||||
h, _ := r.Find(req.Method, req.URL.Path, c)
|
||||
c.reset(w, req, r.echo)
|
||||
|
@ -280,7 +280,7 @@ var (
|
||||
)
|
||||
|
||||
func TestRouterStatic(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
b := new(bytes.Buffer)
|
||||
path := "/folders/a/files/echo.gif"
|
||||
r.Add(GET, path, func(*Context) error {
|
||||
@ -299,7 +299,7 @@ func TestRouterStatic(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterParam(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
r.Add(GET, "/users/:id", func(c *Context) error {
|
||||
return nil
|
||||
}, nil)
|
||||
@ -314,7 +314,7 @@ func TestRouterParam(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterTwoParam(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
r.Add(GET, "/users/:uid/files/:fid", func(*Context) error {
|
||||
return nil
|
||||
}, nil)
|
||||
@ -338,7 +338,7 @@ func TestRouterTwoParam(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterMatchAny(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
r.Add(GET, "/users/*", func(*Context) error {
|
||||
return nil
|
||||
}, nil)
|
||||
@ -363,7 +363,7 @@ func TestRouterMatchAny(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterMicroParam(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
r.Add(GET, "/:a/:b/:c", func(c *Context) error {
|
||||
return nil
|
||||
}, nil)
|
||||
@ -384,7 +384,7 @@ func TestRouterMicroParam(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterMultiRoute(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
b := new(bytes.Buffer)
|
||||
|
||||
// Routes
|
||||
@ -425,7 +425,7 @@ func TestRouterMultiRoute(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterPriority(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
|
||||
// Routes
|
||||
r.Add(GET, "/users", func(c *Context) error {
|
||||
@ -536,7 +536,7 @@ func TestRouterPriority(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterParamNames(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
b := new(bytes.Buffer)
|
||||
|
||||
// Routes
|
||||
@ -596,7 +596,7 @@ func TestRouterParamNames(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterAPI(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
for _, route := range api {
|
||||
r.Add(route.method, route.path, func(c *Context) error {
|
||||
for i, n := range c.pnames {
|
||||
@ -618,7 +618,7 @@ func TestRouterAPI(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRouterServeHTTP(t *testing.T) {
|
||||
r := New().Router
|
||||
r := New().router
|
||||
r.Add(GET, "/users", func(*Context) error {
|
||||
return nil
|
||||
}, nil)
|
||||
|
Loading…
Reference in New Issue
Block a user