diff --git a/echo.go b/echo.go index 75bacd44..b64b65f7 100644 --- a/echo.go +++ b/echo.go @@ -104,7 +104,10 @@ const ( TextPlainCharsetUTF8 = TextPlain + "; " + CharsetUTF8 MultipartForm = "multipart/form-data" + //--------- // Charset + //--------- + CharsetUTF8 = "charset=utf-8" //--------- @@ -150,6 +153,18 @@ var ( UnsupportedMediaType = errors.New("echo ⇒ unsupported media type") RendererNotRegistered = errors.New("echo ⇒ renderer not registered") + + //---------------- + // Error handlers + //---------------- + + notFoundHandler = func(c *Context) error { + return NewHTTPError(http.StatusNotFound) + } + + badRequestHandler = func(c *Context) error { + return NewHTTPError(http.StatusBadRequest) + } ) // New creates an instance of Echo. @@ -168,9 +183,6 @@ func New() (e *Echo) { e.ColoredLog(false) } e.HTTP2(false) - e.notFoundHandler = func(c *Context) error { - return NewHTTPError(http.StatusNotFound) - } e.defaultHTTPErrorHandler = func(err error, c *Context) { code := http.StatusInternalServerError msg := http.StatusText(code) @@ -426,9 +438,6 @@ func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { e = echo } c.reset(r, w, e) - if h == nil { - h = e.notFoundHandler - } // Chain middleware with handler in the end for i := len(e.middleware) - 1; i >= 0; i-- { diff --git a/echo_test.go b/echo_test.go index 00d4d7a5..4c2d8254 100644 --- a/echo_test.go +++ b/echo_test.go @@ -368,6 +368,14 @@ func TestEchoNotFound(t *testing.T) { assert.Equal(t, http.StatusNotFound, w.Code) } +func TestEchoBadRequest(t *testing.T) { + e := New() + r, _ := http.NewRequest("INVALID", "/files", nil) + w := httptest.NewRecorder() + e.ServeHTTP(w, r) + assert.Equal(t, http.StatusBadRequest, w.Code) +} + func TestEchoHTTPError(t *testing.T) { m := http.StatusText(http.StatusBadRequest) he := NewHTTPError(http.StatusBadRequest, m) @@ -385,8 +393,7 @@ func testMethod(t *testing.T, method, path string, e *Echo) { m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) p := reflect.ValueOf(path) h := reflect.ValueOf(func(c *Context) error { - c.String(http.StatusOK, method) - return nil + return c.String(http.StatusOK, method) }) i := interface{}(e) reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h}) diff --git a/router.go b/router.go index 1e6a84a2..8442ca42 100644 --- a/router.go +++ b/router.go @@ -1,12 +1,23 @@ package echo -import "net/http" +import ( + "encoding/binary" + "net/http" +) type ( Router struct { - trees [21]*node - routes []Route - echo *Echo + connectTree *node + deleteTree *node + getTree *node + headTree *node + optionsTree *node + patchTree *node + postTree *node + putTree *node + traceTree *node + routes []Route + echo *Echo } node struct { typ ntype @@ -30,15 +41,17 @@ const ( func NewRouter(e *Echo) (r *Router) { r = &Router{ - // trees: make(map[string]*node), - routes: []Route{}, - echo: e, - } - for _, m := range methods { - r.trees[r.treeIndex(m)] = &node{ - prefix: "", - children: children{}, - } + routes: []Route{}, + echo: e, + connectTree: new(node), + deleteTree: new(node), + getTree: new(node), + headTree: new(node), + optionsTree: new(node), + patchTree: new(node), + postTree: new(node), + putTree: new(node), + traceTree: new(node), } return } @@ -81,7 +94,10 @@ func (r *Router) insert(method, path string, h HandlerFunc, t ntype, pnames []st *e.maxParam = l } - cn := r.trees[r.treeIndex(method)] // Current node as root + cn := r.findTree(method) // Current node as root + if cn == nil { + panic("echo => invalid method") + } search := path for { @@ -208,12 +224,74 @@ func (r *Router) treeIndex(method string) uint8 { } } +func (r *Router) findTree(method string) (n *node) { + switch method[0] { + case 'G': // GET + m := uint32(method[2])<<8 | uint32(method[1])<<16 | uint32(method[0])<<24 + if m == 0x47455400 { + n = r.getTree + } + case 'P': + switch method[1] { + case 'O': // POST + m := binary.BigEndian.Uint32([]byte(method)) + if m == 0x504f5354 { + n = r.postTree + } + case 'U': // PUT + m := uint32(method[2])<<8 | uint32(method[1])<<16 | uint32(method[0])<<24 + if m == 0x50555400 { + n = r.putTree + } + case 'A': // PATCH + m := uint64(method[4])<<24 | uint64(method[3])<<32 | uint64(method[2])<<40 | + uint64(method[1])<<48 | uint64(method[0])<<56 + if m == 0x5041544348000000 { + n = r.patchTree + } + } + case 'D': + m := uint64(method[5])<<16 | uint64(method[4])<<24 | uint64(method[3])<<32 | + uint64(method[2])<<40 | uint64(method[1])<<48 | uint64(method[0])<<56 + if m == 0x44454c4554450000 { + n = r.deleteTree + } + case 'C': + m := uint64(method[6])<<8 | uint64(method[5])<<16 | uint64(method[4])<<24 | + uint64(method[3])<<32 | uint64(method[2])<<40 | uint64(method[1])<<48 | + uint64(method[0])<<56 + if m == 0x434f4e4e45435400 { + n = r.connectTree + } + case 'H': + m := binary.BigEndian.Uint32([]byte(method)) + if m == 0x48454144 { + n = r.headTree + } + case 'O': + m := uint64(method[6])<<8 | uint64(method[5])<<16 | uint64(method[4])<<24 | + uint64(method[3])<<32 | uint64(method[2])<<40 | uint64(method[1])<<48 | + uint64(method[0])<<56 + if m == 0x4f5054494f4e5300 { + n = r.connectTree + } + case 'T': + m := uint64(method[4])<<24 | uint64(method[3])<<32 | uint64(method[2])<<40 | + uint64(method[1])<<48 | uint64(method[0])<<56 + if m == 0x5452414345000000 { + n = r.traceTree + } + } + return +} + func (r *Router) Find(method, path string, ctx *Context) (h HandlerFunc, e *Echo) { - i := r.treeIndex(method) - if i > 20 { + h = notFoundHandler + cn := r.findTree(method) // Current node as root + if cn == nil { + h = badRequestHandler return } - cn := r.trees[i] // Current node as root search := path var ( @@ -330,10 +408,8 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { c := r.echo.pool.Get().(*Context) h, _ := r.Find(req.Method, req.URL.Path, c) c.reset(req, w, r.echo) - if h == nil { - c.Error(NewHTTPError(http.StatusNotFound)) - } else { - h(c) + if err := h(c); err != nil { + r.echo.httpErrorHandler(err, c) } r.echo.pool.Put(c) } diff --git a/router_test.go b/router_test.go index caaadd74..1f272791 100644 --- a/router_test.go +++ b/router_test.go @@ -384,7 +384,10 @@ func TestRouterMultiRoute(t *testing.T) { // Route > /user h, _ = r.Find(GET, "/user", c) - assert.Nil(t, h) + if assert.IsType(t, new(HTTPError), h(c)) { + he := h(c).(*HTTPError) + assert.Equal(t, http.StatusNotFound, he.code) + } } func TestRouterPriority(t *testing.T) { @@ -537,6 +540,16 @@ func TestRouterAPI(t *testing.T) { } } +func TestRouterAddInvalidMethod(t *testing.T) { + e := New() + r := e.router + assert.Panics(t, func() { + r.Add("INVALID", "/", func(*Context) error { + return nil + }, e) + }) +} + func TestRouterServeHTTP(t *testing.T) { e := New() r := e.router