diff --git a/bolt.go b/bolt.go index 34cb8a9a..5743b4ff 100644 --- a/bolt.go +++ b/bolt.go @@ -3,6 +3,7 @@ package bolt import ( "log" "net/http" + "strings" "sync" ) @@ -140,6 +141,14 @@ func (b *Bolt) Trace(path string, h ...HandlerFunc) { b.Handle("TRACE", path, h) } +// Static serves static files +func (b *Bolt) Static(path, root string) { + fs := http.StripPrefix(strings.TrimSuffix(path, "*"), http.FileServer(http.Dir(root))) + b.Get(path, func(c *Context) { + fs.ServeHTTP(c.Response, c.Request) + }) +} + func (b *Bolt) Handle(method, path string, h []HandlerFunc) { h = append(b.handlers, h...) l := len(h) @@ -153,8 +162,8 @@ func (b *Bolt) Handle(method, path string, h []HandlerFunc) { func (b *Bolt) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // Find and execute handler h, c, s := b.Router.Find(r.Method, r.URL.Path) + c.reset(rw, r) if h != nil { - c.reset(rw, r) h(c) } else { if s == NotFound { diff --git a/example/main.go b/example/main.go index a33ce51d..87462709 100644 --- a/example/main.go +++ b/example/main.go @@ -43,5 +43,6 @@ func main() { b.Post("/users", createUser) b.Get("/users", getUsers) b.Get("/users/:id", getUser) + b.Static("/static/*", "/tmp/") b.Run(":8080") } diff --git a/router.go b/router.go index 0a777d77..e7765dce 100644 --- a/router.go +++ b/router.go @@ -32,7 +32,7 @@ type ( const ( snode ntype = iota // Static node pnode // Param node - wnode // Wildcard node + anode // Catch-all node ) const ( @@ -79,6 +79,8 @@ func (r *router) Add(method, path string, h HandlerFunc) { } else { r.insert(method, path[:i], nil, snode) } + } else if path[i] == '*' { + r.insert(method, path[:i], h, anode) } } r.insert(method, path, h, snode) @@ -206,6 +208,10 @@ func (r *router) Find(method, path string) (handler HandlerFunc, c *Context, s S // All param read continue } + case anode: + // End search + search = "" + continue } e := cn.findEdge(search[0]) if e == nil { diff --git a/router_test.go b/router_test.go index 19bf0f89..28094605 100644 --- a/router_test.go +++ b/router_test.go @@ -24,6 +24,15 @@ func TestParam(t *testing.T) { } } +func TestCatchAll(t *testing.T) { + r := New().Router + r.Add("GET", "/static/*", func(c *Context) {}) + h, _, _ := r.Find("GET", "/static/*") + if h == nil { + t.Fatal("handle not found") + } +} + func TestMicroParam(t *testing.T) { r := New().Router r.Add("GET", "/:a/:b/:c", func(c *Context) {})