From a268e01746596652f225605349b4e029a6b585d3 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Sun, 8 Mar 2015 23:58:10 -0700 Subject: [PATCH] Added serve index and file methods Signed-off-by: Vishal Rana --- bolt.go | 31 +++++++++++++++++++++---------- bolt_test.go | 34 +++++++++++++++++++++++++++++----- context.go | 8 ++++++-- example/main.go | 3 ++- router.go | 19 +++++++++---------- router_test.go | 8 ++++---- 6 files changed, 71 insertions(+), 32 deletions(-) diff --git a/bolt.go b/bolt.go index 5743b4ff..0648a79d 100644 --- a/bolt.go +++ b/bolt.go @@ -3,7 +3,6 @@ package bolt import ( "log" "net/http" - "strings" "sync" ) @@ -21,7 +20,7 @@ type ( ) const ( - MIME_JSON = "application/json" + MIMEJSON = "application/json" HeaderAccept = "Accept" HeaderContentDisposition = "Content-Disposition" @@ -141,14 +140,6 @@ func (b *Bolt) Trace(path string, h ...HandlerFunc) { b.Handle("TRACE", path, h) } -// Static serves static files -func (b *Bolt) Static(path, root string) { - fs := http.StripPrefix(strings.TrimSuffix(path, "*"), http.FileServer(http.Dir(root))) - b.Get(path, func(c *Context) { - fs.ServeHTTP(c.Response, c.Request) - }) -} - func (b *Bolt) Handle(method, path string, h []HandlerFunc) { h = append(b.handlers, h...) l := len(h) @@ -159,6 +150,26 @@ func (b *Bolt) Handle(method, path string, h []HandlerFunc) { }) } +// Static serves static files. +func (b *Bolt) Static(path, root string) { + fs := http.StripPrefix(path, http.FileServer(http.Dir(root))) + b.Get(path+"/*", func(c *Context) { + fs.ServeHTTP(c.Response, c.Request) + }) +} + +// ServeFile serves a file. +func (b *Bolt) ServeFile(path, file string) { + b.Get(path, func(c *Context) { + http.ServeFile(c.Response, c.Request, file) + }) +} + +// Index serves index file. +func (b *Bolt) Index(file string) { + b.ServeFile("/", file) +} + func (b *Bolt) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // Find and execute handler h, c, s := b.Router.Find(r.Method, r.URL.Path) diff --git a/bolt_test.go b/bolt_test.go index b8efa68d..e688924a 100644 --- a/bolt_test.go +++ b/bolt_test.go @@ -4,21 +4,45 @@ import ( "encoding/binary" "encoding/json" "io" + "net/http" + "net/http/httptest" "testing" ) type ( user struct { - Id string + ID string Name string } ) var u = user{ - Id: "1", + ID: "1", Name: "Joe", } +func TestIndex(t *testing.T) { + b := New() + b.Index("example/index.html") + r, _ := http.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + b.ServeHTTP(w, r) + if w.Code != 200 { + t.Errorf("status code should be 200, found %d", w.Code) + } +} + +func TestStatic(t *testing.T) { + b := New() + b.Static("/static", "example/public") + r, _ := http.NewRequest("GET", "/static/main.js", nil) + w := httptest.NewRecorder() + b.ServeHTTP(w, r) + if w.Code != 200 { + t.Errorf("status code should be 200, found %d", w.Code) + } +} + func verifyUser(rd io.Reader, t *testing.T) { var l int64 err := binary.Read(rd, binary.BigEndian, &l) // Body length @@ -32,10 +56,10 @@ func verifyUser(rd io.Reader, t *testing.T) { if err != nil { t.Fatal(err) } - if u2.Id != u.Id { - t.Error("user id should be %s, found %s", u.Id, u2.Id) + if u2.ID != u.ID { + t.Errorf("user id should be %s, found %s", u.ID, u2.ID) } if u2.Name != u.Name { - t.Error("user name should be %s, found %s", u.Name, u2.Name) + t.Errorf("user name should be %s, found %s", u.Name, u2.Name) } } diff --git a/context.go b/context.go index 392a94cd..5d570a35 100644 --- a/context.go +++ b/context.go @@ -34,7 +34,7 @@ func (c *Context) Param(n string) string { func (c *Context) Bind(i interface{}) bool { var err error ct := c.Request.Header.Get(HeaderContentType) - if strings.HasPrefix(ct, MIME_JSON) { + if strings.HasPrefix(ct, MIMEJSON) { dec := json.NewDecoder(c.Request.Body) err = dec.Decode(i) } else { @@ -52,13 +52,17 @@ func (c *Context) Bind(i interface{}) bool { //************ func (c *Context) JSON(n int, i interface{}) { enc := json.NewEncoder(c.Response) - c.Response.Header().Set(HeaderContentType, MIME_JSON+"; charset=utf-8") + c.Response.Header().Set(HeaderContentType, MIMEJSON+"; charset=utf-8") c.Response.WriteHeader(n) if err := enc.Encode(i); err != nil { c.bolt.internalServerErrorHandler(c) } } +func (c *Context) File(n int, file, name string) { + +} + // Next executes the next handler in the chain. func (c *Context) Next() { c.i++ diff --git a/example/main.go b/example/main.go index 87462709..cb60ffca 100644 --- a/example/main.go +++ b/example/main.go @@ -40,9 +40,10 @@ func getUser(c *bolt.Context) { func main() { b := bolt.New() + b.Index("public/index.html") b.Post("/users", createUser) b.Get("/users", getUsers) b.Get("/users/:id", getUser) - b.Static("/static/*", "/tmp/") + b.Static("/static", "/tmp") b.Run(":8080") } diff --git a/router.go b/router.go index e7765dce..0f39689b 100644 --- a/router.go +++ b/router.go @@ -76,9 +76,8 @@ func (r *router) Add(method, path string, h HandlerFunc) { if i == l { r.insert(method, path[:i], h, snode) return - } else { - r.insert(method, path[:i], nil, snode) } + r.insert(method, path[:i], nil, snode) } else if path[i] == '*' { r.insert(method, path[:i], h, anode) } @@ -205,12 +204,14 @@ func (r *router) Find(method, path string) (handler HandlerFunc, c *Context, s S search = search[i:] if i == l { - // All param read + // All params read continue } case anode: - // End search - search = "" + p := c.params[:n+1] + p[n].Name = "_name" + p[n].Value = search + search = "" // End search continue } e := cn.findEdge(search[0]) @@ -218,10 +219,9 @@ func (r *router) Find(method, path string) (handler HandlerFunc, c *Context, s S // Not found s = NotFound return - } else { - cn = e - continue } + cn = e + continue } else { // Not found s = NotFound @@ -288,7 +288,6 @@ func (n *node) printTree(pfx string, tail bool) { func prefix(tail bool, p, on, off string) string { if tail { return fmt.Sprintf("%s%s", p, on) - } else { - return fmt.Sprintf("%s%s", p, off) } + return fmt.Sprintf("%s%s", p, off) } diff --git a/router_test.go b/router_test.go index 28094605..e9178a0e 100644 --- a/router_test.go +++ b/router_test.go @@ -2,7 +2,7 @@ package bolt import "testing" -func TestStatic(t *testing.T) { +func TestStaticRoute(t *testing.T) { r := New().Router r.Add("GET", "/users/joe/books", func(c *Context) {}) h, _, _ := r.Find("GET", "/users/joe/books") @@ -11,7 +11,7 @@ func TestStatic(t *testing.T) { } } -func TestParam(t *testing.T) { +func TestParamRoute(t *testing.T) { r := New().Router r.Add("GET", "/users/:name", func(c *Context) {}) h, c, _ := r.Find("GET", "/users/joe") @@ -24,7 +24,7 @@ func TestParam(t *testing.T) { } } -func TestCatchAll(t *testing.T) { +func TestCatchAllRoute(t *testing.T) { r := New().Router r.Add("GET", "/static/*", func(c *Context) {}) h, _, _ := r.Find("GET", "/static/*") @@ -33,7 +33,7 @@ func TestCatchAll(t *testing.T) { } } -func TestMicroParam(t *testing.T) { +func TestMicroParamRoute(t *testing.T) { r := New().Router r.Add("GET", "/:a/:b/:c", func(c *Context) {}) h, c, _ := r.Find("GET", "/a/b/c")