1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Enhanced method lookup in router

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-07-24 12:03:36 -07:00
parent 99f2868bcc
commit 4dcb57d42a
4 changed files with 135 additions and 30 deletions

21
echo.go
View File

@ -104,7 +104,10 @@ const (
TextPlainCharsetUTF8 = TextPlain + "; " + CharsetUTF8 TextPlainCharsetUTF8 = TextPlain + "; " + CharsetUTF8
MultipartForm = "multipart/form-data" MultipartForm = "multipart/form-data"
//---------
// Charset // Charset
//---------
CharsetUTF8 = "charset=utf-8" CharsetUTF8 = "charset=utf-8"
//--------- //---------
@ -150,6 +153,18 @@ var (
UnsupportedMediaType = errors.New("echo ⇒ unsupported media type") UnsupportedMediaType = errors.New("echo ⇒ unsupported media type")
RendererNotRegistered = errors.New("echo ⇒ renderer not registered") RendererNotRegistered = errors.New("echo ⇒ renderer not registered")
//----------------
// Error handlers
//----------------
notFoundHandler = func(c *Context) error {
return NewHTTPError(http.StatusNotFound)
}
badRequestHandler = func(c *Context) error {
return NewHTTPError(http.StatusBadRequest)
}
) )
// New creates an instance of Echo. // New creates an instance of Echo.
@ -168,9 +183,6 @@ func New() (e *Echo) {
e.ColoredLog(false) e.ColoredLog(false)
} }
e.HTTP2(false) e.HTTP2(false)
e.notFoundHandler = func(c *Context) error {
return NewHTTPError(http.StatusNotFound)
}
e.defaultHTTPErrorHandler = func(err error, c *Context) { e.defaultHTTPErrorHandler = func(err error, c *Context) {
code := http.StatusInternalServerError code := http.StatusInternalServerError
msg := http.StatusText(code) msg := http.StatusText(code)
@ -426,9 +438,6 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
e = echo e = echo
} }
c.reset(r, w, e) c.reset(r, w, e)
if h == nil {
h = e.notFoundHandler
}
// Chain middleware with handler in the end // Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- { for i := len(e.middleware) - 1; i >= 0; i-- {

View File

@ -368,6 +368,14 @@ func TestEchoNotFound(t *testing.T) {
assert.Equal(t, http.StatusNotFound, w.Code) assert.Equal(t, http.StatusNotFound, w.Code)
} }
func TestEchoBadRequest(t *testing.T) {
e := New()
r, _ := http.NewRequest("INVALID", "/files", nil)
w := httptest.NewRecorder()
e.ServeHTTP(w, r)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestEchoHTTPError(t *testing.T) { func TestEchoHTTPError(t *testing.T) {
m := http.StatusText(http.StatusBadRequest) m := http.StatusText(http.StatusBadRequest)
he := NewHTTPError(http.StatusBadRequest, m) he := NewHTTPError(http.StatusBadRequest, m)
@ -385,8 +393,7 @@ func testMethod(t *testing.T, method, path string, e *Echo) {
m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:]))
p := reflect.ValueOf(path) p := reflect.ValueOf(path)
h := reflect.ValueOf(func(c *Context) error { h := reflect.ValueOf(func(c *Context) error {
c.String(http.StatusOK, method) return c.String(http.StatusOK, method)
return nil
}) })
i := interface{}(e) i := interface{}(e)
reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h}) reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h})

110
router.go
View File

@ -1,10 +1,21 @@
package echo package echo
import "net/http" import (
"encoding/binary"
"net/http"
)
type ( type (
Router struct { Router struct {
trees [21]*node connectTree *node
deleteTree *node
getTree *node
headTree *node
optionsTree *node
patchTree *node
postTree *node
putTree *node
traceTree *node
routes []Route routes []Route
echo *Echo echo *Echo
} }
@ -30,15 +41,17 @@ const (
func NewRouter(e *Echo) (r *Router) { func NewRouter(e *Echo) (r *Router) {
r = &Router{ r = &Router{
// trees: make(map[string]*node),
routes: []Route{}, routes: []Route{},
echo: e, echo: e,
} connectTree: new(node),
for _, m := range methods { deleteTree: new(node),
r.trees[r.treeIndex(m)] = &node{ getTree: new(node),
prefix: "", headTree: new(node),
children: children{}, optionsTree: new(node),
} patchTree: new(node),
postTree: new(node),
putTree: new(node),
traceTree: new(node),
} }
return return
} }
@ -81,7 +94,10 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st
*e.maxParam = l *e.maxParam = l
} }
cn := r.trees[r.treeIndex(method)] // Current node as root cn := r.findTree(method) // Current node as root
if cn == nil {
panic("echo => invalid method")
}
search := path search := path
for { for {
@ -208,12 +224,74 @@ func (r *Router) treeIndex(method string) uint8 {
} }
} }
func (r *Router) findTree(method string) (n *node) {
switch method[0] {
case 'G': // GET
m := uint32(method[2])<<8 | uint32(method[1])<<16 | uint32(method[0])<<24
if m == 0x47455400 {
n = r.getTree
}
case 'P':
switch method[1] {
case 'O': // POST
m := binary.BigEndian.Uint32([]byte(method))
if m == 0x504f5354 {
n = r.postTree
}
case 'U': // PUT
m := uint32(method[2])<<8 | uint32(method[1])<<16 | uint32(method[0])<<24
if m == 0x50555400 {
n = r.putTree
}
case 'A': // PATCH
m := uint64(method[4])<<24 | uint64(method[3])<<32 | uint64(method[2])<<40 |
uint64(method[1])<<48 | uint64(method[0])<<56
if m == 0x5041544348000000 {
n = r.patchTree
}
}
case 'D':
m := uint64(method[5])<<16 | uint64(method[4])<<24 | uint64(method[3])<<32 |
uint64(method[2])<<40 | uint64(method[1])<<48 | uint64(method[0])<<56
if m == 0x44454c4554450000 {
n = r.deleteTree
}
case 'C':
m := uint64(method[6])<<8 | uint64(method[5])<<16 | uint64(method[4])<<24 |
uint64(method[3])<<32 | uint64(method[2])<<40 | uint64(method[1])<<48 |
uint64(method[0])<<56
if m == 0x434f4e4e45435400 {
n = r.connectTree
}
case 'H':
m := binary.BigEndian.Uint32([]byte(method))
if m == 0x48454144 {
n = r.headTree
}
case 'O':
m := uint64(method[6])<<8 | uint64(method[5])<<16 | uint64(method[4])<<24 |
uint64(method[3])<<32 | uint64(method[2])<<40 | uint64(method[1])<<48 |
uint64(method[0])<<56
if m == 0x4f5054494f4e5300 {
n = r.connectTree
}
case 'T':
m := uint64(method[4])<<24 | uint64(method[3])<<32 | uint64(method[2])<<40 |
uint64(method[1])<<48 | uint64(method[0])<<56
if m == 0x5452414345000000 {
n = r.traceTree
}
}
return
}
func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo) { func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo) {
i := r.treeIndex(method) h = notFoundHandler
if i > 20 { cn := r.findTree(method) // Current node as root
if cn == nil {
h = badRequestHandler
return return
} }
cn := r.trees[i] // Current node as root
search := path search := path
var ( var (
@ -330,10 +408,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
c := r.echo.pool.Get().(*Context) c := r.echo.pool.Get().(*Context)
h, _ := r.Find(req.Method, req.URL.Path, c) h, _ := r.Find(req.Method, req.URL.Path, c)
c.reset(req, w, r.echo) c.reset(req, w, r.echo)
if h == nil { if err := h(c); err != nil {
c.Error(NewHTTPError(http.StatusNotFound)) r.echo.httpErrorHandler(err, c)
} else {
h(c)
} }
r.echo.pool.Put(c) r.echo.pool.Put(c)
} }

View File

@ -384,7 +384,10 @@ func TestRouterMultiRoute(t *testing.T) {
// Route > /user // Route > /user
h, _ = r.Find(GET, "/user", c) h, _ = r.Find(GET, "/user", c)
assert.Nil(t, h) if assert.IsType(t, new(HTTPError), h(c)) {
he := h(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.code)
}
} }
func TestRouterPriority(t *testing.T) { func TestRouterPriority(t *testing.T) {
@ -537,6 +540,16 @@ func TestRouterAPI(t *testing.T) {
} }
} }
func TestRouterAddInvalidMethod(t *testing.T) {
e := New()
r := e.router
assert.Panics(t, func() {
r.Add("INVALID", "/", func(*Context) error {
return nil
}, e)
})
}
func TestRouterServeHTTP(t *testing.T) { func TestRouterServeHTTP(t *testing.T) {
e := New() e := New()
r := e.router r := e.router