1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

- Moved static file serving to a new handler package

- Middleware at route level
- Group middleware is a in closure now

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-02-15 08:11:29 -08:00
parent 6bb871fe3a
commit 51acf465fe
17 changed files with 485 additions and 609 deletions

View File

@ -3,7 +3,10 @@ package echo
import ( import (
"encoding/json" "encoding/json"
"encoding/xml" "encoding/xml"
"io"
"mime"
"net/http" "net/http"
"os"
"path/filepath" "path/filepath"
"time" "time"
@ -42,10 +45,11 @@ type (
JSONP(int, string, interface{}) error JSONP(int, string, interface{}) error
XML(int, interface{}) error XML(int, interface{}) error
XMLBlob(int, []byte) error XMLBlob(int, []byte) error
File(string, string, bool) error Attachment(string) error
NoContent(int) error NoContent(int) error
Redirect(int, string) error Redirect(int, string) error
Error(err error) Error(err error)
Handle(Context) error
Logger() logger.Logger Logger() logger.Logger
Object() *context Object() *context
} }
@ -59,12 +63,17 @@ type (
pvalues []string pvalues []string
query url.Values query url.Values
store store store store
handler Handler
echo *Echo echo *Echo
} }
store map[string]interface{} store map[string]interface{}
) )
const (
indexPage = "index.html"
)
// NewContext creates a Context object. // NewContext creates a Context object.
func NewContext(req engine.Request, res engine.Response, e *Echo) Context { func NewContext(req engine.Request, res engine.Response, e *Echo) Context {
return &context{ return &context{
@ -73,9 +82,14 @@ func NewContext(req engine.Request, res engine.Response, e *Echo) Context {
echo: e, echo: e,
pvalues: make([]string, *e.maxParam), pvalues: make([]string, *e.maxParam),
store: make(store), store: make(store),
handler: notFoundHandler,
} }
} }
func (c *context) Handle(ctx Context) error {
return c.handler.Handle(ctx)
}
func (c *context) Deadline() (deadline time.Time, ok bool) { func (c *context) Deadline() (deadline time.Time, ok bool) {
return return
} }
@ -166,7 +180,7 @@ func (c *context) Bind(i interface{}) error {
// code. Templates can be registered using `Echo.SetRenderer()`. // code. Templates can be registered using `Echo.SetRenderer()`.
func (c *context) Render(code int, name string, data interface{}) (err error) { func (c *context) Render(code int, name string, data interface{}) (err error) {
if c.echo.renderer == nil { if c.echo.renderer == nil {
return RendererNotRegistered return ErrRendererNotRegistered
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err = c.echo.renderer.Render(buf, name, data); err != nil { if err = c.echo.renderer.Render(buf, name, data); err != nil {
@ -250,17 +264,17 @@ func (c *context) XMLBlob(code int, b []byte) (err error) {
return return
} }
// File sends a response with the content of the file. If `attachment` is set // Attachment sends specified file as an attachment to the client.
// to true, the client is prompted to save the file with provided `name`, func (c *context) Attachment(file string) (err error) {
// name can be empty, in that case name of the file is used. f, err := os.Open(file)
func (c *context) File(path, name string, attachment bool) (err error) { if err != nil {
dir, file := filepath.Split(path) return
if attachment {
c.response.Header().Set(ContentDisposition, "attachment; filename="+name)
}
if err = c.echo.serveFile(dir, file, c); err != nil {
c.response.Header().Del(ContentDisposition)
} }
_, name := filepath.Split(file)
c.response.Header().Set(ContentDisposition, "attachment; filename="+name)
c.response.Header().Set(ContentType, c.detectContentType(file))
c.response.WriteHeader(http.StatusOK)
_, err = io.Copy(c.response, f)
return return
} }
@ -273,7 +287,7 @@ func (c *context) NoContent(code int) error {
// Redirect redirects the request using http.Redirect with status code. // Redirect redirects the request using http.Redirect with status code.
func (c *context) Redirect(code int, url string) error { func (c *context) Redirect(code int, url string) error {
if code < http.StatusMultipleChoices || code > http.StatusTemporaryRedirect { if code < http.StatusMultipleChoices || code > http.StatusTemporaryRedirect {
return InvalidRedirectCode return ErrInvalidRedirectCode
} }
// TODO: v2 // TODO: v2
// http.Redirect(c.response, c.request, url, code) // http.Redirect(c.response, c.request, url, code)
@ -295,10 +309,16 @@ func (c *context) Object() *context {
return c return c
} }
func (c *context) reset(req engine.Request, res engine.Response, e *Echo) { func (c *context) detectContentType(name string) (t string) {
if t = mime.TypeByExtension(filepath.Ext(name)); t == "" {
t = OctetStream
}
return
}
func (c *context) reset(req engine.Request, res engine.Response) {
c.request = req c.request = req
c.response = res c.response = res
c.query = nil c.query = nil
c.store = nil c.store = nil
c.echo = e
} }

View File

@ -157,22 +157,13 @@ func TestContext(t *testing.T) {
assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String()) assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String())
} }
// File // Attachment
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = NewContext(req, rec, e) c = NewContext(req, rec, e)
err = c.File("_fixture/images/walle.png", "", false) err = c.Attachment("_fixture/images/walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, rec.Header().Get(ContentDisposition), "attachment; filename=walle.png")
}
// File as attachment
rec = test.NewResponseRecorder()
c = NewContext(req, rec, e)
err = c.File("_fixture/images/walle.png", "WALLE.PNG", true)
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, rec.Header().Get(ContentDisposition), "attachment; filename=WALLE.PNG")
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
@ -194,7 +185,7 @@ func TestContext(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, c.Response().Status()) assert.Equal(t, http.StatusInternalServerError, c.Response().Status())
// reset // reset
c.Object().reset(req, test.NewResponseRecorder(), e) c.Object().reset(req, test.NewResponseRecorder())
} }
func TestContextPath(t *testing.T) { func TestContextPath(t *testing.T) {
@ -263,7 +254,7 @@ func testBindError(t *testing.T, c Context, ct string) {
} }
default: default:
if assert.IsType(t, new(HTTPError), err) { if assert.IsType(t, new(HTTPError), err) {
assert.Equal(t, UnsupportedMediaType, err) assert.Equal(t, ErrUnsupportedMediaType, err)
} }
} }

269
echo.go
View File

@ -7,8 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"path"
"path/filepath"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
@ -24,7 +22,7 @@ import (
type ( type (
Echo struct { Echo struct {
prefix string prefix string
middleware []MiddlewareFunc middleware []Middleware
http2 bool http2 bool
maxParam *int maxParam *int
notFoundHandler HandlerFunc notFoundHandler HandlerFunc
@ -33,8 +31,6 @@ type (
renderer Renderer renderer Renderer
pool sync.Pool pool sync.Pool
debug bool debug bool
hook engine.HandlerFunc
autoIndex bool
router *Router router *Router
logger logger.Logger logger logger.Logger
} }
@ -51,10 +47,11 @@ type (
} }
Middleware interface { Middleware interface {
Process(HandlerFunc) HandlerFunc Handle(Handler) Handler
Priority() int
} }
MiddlewareFunc func(HandlerFunc) HandlerFunc MiddlewareFunc func(Handler) Handler
Handler interface { Handler interface {
Handle(Context) error Handle(Context) error
@ -122,6 +119,7 @@ const (
TextPlain = "text/plain" TextPlain = "text/plain"
TextPlainCharsetUTF8 = TextPlain + "; " + CharsetUTF8 TextPlainCharsetUTF8 = TextPlain + "; " + CharsetUTF8
MultipartForm = "multipart/form-data" MultipartForm = "multipart/form-data"
OctetStream = "application/octet-stream"
//--------- //---------
// Charset // Charset
@ -150,8 +148,6 @@ const (
//----------- //-----------
WebSocket = "websocket" WebSocket = "websocket"
indexPage = "index.html"
) )
var ( var (
@ -171,9 +167,10 @@ var (
// Errors // Errors
//-------- //--------
UnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType)
RendererNotRegistered = errors.New("renderer not registered") ErrNotFound = NewHTTPError(http.StatusNotFound)
InvalidRedirectCode = errors.New("invalid redirect status code") ErrRendererNotRegistered = errors.New("renderer not registered")
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
//---------------- //----------------
// Error handlers // Error handlers
@ -196,6 +193,7 @@ func New() (e *Echo) {
return NewContext(nil, nil, e) return NewContext(nil, nil, e)
} }
e.router = NewRouter(e) e.router = NewRouter(e)
e.middleware = []Middleware{e.router}
//---------- //----------
// Defaults // Defaults
@ -211,10 +209,14 @@ func New() (e *Echo) {
return return
} }
func (f MiddlewareFunc) Process(h HandlerFunc) HandlerFunc { func (f MiddlewareFunc) Handle(h Handler) Handler {
return f(h) return f(h)
} }
func (f MiddlewareFunc) Priority() int {
return 1
}
func (f HandlerFunc) Handle(c Context) error { func (f HandlerFunc) Handle(c Context) error {
return f(c) return f(c)
} }
@ -281,18 +283,6 @@ func (e *Echo) Debug() bool {
return e.debug return e.debug
} }
// AutoIndex enable/disable automatically creating an index page for the directory.
func (e *Echo) AutoIndex(on bool) {
e.autoIndex = on
}
// Hook registers a callback which is invoked from `Echo#ServerHTTP` as the first
// statement. Hook is useful if you want to modify response/response objects even
// before it hits the router or any middleware.
func (e *Echo) Hook(h engine.HandlerFunc) {
e.hook = h
}
// Use adds handler to the middleware chain. // Use adds handler to the middleware chain.
func (e *Echo) Use(middleware ...interface{}) { func (e *Echo) Use(middleware ...interface{}) {
for _, m := range middleware { for _, m := range middleware {
@ -301,190 +291,99 @@ func (e *Echo) Use(middleware ...interface{}) {
} }
// Connect adds a CONNECT route > handler to the router. // Connect adds a CONNECT route > handler to the router.
func (e *Echo) Connect(path string, handler interface{}) { func (e *Echo) Connect(path string, handler interface{}, middleware ...interface{}) {
e.add(CONNECT, path, handler) e.add(CONNECT, path, handler, middleware...)
} }
// Delete adds a DELETE route > handler to the router. // Delete adds a DELETE route > handler to the router.
func (e *Echo) Delete(path string, handler interface{}) { func (e *Echo) Delete(path string, handler interface{}, middleware ...interface{}) {
e.add(DELETE, path, handler) e.add(DELETE, path, handler, middleware...)
} }
// Get adds a GET route > handler to the router. // Get adds a GET route > handler to the router.
func (e *Echo) Get(path string, handler interface{}) { func (e *Echo) Get(path string, handler interface{}, middleware ...interface{}) {
e.add(GET, path, handler) e.add(GET, path, handler, middleware...)
} }
// Head adds a HEAD route > handler to the router. // Head adds a HEAD route > handler to the router.
func (e *Echo) Head(path string, handler interface{}) { func (e *Echo) Head(path string, handler interface{}, middleware ...interface{}) {
e.add(HEAD, path, handler) e.add(HEAD, path, handler, middleware...)
} }
// Options adds an OPTIONS route > handler to the router. // Options adds an OPTIONS route > handler to the router.
func (e *Echo) Options(path string, handler interface{}) { func (e *Echo) Options(path string, handler interface{}, middleware ...interface{}) {
e.add(OPTIONS, path, handler) e.add(OPTIONS, path, handler, middleware...)
} }
// Patch adds a PATCH route > handler to the router. // Patch adds a PATCH route > handler to the router.
func (e *Echo) Patch(path string, handler interface{}) { func (e *Echo) Patch(path string, handler interface{}, middleware ...interface{}) {
e.add(PATCH, path, handler) e.add(PATCH, path, handler, middleware...)
} }
// Post adds a POST route > handler to the router. // Post adds a POST route > handler to the router.
func (e *Echo) Post(path string, handler interface{}) { func (e *Echo) Post(path string, handler interface{}, middleware ...interface{}) {
e.add(POST, path, handler) e.add(POST, path, handler, middleware...)
} }
// Put adds a PUT route > handler to the router. // Put adds a PUT route > handler to the router.
func (e *Echo) Put(path string, handler interface{}) { func (e *Echo) Put(path string, handler interface{}, middleware ...interface{}) {
e.add(PUT, path, handler) e.add(PUT, path, handler, middleware...)
} }
// Trace adds a TRACE route > handler to the router. // Trace adds a TRACE route > handler to the router.
func (e *Echo) Trace(path string, handler interface{}) { func (e *Echo) Trace(path string, handler interface{}, middleware ...interface{}) {
e.add(TRACE, path, handler) e.add(TRACE, path, handler, middleware...)
} }
// Any adds a route > handler to the router for all HTTP methods. // Any adds a route > handler to the router for all HTTP methods.
func (e *Echo) Any(path string, handler interface{}) { func (e *Echo) Any(path string, handler interface{}, middleware ...interface{}) {
for _, m := range methods { for _, m := range methods {
e.add(m, path, handler) e.add(m, path, handler, middleware...)
} }
} }
// Match adds a route > handler to the router for multiple HTTP methods provided. // Match adds a route > handler to the router for multiple HTTP methods provided.
func (e *Echo) Match(methods []string, path string, handler interface{}) { func (e *Echo) Match(methods []string, path string, handler interface{}, middleware ...interface{}) {
for _, m := range methods { for _, m := range methods {
e.add(m, path, handler) e.add(m, path, handler, middleware...)
} }
} }
// NOTE: v2 // NOTE: v2
func (e *Echo) add(method, path string, h interface{}) { func (e *Echo) add(method, path string, handler interface{}, middleware ...interface{}) {
path = e.prefix + path h := wrapHandler(handler)
e.router.Add(method, path, wrapHandler(h), e) name := handlerName(handler)
e.router.Add(method, path, HandlerFunc(func(c Context) error {
for _, m := range middleware {
h = wrapMiddleware(m).Handle(h)
}
return h.Handle(c)
}), e)
r := Route{ r := Route{
Method: method, Method: method,
Path: path, Path: path,
Handler: runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), Handler: name,
} }
e.router.routes = append(e.router.routes, r) e.router.routes = append(e.router.routes, r)
} }
// Index serves index file. // Group creates a new sub-router with prefix.
func (e *Echo) Index(file string) { func (e *Echo) Group(prefix string, middleware ...interface{}) (g *Group) {
e.ServeFile("/", file) g = &Group{prefix: prefix, echo: e}
} g.Use(middleware...)
// Favicon serves the default favicon - GET /favicon.ico.
func (e *Echo) Favicon(file string) {
e.ServeFile("/favicon.ico", file)
}
// Static serves static files from a directory. It's an alias for `Echo.ServeDir`
func (e *Echo) Static(path, dir string) {
e.ServeDir(path, dir)
}
// ServeDir serves files from a directory.
func (e *Echo) ServeDir(path, dir string) {
e.Get(path+"*", func(c Context) error {
return e.serveFile(dir, c.P(0), c) // Param `_*`
})
}
// ServeFile serves a file.
func (e *Echo) ServeFile(path, file string) {
e.Get(path, func(c Context) error {
dir, file := filepath.Split(file)
return e.serveFile(dir, file, c)
})
}
func (e *Echo) serveFile(dir, file string, c Context) (err error) {
fs := http.Dir(dir)
f, err := fs.Open(file)
if err != nil {
return NewHTTPError(http.StatusNotFound)
}
defer f.Close()
fi, _ := f.Stat()
if fi.IsDir() {
/* NOTE:
Not checking the Last-Modified header as it caches the response `304` when
changing differnt directories for the same path.
*/
d := f
// Index file
file = path.Join(file, indexPage)
f, err = fs.Open(file)
if err != nil {
if e.autoIndex {
// Auto index
return listDir(d, c)
}
return NewHTTPError(http.StatusForbidden)
}
fi, _ = f.Stat() // Index file stat
}
c.Response().WriteHeader(http.StatusOK)
io.Copy(c.Response(), f)
// TODO:
// http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
return return
} }
func listDir(d http.File, c Context) (err error) {
dirs, err := d.Readdir(-1)
if err != nil {
return err
}
// Create directory index
w := c.Response()
w.Header().Set(ContentType, TextHTMLCharsetUTF8)
fmt.Fprintf(w, "<pre>\n")
for _, d := range dirs {
name := d.Name()
color := "#212121"
if d.IsDir() {
color = "#e91e63"
name += "/"
}
fmt.Fprintf(w, "<a href=\"%s\" style=\"color: %s;\">%s</a>\n", name, color, name)
}
fmt.Fprintf(w, "</pre>\n")
return
}
// Group creates a new sub router with prefix. It inherits all properties from
// the parent. Passing middleware overrides parent middleware.
func (e *Echo) Group(prefix string, m ...MiddlewareFunc) *Group {
g := &Group{*e}
g.echo.prefix += prefix
if len(m) == 0 {
mw := make([]MiddlewareFunc, len(g.echo.middleware))
copy(mw, g.echo.middleware)
g.echo.middleware = mw
} else {
g.echo.middleware = nil
g.Use(m...)
}
return g
}
// URI generates a URI from handler. // URI generates a URI from handler.
func (e *Echo) URI(h HandlerFunc, params ...interface{}) string { func (e *Echo) URI(handler interface{}, params ...interface{}) string {
uri := new(bytes.Buffer) uri := new(bytes.Buffer)
pl := len(params) ln := len(params)
n := 0 n := 0
hn := runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() name := handlerName(handler)
for _, r := range e.router.routes { for _, r := range e.router.routes {
if r.Handler == hn { if r.Handler == name {
for i, l := 0, len(r.Path); i < l; i++ { for i, l := 0, len(r.Path); i < l; i++ {
if r.Path[i] == ':' && n < pl { if r.Path[i] == ':' && n < ln {
for ; i < l && r.Path[i] != '/'; i++ { for ; i < l && r.Path[i] != '/'; i++ {
} }
uri.WriteString(fmt.Sprintf("%v", params[n])) uri.WriteString(fmt.Sprintf("%v", params[n]))
@ -501,8 +400,8 @@ func (e *Echo) URI(h HandlerFunc, params ...interface{}) string {
} }
// URL is an alias for `URI` function. // URL is an alias for `URI` function.
func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { func (e *Echo) URL(handler interface{}, params ...interface{}) string {
return e.URI(h, params...) return e.URI(handler, params...)
} }
// Routes returns the registered routes. // Routes returns the registered routes.
@ -511,37 +410,23 @@ func (e *Echo) Routes() []Route {
} }
func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) {
if e.hook != nil {
e.hook(req, res)
}
c := e.pool.Get().(*context) c := e.pool.Get().(*context)
h, e := e.router.Find(req.Method(), req.URL().Path(), c) c.reset(req, res)
c.reset(req, res, e) h := Handler(c)
// Chain middleware with handler in the end // Chain middleware with handler in the end
for i := len(e.middleware) - 1; i >= 0; i-- { for i := len(e.middleware) - 1; i >= 0; i-- {
h = e.middleware[i](h) h = e.middleware[i].Handle(h)
} }
// Execute chain // Execute chain
if err := h(c); err != nil { if err := h.Handle(c); err != nil {
e.httpErrorHandler(err, c) e.httpErrorHandler(err, c)
} }
e.pool.Put(c) e.pool.Put(c)
} }
// Server returns the internal *http.Server.
// func (e *Echo) Server(addr string) *http.Server {
// s := &http.Server{Addr: addr, Handler: e}
// // TODO: Remove in Go 1.6+
// if e.http2 {
// http2.ConfigureServer(s, nil)
// }
// return s
// }
// Run starts the HTTP engine. // Run starts the HTTP engine.
func (e *Echo) Run(eng engine.Engine) { func (e *Echo) Run(eng engine.Engine) {
eng.SetHandler(e.ServeHTTP) eng.SetHandler(e.ServeHTTP)
@ -575,7 +460,7 @@ func (e *HTTPError) Error() string {
func (binder) Bind(r engine.Request, i interface{}) (err error) { func (binder) Bind(r engine.Request, i interface{}) (err error) {
ct := r.Header().Get(ContentType) ct := r.Header().Get(ContentType)
err = UnsupportedMediaType err = ErrUnsupportedMediaType
if strings.HasPrefix(ct, ApplicationJSON) { if strings.HasPrefix(ct, ApplicationJSON) {
if err = json.NewDecoder(r.Body()).Decode(i); err != nil { if err = json.NewDecoder(r.Body()).Decode(i); err != nil {
err = NewHTTPError(http.StatusBadRequest, err.Error()) err = NewHTTPError(http.StatusBadRequest, err.Error())
@ -588,28 +473,40 @@ func (binder) Bind(r engine.Request, i interface{}) (err error) {
return return
} }
func wrapMiddleware(m interface{}) MiddlewareFunc { func wrapMiddleware(m interface{}) Middleware {
switch m := m.(type) { switch m := m.(type) {
case Middleware: case Middleware:
return m.Process return m
case MiddlewareFunc: case MiddlewareFunc:
return m return m
case func(HandlerFunc) HandlerFunc: case func(Handler) Handler:
return m return MiddlewareFunc(m)
default: default:
panic("invalid middleware") panic("invalid middleware")
} }
} }
func wrapHandler(h interface{}) HandlerFunc { func wrapHandler(h interface{}) Handler {
switch h := h.(type) { switch h := h.(type) {
case Handler: case Handler:
return h.Handle return h
case HandlerFunc: case HandlerFunc:
return h return h
case func(Context) error: case func(Context) error:
return h return HandlerFunc(h)
default: default:
panic("invalid handler") panic("echo => invalid handler")
}
}
func handlerName(h interface{}) string {
switch h := h.(type) {
case Handler:
t := reflect.TypeOf(h)
return fmt.Sprintf("%s » %s", t.PkgPath(), t.Name())
case HandlerFunc, func(Context) error:
return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name()
default:
panic("echo => invalid handler")
} }
} }

View File

@ -11,7 +11,6 @@ import (
"errors" "errors"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/test" "github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -41,77 +40,29 @@ func TestEcho(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rec.Status())
} }
func TestEchoIndex(t *testing.T) {
e := New()
e.Index("_fixture/index.html")
c, b := request(GET, "/", e)
assert.Equal(t, http.StatusOK, c)
assert.NotEmpty(t, b)
}
func TestEchoFavicon(t *testing.T) {
e := New()
e.Favicon("_fixture/favicon.ico")
c, b := request(GET, "/favicon.ico", e)
assert.Equal(t, http.StatusOK, c)
assert.NotEmpty(t, b)
}
func TestEchoStatic(t *testing.T) {
e := New()
// OK
e.Static("/images", "_fixture/images")
c, b := request(GET, "/images/walle.png", e)
assert.Equal(t, http.StatusOK, c)
assert.NotEmpty(t, b)
// No file
e.Static("/images", "_fixture/scripts")
c, _ = request(GET, "/images/bolt.png", e)
assert.Equal(t, http.StatusNotFound, c)
// Directory
e.Static("/images", "_fixture/images")
c, _ = request(GET, "/images", e)
assert.Equal(t, http.StatusForbidden, c)
// Directory with index.html
e.Static("/", "_fixture")
c, r := request(GET, "/", e)
assert.Equal(t, http.StatusOK, c)
assert.Equal(t, true, strings.HasPrefix(r, "<!doctype html>"))
// Sub-directory with index.html
c, r = request(GET, "/folder", e)
assert.Equal(t, http.StatusOK, c)
assert.Equal(t, true, strings.HasPrefix(r, "<!doctype html>"))
// assert.Equal(t, "sub directory", r)
}
func TestEchoMiddleware(t *testing.T) { func TestEchoMiddleware(t *testing.T) {
e := New() e := New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Use(func(h HandlerFunc) HandlerFunc { e.Use(func(h Handler) Handler {
return func(c Context) error { return HandlerFunc(func(c Context) error {
buf.WriteString("a") buf.WriteString("a")
return h(c) return h.Handle(c)
} })
}) })
e.Use(func(h HandlerFunc) HandlerFunc { e.Use(func(h Handler) Handler {
return func(c Context) error { return HandlerFunc(func(c Context) error {
buf.WriteString("b") buf.WriteString("b")
return h(c) return h.Handle(c)
} })
}) })
e.Use(func(h HandlerFunc) HandlerFunc { e.Use(func(h Handler) Handler {
return func(c Context) error { return HandlerFunc(func(c Context) error {
buf.WriteString("c") buf.WriteString("c")
return h(c) return h.Handle(c)
} })
}) })
// Route // Route
@ -125,10 +76,10 @@ func TestEchoMiddleware(t *testing.T) {
assert.Equal(t, "OK", b) assert.Equal(t, "OK", b)
// Error // Error
e.Use(func(h HandlerFunc) HandlerFunc { e.Use(func(Handler) Handler {
return func(c Context) error { return HandlerFunc(func(c Context) error {
return errors.New("error") return errors.New("error")
} })
}) })
c, b = request(GET, "/", e) c, b = request(GET, "/", e)
assert.Equal(t, http.StatusInternalServerError, c) assert.Equal(t, http.StatusInternalServerError, c)
@ -138,9 +89,9 @@ func TestEchoHandler(t *testing.T) {
e := New() e := New()
// HandlerFunc // HandlerFunc
e.Get("/ok", HandlerFunc(func(c Context) error { e.Get("/ok", func(c Context) error {
return c.String(http.StatusOK, "OK") return c.String(http.StatusOK, "OK")
})) })
c, b := request(GET, "/ok", e) c, b := request(GET, "/ok", e)
assert.Equal(t, http.StatusOK, c) assert.Equal(t, http.StatusOK, c)
@ -208,7 +159,6 @@ func TestEchoMatch(t *testing.T) { // JFC
func TestEchoURL(t *testing.T) { func TestEchoURL(t *testing.T) {
e := New() e := New()
static := func(Context) error { return nil } static := func(Context) error { return nil }
getUser := func(Context) error { return nil } getUser := func(Context) error { return nil }
getFile := func(Context) error { return nil } getFile := func(Context) error { return nil }
@ -248,11 +198,11 @@ func TestEchoRoutes(t *testing.T) {
func TestEchoGroup(t *testing.T) { func TestEchoGroup(t *testing.T) {
e := New() e := New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Use(func(h HandlerFunc) HandlerFunc { e.Use(func(h Handler) Handler {
return func(c Context) error { return HandlerFunc(func(c Context) error {
buf.WriteString("0") buf.WriteString("0")
return h(c) return h.Handle(c)
} })
}) })
h := func(c Context) error { h := func(c Context) error {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
@ -266,27 +216,18 @@ func TestEchoGroup(t *testing.T) {
// Group // Group
g1 := e.Group("/group1") g1 := e.Group("/group1")
g1.Use(func(h HandlerFunc) HandlerFunc { g1.Use(func(h Handler) Handler {
return func(c Context) error { return HandlerFunc(func(c Context) error {
buf.WriteString("1") buf.WriteString("1")
return h(c) return h.Handle(c)
} })
}) })
g1.Get("/", h) g1.Get("/", h)
// Group with no parent middleware
g2 := e.Group("/group2", func(h HandlerFunc) HandlerFunc {
return func(c Context) error {
buf.WriteString("2")
return h(c)
}
})
g2.Get("/", h)
// Nested groups // Nested groups
g3 := e.Group("/group3") g2 := e.Group("/group2")
g4 := g3.Group("/group4") g3 := g2.Group("/group3")
g4.Get("/", h) g3.Get("/", h)
request(GET, "/users", e) request(GET, "/users", e)
assert.Equal(t, "0", buf.String()) assert.Equal(t, "0", buf.String())
@ -296,11 +237,7 @@ func TestEchoGroup(t *testing.T) {
assert.Equal(t, "01", buf.String()) assert.Equal(t, "01", buf.String())
buf.Reset() buf.Reset()
request(GET, "/group2/", e) c, _ := request(GET, "/group2/group3/", e)
assert.Equal(t, "2", buf.String())
buf.Reset()
c, _ := request(GET, "/group3/group4/", e)
assert.Equal(t, http.StatusOK, c) assert.Equal(t, http.StatusOK, c)
} }
@ -330,30 +267,6 @@ func TestEchoHTTPError(t *testing.T) {
assert.Equal(t, m, he.Error()) assert.Equal(t, m, he.Error())
} }
func TestEchoServer(t *testing.T) {
// e := New()
// s := e.Server(":1323")
// assert.IsType(t, &http.Server{}, s)
}
func TestEchoHook(t *testing.T) {
e := New()
e.Get("/test", func(c Context) error {
return c.NoContent(http.StatusNoContent)
})
e.Hook(func(req engine.Request, res engine.Response) {
path := req.URL().Path()
l := len(path) - 1
if path != "/" && path[l] == '/' {
req.URL().SetPath(path[:l])
}
})
req := test.NewRequest(GET, "/test/", nil)
rec := test.NewResponseRecorder()
e.ServeHTTP(req, rec)
assert.Equal(t, req.URL().Path(), "/test")
}
func testMethod(t *testing.T, method, path string, e *Echo) { func testMethod(t *testing.T, method, path string, e *Echo) {
m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:])) m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:]))
p := reflect.ValueOf(path) p := reflect.ValueOf(path)

View File

@ -2,64 +2,72 @@ package echo
type ( type (
Group struct { Group struct {
echo Echo prefix string
middleware []Middleware
echo *Echo
} }
) )
func (g *Group) Use(m ...MiddlewareFunc) { func (g *Group) Use(middleware ...interface{}) {
for _, h := range m { for _, m := range middleware {
g.echo.middleware = append(g.echo.middleware, h) g.middleware = append(g.middleware, wrapMiddleware(m))
} }
} }
func (g *Group) Connect(path string, h HandlerFunc) { func (g *Group) Connect(path string, handler interface{}) {
g.echo.Connect(path, h) g.add(CONNECT, path, handler)
} }
func (g *Group) Delete(path string, h HandlerFunc) { func (g *Group) Delete(path string, handler interface{}) {
g.echo.Delete(path, h) g.add(DELETE, path, handler)
} }
func (g *Group) Get(path string, h HandlerFunc) { func (g *Group) Get(path string, handler interface{}) {
g.echo.Get(path, h) g.add(GET, path, handler)
} }
func (g *Group) Head(path string, h HandlerFunc) { func (g *Group) Head(path string, handler interface{}) {
g.echo.Head(path, h) g.add(HEAD, path, handler)
} }
func (g *Group) Options(path string, h HandlerFunc) { func (g *Group) Options(path string, handler interface{}) {
g.echo.Options(path, h) g.add(OPTIONS, path, handler)
} }
func (g *Group) Patch(path string, h HandlerFunc) { func (g *Group) Patch(path string, handler interface{}) {
g.echo.Patch(path, h) g.add(PATCH, path, handler)
} }
func (g *Group) Post(path string, h HandlerFunc) { func (g *Group) Post(path string, handler interface{}) {
g.echo.Post(path, h) g.add(POST, path, handler)
} }
func (g *Group) Put(path string, h HandlerFunc) { func (g *Group) Put(path string, handler interface{}) {
g.echo.Put(path, h) g.add(PUT, path, handler)
} }
func (g *Group) Trace(path string, h HandlerFunc) { func (g *Group) Trace(path string, handler interface{}) {
g.echo.Trace(path, h) g.add(TRACE, path, handler)
} }
func (g *Group) Static(path, root string) { func (g *Group) Group(prefix string, middleware ...interface{}) *Group {
g.echo.Static(path, root) return g.echo.Group(prefix, middleware...)
} }
func (g *Group) ServeDir(path, root string) { func (g *Group) add(method, path string, handler interface{}) {
g.echo.ServeDir(path, root) path = g.prefix + path
} h := wrapHandler(handler)
name := handlerName(handler)
func (g *Group) ServeFile(path, file string) { g.echo.router.Add(method, path, HandlerFunc(func(c Context) error {
g.echo.ServeFile(path, file) for i := len(g.middleware) - 1; i >= 0; i-- {
} h = g.middleware[i].Handle(h)
}
func (g *Group) Group(prefix string, m ...MiddlewareFunc) *Group { return h.Handle(c)
return g.echo.Group(prefix, m...) }), g.echo)
r := Route{
Method: method,
Path: path,
Handler: name,
}
g.echo.router.routes = append(g.echo.router.routes, r)
} }

View File

@ -14,7 +14,4 @@ func TestGroup(t *testing.T) {
g.Post("/", h) g.Post("/", h)
g.Put("/", h) g.Put("/", h)
g.Trace("/", h) g.Trace("/", h)
g.Static("/scripts", "scripts")
g.ServeDir("/scripts", "scripts")
g.ServeFile("/scripts/main.js", "scripts/main.js")
} }

92
handler/static.go Normal file
View File

@ -0,0 +1,92 @@
package handler
import (
"fmt"
"io"
"net/http"
"path"
"github.com/labstack/echo"
)
type (
Static struct {
Root string
Browse bool
Index string
}
FaviconOptions struct {
}
)
func (s Static) Handle(c echo.Context) error {
fs := http.Dir(s.Root)
file := c.P(0)
f, err := fs.Open(file)
if err != nil {
return echo.ErrNotFound
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return err
}
if fi.IsDir() {
/* NOTE:
Not checking the Last-Modified header as it caches the response `304` when
changing differnt directories for the same path.
*/
d := f
// Index file
file = path.Join(file, s.Index)
f, err = fs.Open(file)
if err != nil {
if s.Browse {
dirs, err := d.Readdir(-1)
if err != nil {
return err
}
// Create a directory index
w := c.Response()
w.Header().Set(echo.ContentType, echo.TextHTMLCharsetUTF8)
if _, err = fmt.Fprintf(w, "<pre>\n"); err != nil {
return err
}
for _, d := range dirs {
name := d.Name()
color := "#212121"
if d.IsDir() {
color = "#e91e63"
name += "/"
}
if _, err = fmt.Fprintf(w, "<a href=\"%s\" style=\"color: %s;\">%s</a>\n", name, color, name); err != nil {
return err
}
}
_, err = fmt.Fprintf(w, "</pre>\n")
return err
}
return echo.ErrNotFound
}
fi, _ = f.Stat() // Index file stat
}
c.Response().WriteHeader(http.StatusOK)
io.Copy(c.Response(), f)
return nil
// TODO:
// http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), f)
}
// Favicon serves the default favicon - GET /favicon.ico.
func Favicon(root string, options ...FaviconOptions) echo.MiddlewareFunc {
return func(h echo.Handler) echo.Handler {
return echo.HandlerFunc(func(c echo.Context) error {
return nil
})
}
}

View File

@ -20,8 +20,8 @@ const (
// For valid credentials it calls the next handler. // For valid credentials it calls the next handler.
// For invalid credentials, it sends "401 - Unauthorized" response. // For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicValidateFunc) echo.MiddlewareFunc { func BasicAuth(fn BasicValidateFunc) echo.MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc { return func(h echo.Handler) echo.Handler {
return func(c echo.Context) error { return echo.HandlerFunc(func(c echo.Context) error {
// Skip WebSocket // Skip WebSocket
if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket { if (c.Request().Header().Get(echo.Upgrade)) == echo.WebSocket {
return nil return nil
@ -46,6 +46,6 @@ func BasicAuth(fn BasicValidateFunc) echo.MiddlewareFunc {
} }
c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted") c.Response().Header().Set(echo.WWWAuthenticate, basic+" realm=Restricted")
return echo.NewHTTPError(http.StatusUnauthorized) return echo.NewHTTPError(http.StatusUnauthorized)
} })
} }
} }

View File

@ -21,15 +21,14 @@ func TestBasicAuth(t *testing.T) {
} }
return false return false
} }
h := func(c echo.Context) error { h := BasicAuth(fn)(echo.HandlerFunc(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }))
mw := BasicAuth(fn)(h)
// Valid credentials // Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.Authorization, auth) req.Header().Set(echo.Authorization, auth)
assert.NoError(t, mw(c)) assert.NoError(t, h.Handle(c))
//--------------------- //---------------------
// Invalid credentials // Invalid credentials
@ -38,24 +37,24 @@ func TestBasicAuth(t *testing.T) {
// Incorrect password // Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.Authorization, auth) req.Header().Set(echo.Authorization, auth)
he := mw(c).(*echo.HTTPError) he := h.Handle(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate)) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
// Empty Authorization header // Empty Authorization header
req.Header().Set(echo.Authorization, "") req.Header().Set(echo.Authorization, "")
he = mw(c).(*echo.HTTPError) he = h.Handle(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate)) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
// Invalid Authorization header // Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid")) auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.Authorization, auth) req.Header().Set(echo.Authorization, auth)
he = mw(c).(*echo.HTTPError) he = h.Handle(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate)) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.WWWAuthenticate))
// WebSocket // WebSocket
c.Request().Header().Set(echo.Upgrade, echo.WebSocket) c.Request().Header().Set(echo.Upgrade, echo.WebSocket)
assert.NoError(t, mw(c)) assert.NoError(t, h.Handle(c))
} }

View File

@ -49,10 +49,9 @@ var writerPool = sync.Pool{
// Gzip returns a middleware which compresses HTTP response using gzip compression // Gzip returns a middleware which compresses HTTP response using gzip compression
// scheme. // scheme.
func Gzip() echo.MiddlewareFunc { func Gzip() echo.MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc { return func(h echo.Handler) echo.Handler {
scheme := "gzip" scheme := "gzip"
return echo.HandlerFunc(func(c echo.Context) error {
return func(c echo.Context) error {
c.Response().Header().Add(echo.Vary, echo.AcceptEncoding) c.Response().Header().Add(echo.Vary, echo.AcceptEncoding)
if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) { if strings.Contains(c.Request().Header().Get(echo.AcceptEncoding), scheme) {
w := writerPool.Get().(*gzip.Writer) w := writerPool.Get().(*gzip.Writer)
@ -69,6 +68,6 @@ func Gzip() echo.MiddlewareFunc {
c.Error(err) c.Error(err)
} }
return nil return nil
} })
} }
} }

View File

@ -35,13 +35,12 @@ func TestGzip(t *testing.T) {
req := test.NewRequest(echo.GET, "/", nil) req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
c := echo.NewContext(req, rec, e) c := echo.NewContext(req, rec, e)
h := func(c echo.Context) error { // Skip if no Accept-Encoding header
h := Gzip()(echo.HandlerFunc(func(c echo.Context) error {
c.Response().Write([]byte("test")) // For Content-Type sniffing c.Response().Write([]byte("test")) // For Content-Type sniffing
return nil return nil
} }))
h.Handle(c)
// Skip if no Accept-Encoding header
Gzip()(h)(c)
// assert.Equal(t, http.StatusOK, rec.Status()) // assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "test", rec.Body.String()) assert.Equal(t, "test", rec.Body.String())
@ -51,7 +50,7 @@ func TestGzip(t *testing.T) {
c = echo.NewContext(req, rec, e) c = echo.NewContext(req, rec, e)
// Gzip // Gzip
Gzip()(h)(c) h.Handle(c)
// assert.Equal(t, http.StatusOK, rec.Status()) // assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding)) assert.Equal(t, "gzip", rec.Header().Get(echo.ContentEncoding))
assert.Contains(t, rec.Header().Get(echo.ContentType), echo.TextPlain) assert.Contains(t, rec.Header().Get(echo.ContentType), echo.TextPlain)

View File

@ -9,8 +9,8 @@ import (
) )
func Log() echo.MiddlewareFunc { func Log() echo.MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc { return func(h echo.Handler) echo.Handler {
return func(c echo.Context) error { return echo.HandlerFunc(func(c echo.Context) error {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
logger := c.Logger() logger := c.Logger()
@ -49,6 +49,6 @@ func Log() echo.MiddlewareFunc {
logger.Infof("%s %s %s %s %s %d", remoteAddr, method, path, code, stop.Sub(start), size) logger.Infof("%s %s %s %s %s %d", remoteAddr, method, path, code, stop.Sub(start), size)
return nil return nil
} })
} }
} }

View File

@ -18,38 +18,37 @@ func TestLog(t *testing.T) {
req := test.NewRequest(echo.GET, "/", nil) req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
c := echo.NewContext(req, rec, e) c := echo.NewContext(req, rec, e)
h := func(c echo.Context) error { h := Log()(echo.HandlerFunc(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }))
mw := Log()(h)
// Status 2xx // Status 2xx
mw(c) h.Handle(c)
// Status 3xx // Status 3xx
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = echo.NewContext(req, rec, e) c = echo.NewContext(req, rec, e)
h = func(c echo.Context) error { h = Log()(echo.HandlerFunc(func(c echo.Context) error {
return c.String(http.StatusTemporaryRedirect, "test") return c.String(http.StatusTemporaryRedirect, "test")
} }))
mw(c) h.Handle(c)
// Status 4xx // Status 4xx
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = echo.NewContext(req, rec, e) c = echo.NewContext(req, rec, e)
h = func(c echo.Context) error { h = Log()(echo.HandlerFunc(func(c echo.Context) error {
return c.String(http.StatusNotFound, "test") return c.String(http.StatusNotFound, "test")
} }))
mw(c) h.Handle(c)
// Status 5xx with empty path // Status 5xx with empty path
req = test.NewRequest(echo.GET, "", nil) req = test.NewRequest(echo.GET, "", nil)
rec = test.NewResponseRecorder() rec = test.NewResponseRecorder()
c = echo.NewContext(req, rec, e) c = echo.NewContext(req, rec, e)
h = func(c echo.Context) error { h = Log()(echo.HandlerFunc(func(c echo.Context) error {
return errors.New("error") return errors.New("error")
} }))
mw(c) h.Handle(c)
} }
func TestLogIPAddress(t *testing.T) { func TestLogIPAddress(t *testing.T) {
@ -60,25 +59,24 @@ func TestLogIPAddress(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger().(*log.Logger).SetOutput(buf) e.Logger().(*log.Logger).SetOutput(buf)
ip := "127.0.0.1" ip := "127.0.0.1"
h := func(c echo.Context) error { h := Log()(echo.HandlerFunc(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
} }))
mw := Log()(h)
// With X-Real-IP // With X-Real-IP
req.Header().Add(echo.XRealIP, ip) req.Header().Add(echo.XRealIP, ip)
mw(c) h.Handle(c)
assert.Contains(t, buf.String(), ip) assert.Contains(t, buf.String(), ip)
// With X-Forwarded-For // With X-Forwarded-For
buf.Reset() buf.Reset()
req.Header().Del(echo.XRealIP) req.Header().Del(echo.XRealIP)
req.Header().Add(echo.XForwardedFor, ip) req.Header().Add(echo.XForwardedFor, ip)
mw(c) h.Handle(c)
assert.Contains(t, buf.String(), ip) assert.Contains(t, buf.String(), ip)
// with req.RemoteAddr // with req.RemoteAddr
buf.Reset() buf.Reset()
mw(c) h.Handle(c)
assert.Contains(t, buf.String(), ip) assert.Contains(t, buf.String(), ip)
} }

View File

@ -11,9 +11,9 @@ import (
// Recover returns a middleware which recovers from panics anywhere in the chain // Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler. // and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc { func Recover() echo.MiddlewareFunc {
return func(h echo.HandlerFunc) echo.HandlerFunc { return func(h echo.Handler) echo.Handler {
// TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace` // TODO: Provide better stack trace `https://github.com/go-errors/errors` `https://github.com/docker/libcontainer/tree/master/stacktrace`
return func(c echo.Context) error { return echo.HandlerFunc(func(c echo.Context) error {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
trace := make([]byte, 1<<16) trace := make([]byte, 1<<16)
@ -23,6 +23,6 @@ func Recover() echo.MiddlewareFunc {
} }
}() }()
return h.Handle(c) return h.Handle(c)
} })
} }
} }

View File

@ -15,10 +15,10 @@ func TestRecover(t *testing.T) {
req := test.NewRequest(echo.GET, "/", nil) req := test.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := test.NewResponseRecorder()
c := echo.NewContext(req, rec, e) c := echo.NewContext(req, rec, e)
h := func(c echo.Context) error { h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
panic("test") panic("test")
} }))
Recover()(h)(c) h.Handle(c)
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rec.Status())
assert.Contains(t, rec.Body.String(), "panic recover") assert.Contains(t, rec.Body.String(), "panic recover")
} }

View File

@ -15,20 +15,19 @@ type (
ppath string ppath string
pnames []string pnames []string
methodHandler *methodHandler methodHandler *methodHandler
echo *Echo
} }
kind uint8 kind uint8
children []*node children []*node
methodHandler struct { methodHandler struct {
connect HandlerFunc connect Handler
delete HandlerFunc delete Handler
get HandlerFunc get Handler
head HandlerFunc head Handler
options HandlerFunc options Handler
patch HandlerFunc patch Handler
post HandlerFunc post Handler
put HandlerFunc put Handler
trace HandlerFunc trace Handler
} }
) )
@ -48,7 +47,20 @@ func NewRouter(e *Echo) *Router {
} }
} }
func (r *Router) Add(method, path string, h HandlerFunc, e *Echo) { func (r *Router) Handle(h Handler) Handler {
return HandlerFunc(func(c Context) error {
method := c.Request().Method()
path := c.Request().URL().Path()
r.Find(method, path, c)
return h.Handle(c)
})
}
func (r *Router) Priority() int {
return 0
}
func (r *Router) Add(method, path string, h Handler, e *Echo) {
ppath := path // Pristine path ppath := path // Pristine path
pnames := []string{} // Param names pnames := []string{} // Param names
@ -80,7 +92,7 @@ func (r *Router) Add(method, path string, h HandlerFunc, e *Echo) {
r.insert(method, path, h, skind, ppath, pnames, e) r.insert(method, path, h, skind, ppath, pnames, e)
} }
func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string, pnames []string, e *Echo) { func (r *Router) insert(method, path string, h Handler, t kind, ppath string, pnames []string, e *Echo) {
// Adjust max param // Adjust max param
l := len(pnames) l := len(pnames)
if *e.maxParam < l { if *e.maxParam < l {
@ -89,7 +101,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
cn := r.tree // Current node as root cn := r.tree // Current node as root
if cn == nil { if cn == nil {
panic("echo => invalid method") panic("echo invalid method")
} }
search := path search := path
@ -115,11 +127,10 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
cn.addHandler(method, h) cn.addHandler(method, h)
cn.ppath = ppath cn.ppath = ppath
cn.pnames = pnames cn.pnames = pnames
cn.echo = e
} }
} else if l < pl { } else if l < pl {
// Split node // Split node
n := newNode(cn.kind, cn.prefix[l:], cn, cn.children, cn.methodHandler, cn.ppath, cn.pnames, cn.echo) n := newNode(cn.kind, cn.prefix[l:], cn, cn.children, cn.methodHandler, cn.ppath, cn.pnames)
// Reset parent node // Reset parent node
cn.kind = skind cn.kind = skind
@ -129,7 +140,6 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
cn.methodHandler = new(methodHandler) cn.methodHandler = new(methodHandler)
cn.ppath = "" cn.ppath = ""
cn.pnames = nil cn.pnames = nil
cn.echo = nil
cn.addChild(n) cn.addChild(n)
@ -139,10 +149,9 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
cn.addHandler(method, h) cn.addHandler(method, h)
cn.ppath = ppath cn.ppath = ppath
cn.pnames = pnames cn.pnames = pnames
cn.echo = e
} else { } else {
// Create child node // Create child node
n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames, e) n = newNode(t, search[l:], cn, nil, new(methodHandler), ppath, pnames)
n.addHandler(method, h) n.addHandler(method, h)
cn.addChild(n) cn.addChild(n)
} }
@ -155,7 +164,7 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
continue continue
} }
// Create child node // Create child node
n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames, e) n := newNode(t, search, cn, nil, new(methodHandler), ppath, pnames)
n.addHandler(method, h) n.addHandler(method, h)
cn.addChild(n) cn.addChild(n)
} else { } else {
@ -164,14 +173,13 @@ func (r *Router) insert(method, path string, h HandlerFunc, t kind, ppath string
cn.addHandler(method, h) cn.addHandler(method, h)
cn.ppath = path cn.ppath = path
cn.pnames = pnames cn.pnames = pnames
cn.echo = e
} }
} }
return return
} }
} }
func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath string, pnames []string, e *Echo) *node { func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath string, pnames []string) *node {
return &node{ return &node{
kind: t, kind: t,
label: pre[0], label: pre[0],
@ -181,7 +189,6 @@ func newNode(t kind, pre string, p *node, c children, mh *methodHandler, ppath s
ppath: ppath, ppath: ppath,
pnames: pnames, pnames: pnames,
methodHandler: mh, methodHandler: mh,
echo: e,
} }
} }
@ -216,7 +223,7 @@ func (n *node) findChildByKind(t kind) *node {
return nil return nil
} }
func (n *node) addHandler(method string, h HandlerFunc) { func (n *node) addHandler(method string, h Handler) {
switch method { switch method {
case GET: case GET:
n.methodHandler.get = h n.methodHandler.get = h
@ -239,7 +246,7 @@ func (n *node) addHandler(method string, h HandlerFunc) {
} }
} }
func (n *node) findHandler(method string) HandlerFunc { func (n *node) findHandler(method string) Handler {
switch method { switch method {
case GET: case GET:
return n.methodHandler.get return n.methodHandler.get
@ -273,10 +280,10 @@ func (n *node) check405() HandlerFunc {
return notFoundHandler return notFoundHandler
} }
func (r *Router) Find(method, path string, context Context) (h HandlerFunc, e *Echo) { func (r *Router) Find(method, path string, context Context) {
x := context.Object() ctx := context.Object()
h = notFoundHandler // h = notFoundHandler
e = r.echo // e = r.echo
cn := r.tree // Current node as root cn := r.tree // Current node as root
var ( var (
@ -357,7 +364,7 @@ func (r *Router) Find(method, path string, context Context) (h HandlerFunc, e *E
i, l := 0, len(search) i, l := 0, len(search)
for ; i < l && search[i] != '/'; i++ { for ; i < l && search[i] != '/'; i++ {
} }
x.pvalues[n] = search[:i] ctx.pvalues[n] = search[:i]
n++ n++
search = search[i:] search = search[i:]
continue continue
@ -370,30 +377,27 @@ func (r *Router) Find(method, path string, context Context) (h HandlerFunc, e *E
// Not found // Not found
return return
} }
x.pvalues[len(cn.pnames)-1] = search ctx.pvalues[len(cn.pnames)-1] = search
goto End goto End
} }
End: End:
x.path = cn.ppath ctx.path = cn.ppath
x.pnames = cn.pnames ctx.pnames = cn.pnames
h = cn.findHandler(method) ctx.handler = cn.findHandler(method)
if cn.echo != nil {
e = cn.echo
}
// NOTE: Slow zone... // NOTE: Slow zone...
if h == nil { if ctx.handler == nil {
h = cn.check405() ctx.handler = cn.check405()
// Dig further for match-any, might have an empty value for *, e.g. // Dig further for match-any, might have an empty value for *, e.g.
// serving a directory. Issue #207. // serving a directory. Issue #207.
if cn = cn.findChildByKind(mkind); cn == nil { if cn = cn.findChildByKind(mkind); cn == nil {
return return
} }
x.pvalues[len(cn.pnames)-1] = "" ctx.pvalues[len(cn.pnames)-1] = ""
if h = cn.findHandler(method); h == nil { if ctx.handler = cn.findHandler(method); ctx.handler == nil {
h = cn.check405() ctx.handler = cn.check405()
} }
} }
return return

View File

@ -277,44 +277,38 @@ func TestRouterStatic(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
path := "/folders/a/files/echo.gif" path := "/folders/a/files/echo.gif"
r.Add(GET, path, func(c Context) error { r.Add(GET, path, HandlerFunc(func(c Context) error {
c.Set("path", path) c.Set("path", path)
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
h, _ := r.Find(GET, path, c) r.Find(GET, path, c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, path, c.Get("path"))
assert.Equal(t, path, c.Get("path"))
}
} }
func TestRouterParam(t *testing.T) { func TestRouterParam(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
r.Add(GET, "/users/:id", func(c Context) error { r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
h, _ := r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
if assert.NotNil(t, h) { assert.Equal(t, "1", c.P(0))
assert.Equal(t, "1", c.P(0))
}
} }
func TestRouterTwoParam(t *testing.T) { func TestRouterTwoParam(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
r.Add(GET, "/users/:uid/files/:fid", func(Context) error { r.Add(GET, "/users/:uid/files/:fid", HandlerFunc(func(Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
h, _ := r.Find(GET, "/users/1/files/1", c) r.Find(GET, "/users/1/files/1", c)
if assert.NotNil(t, h) { assert.Equal(t, "1", c.P(0))
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(1))
assert.Equal(t, "1", c.P(1))
}
} }
func TestRouterMatchAny(t *testing.T) { func TestRouterMatchAny(t *testing.T) {
@ -322,46 +316,38 @@ func TestRouterMatchAny(t *testing.T) {
r := e.router r := e.router
// Routes // Routes
r.Add(GET, "/", func(Context) error { r.Add(GET, "/", HandlerFunc(func(Context) error {
return nil return nil
}, e) }), e)
r.Add(GET, "/*", func(Context) error { r.Add(GET, "/*", HandlerFunc(func(Context) error {
return nil return nil
}, e) }), e)
r.Add(GET, "/users/*", func(Context) error { r.Add(GET, "/users/*", HandlerFunc(func(Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
h, _ := r.Find(GET, "/", c) r.Find(GET, "/", c)
if assert.NotNil(t, h) { assert.Equal(t, "", c.P(0))
assert.Equal(t, "", c.P(0))
}
h, _ = r.Find(GET, "/download", c) r.Find(GET, "/download", c)
if assert.NotNil(t, h) { assert.Equal(t, "download", c.P(0))
assert.Equal(t, "download", c.P(0))
}
h, _ = r.Find(GET, "/users/joe", c) r.Find(GET, "/users/joe", c)
if assert.NotNil(t, h) { assert.Equal(t, "joe", c.P(0))
assert.Equal(t, "joe", c.P(0))
}
} }
func TestRouterMicroParam(t *testing.T) { func TestRouterMicroParam(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
r.Add(GET, "/:a/:b/:c", func(c Context) error { r.Add(GET, "/:a/:b/:c", HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
h, _ := r.Find(GET, "/1/2/3", c) r.Find(GET, "/1/2/3", c)
if assert.NotNil(t, h) { assert.Equal(t, "1", c.P(0))
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "2", c.P(1))
assert.Equal(t, "2", c.P(1)) assert.Equal(t, "3", c.P(2))
assert.Equal(t, "3", c.P(2))
}
} }
func TestRouterMixParamMatchAny(t *testing.T) { func TestRouterMixParamMatchAny(t *testing.T) {
@ -369,16 +355,14 @@ func TestRouterMixParamMatchAny(t *testing.T) {
r := e.router r := e.router
// Route // Route
r.Add(GET, "/users/:id/*", func(c Context) error { r.Add(GET, "/users/:id/*", HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
h, _ := r.Find(GET, "/users/joe/comments", c) r.Find(GET, "/users/joe/comments", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, "joe", c.P(0))
assert.Equal(t, "joe", c.P(0))
}
} }
func TestRouterMultiRoute(t *testing.T) { func TestRouterMultiRoute(t *testing.T) {
@ -386,32 +370,29 @@ func TestRouterMultiRoute(t *testing.T) {
r := e.router r := e.router
// Routes // Routes
r.Add(GET, "/users", func(c Context) error { r.Add(GET, "/users", HandlerFunc(func(c Context) error {
c.Set("path", "/users") c.Set("path", "/users")
return nil return nil
}, e) }), e)
r.Add(GET, "/users/:id", func(c Context) error { r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
// Route > /users // Route > /users
h, _ := r.Find(GET, "/users", c) r.Find(GET, "/users", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, "/users", c.Get("path"))
assert.Equal(t, "/users", c.Get("path"))
}
// Route > /users/:id // Route > /users/:id
h, _ = r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
if assert.NotNil(t, h) { assert.Equal(t, "1", c.P(0))
assert.Equal(t, "1", c.P(0))
}
// Route > /user // Route > /user
h, _ = r.Find(GET, "/user", c) c = NewContext(nil, nil, e)
if assert.IsType(t, new(HTTPError), h(c)) { r.Find(GET, "/user", c)
he := h(c).(*HTTPError) if assert.IsType(t, new(HTTPError), c.Handle(c)) {
he := c.Handle(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.code) assert.Equal(t, http.StatusNotFound, he.code)
} }
} }
@ -421,85 +402,71 @@ func TestRouterPriority(t *testing.T) {
r := e.router r := e.router
// Routes // Routes
r.Add(GET, "/users", func(c Context) error { r.Add(GET, "/users", HandlerFunc(func(c Context) error {
c.Set("a", 1) c.Set("a", 1)
return nil return nil
}, e) }), e)
r.Add(GET, "/users/new", func(c Context) error { r.Add(GET, "/users/new", HandlerFunc(func(c Context) error {
c.Set("b", 2) c.Set("b", 2)
return nil return nil
}, e) }), e)
r.Add(GET, "/users/:id", func(c Context) error { r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error {
c.Set("c", 3) c.Set("c", 3)
return nil return nil
}, e) }), e)
r.Add(GET, "/users/dew", func(c Context) error { r.Add(GET, "/users/dew", HandlerFunc(func(c Context) error {
c.Set("d", 4) c.Set("d", 4)
return nil return nil
}, e) }), e)
r.Add(GET, "/users/:id/files", func(c Context) error { r.Add(GET, "/users/:id/files", HandlerFunc(func(c Context) error {
c.Set("e", 5) c.Set("e", 5)
return nil return nil
}, e) }), e)
r.Add(GET, "/users/newsee", func(c Context) error { r.Add(GET, "/users/newsee", HandlerFunc(func(c Context) error {
c.Set("f", 6) c.Set("f", 6)
return nil return nil
}, e) }), e)
r.Add(GET, "/users/*", func(c Context) error { r.Add(GET, "/users/*", HandlerFunc(func(c Context) error {
c.Set("g", 7) c.Set("g", 7)
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
// Route > /users // Route > /users
h, _ := r.Find(GET, "/users", c) r.Find(GET, "/users", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 1, c.Get("a"))
assert.Equal(t, 1, c.Get("a"))
}
// Route > /users/new // Route > /users/new
h, _ = r.Find(GET, "/users/new", c) r.Find(GET, "/users/new", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 2, c.Get("b"))
assert.Equal(t, 2, c.Get("b"))
}
// Route > /users/:id // Route > /users/:id
h, _ = r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 3, c.Get("c"))
assert.Equal(t, 3, c.Get("c"))
}
// Route > /users/dew // Route > /users/dew
h, _ = r.Find(GET, "/users/dew", c) r.Find(GET, "/users/dew", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 4, c.Get("d"))
assert.Equal(t, 4, c.Get("d"))
}
// Route > /users/:id/files // Route > /users/:id/files
h, _ = r.Find(GET, "/users/1/files", c) r.Find(GET, "/users/1/files", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 5, c.Get("e"))
assert.Equal(t, 5, c.Get("e"))
}
// Route > /users/:id // Route > /users/:id
h, _ = r.Find(GET, "/users/news", c) r.Find(GET, "/users/news", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 3, c.Get("c"))
assert.Equal(t, 3, c.Get("c"))
}
// Route > /users/* // Route > /users/*
h, _ = r.Find(GET, "/users/joe/books", c) r.Find(GET, "/users/joe/books", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, 7, c.Get("g"))
assert.Equal(t, 7, c.Get("g")) assert.Equal(t, "joe/books", c.Param("_*"))
assert.Equal(t, "joe/books", c.Param("_*"))
}
} }
func TestRouterParamNames(t *testing.T) { func TestRouterParamNames(t *testing.T) {
@ -507,40 +474,34 @@ func TestRouterParamNames(t *testing.T) {
r := e.router r := e.router
// Routes // Routes
r.Add(GET, "/users", func(c Context) error { r.Add(GET, "/users", HandlerFunc(func(c Context) error {
c.Set("path", "/users") c.Set("path", "/users")
return nil return nil
}, e) }), e)
r.Add(GET, "/users/:id", func(c Context) error { r.Add(GET, "/users/:id", HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
r.Add(GET, "/users/:uid/files/:fid", func(c Context) error { r.Add(GET, "/users/:uid/files/:fid", HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
// Route > /users // Route > /users
h, _ := r.Find(GET, "/users", c) r.Find(GET, "/users", c)
if assert.NotNil(t, h) { c.Handle(c)
h(c) assert.Equal(t, "/users", c.Get("path"))
assert.Equal(t, "/users", c.Get("path"))
}
// Route > /users/:id // Route > /users/:id
h, _ = r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
if assert.NotNil(t, h) { assert.Equal(t, "id", c.Object().pnames[0])
assert.Equal(t, "id", c.Object().pnames[0]) assert.Equal(t, "1", c.P(0))
assert.Equal(t, "1", c.P(0))
}
// Route > /users/:uid/files/:fid // Route > /users/:uid/files/:fid
h, _ = r.Find(GET, "/users/1/files/1", c) r.Find(GET, "/users/1/files/1", c)
if assert.NotNil(t, h) { assert.Equal(t, "uid", c.Object().pnames[0])
assert.Equal(t, "uid", c.Object().pnames[0]) assert.Equal(t, "1", c.P(0))
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "fid", c.Object().pnames[1])
assert.Equal(t, "fid", c.Object().pnames[1]) assert.Equal(t, "1", c.P(1))
assert.Equal(t, "1", c.P(1))
}
} }
func TestRouterAPI(t *testing.T) { func TestRouterAPI(t *testing.T) {
@ -548,21 +509,19 @@ func TestRouterAPI(t *testing.T) {
r := e.router r := e.router
for _, route := range api { for _, route := range api {
r.Add(route.Method, route.Path, func(c Context) error { r.Add(route.Method, route.Path, HandlerFunc(func(c Context) error {
return nil return nil
}, e) }), e)
} }
c := NewContext(nil, nil, e) c := NewContext(nil, nil, e)
for _, route := range api { for _, route := range api {
h, _ := r.Find(route.Method, route.Path, c) r.Find(route.Method, route.Path, c)
if assert.NotNil(t, h) { for i, n := range c.Object().pnames {
for i, n := range c.Object().pnames { if assert.NotEmpty(t, n) {
if assert.NotEmpty(t, n) { assert.Equal(t, ":"+n, c.P(i))
assert.Equal(t, ":"+n, c.P(i))
}
} }
h(c)
} }
c.Handle(c)
} }
} }