1
0
mirror of https://github.com/labstack/echo.git synced 2025-11-25 22:32:23 +02:00

Changes based on comments for #430

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana
2016-04-16 15:53:27 -07:00
parent fd104333f2
commit 467cf05b41
11 changed files with 191 additions and 169 deletions

2
.gitignore vendored
View File

@@ -11,3 +11,5 @@ node_modules
# Glide # Glide
vendor vendor
.DS_Store

View File

@@ -40,6 +40,9 @@ type (
// Path returns the registered path for the handler. // Path returns the registered path for the handler.
Path() string Path() string
// SetPath sets the registered path for the handler.
SetPath(string)
// P returns path parameter by index. // P returns path parameter by index.
P(int) string P(int) string
@@ -49,11 +52,20 @@ type (
// ParamNames returns path parameter names. // ParamNames returns path parameter names.
ParamNames() []string ParamNames() []string
// SetParamNames sets path parameter names.
SetParamNames([]string)
// ParamValues returns path parameter values.
ParamValues() []string
// SetParamValues sets path parameter values.
SetParamValues([]string)
// QueryParam returns the query param for the provided name. It is an alias // QueryParam returns the query param for the provided name. It is an alias
// for `engine.URL#QueryParam()`. // for `engine.URL#QueryParam()`.
QueryParam(string) string QueryParam(string) string
// QueryParam returns the query parameters as map. It is an alias for `engine.URL#QueryParams()`. // QueryParams returns the query parameters as map. It is an alias for `engine.URL#QueryParams()`.
QueryParams() map[string][]string QueryParams() map[string][]string
// FormValue returns the form field value for the provided name. It is an // FormValue returns the form field value for the provided name. It is an
@@ -76,6 +88,9 @@ type (
// Set saves data in the context. // Set saves data in the context.
Set(string, interface{}) Set(string, interface{})
// Del data from the context
Del(string)
// Bind binds the request body into provided type `i`. The default binder // Bind binds the request body into provided type `i`. The default binder
// does it based on Content-Type header. // does it based on Content-Type header.
Bind(interface{}) error Bind(interface{}) error
@@ -125,6 +140,9 @@ type (
// Handler returns the matched handler by router. // Handler returns the matched handler by router.
Handler() HandlerFunc Handler() HandlerFunc
// SetHandler sets the matched handler by router.
SetHandler(HandlerFunc)
// Logger returns the `Logger` instance. // Logger returns the `Logger` instance.
Logger() *log.Logger Logger() *log.Logger
@@ -136,9 +154,6 @@ type (
// and `Last-Modified` response headers. // and `Last-Modified` response headers.
ServeContent(io.ReadSeeker, string, time.Time) error ServeContent(io.ReadSeeker, string, time.Time) error
// Object returns the internal context implementation.
Object() *context
// Reset resets the context after request completes. It must be called along // Reset resets the context after request completes. It must be called along
// with `Echo#GetContext()` and `Echo#PutContext()`. See `Echo#ServeHTTP()` // with `Echo#GetContext()` and `Echo#PutContext()`. See `Echo#ServeHTTP()`
Reset(engine.Request, engine.Response) Reset(engine.Request, engine.Response)
@@ -163,32 +178,6 @@ const (
indexPage = "index.html" indexPage = "index.html"
) )
// NewContext creates a Context object.
func NewContext(rq engine.Request, rs engine.Response, e *Echo) Context {
return &context{
request: rq,
response: rs,
echo: e,
pvalues: make([]string, *e.maxParam),
store: make(store),
handler: notFoundHandler,
}
}
// MockContext returns `Context` for testing purpose.
func MockContext(request engine.Request, response engine.Response, path string, paramNames []string, paramValues []string) Context {
return &context{
request: request,
response: response,
echo: new(Echo),
path: path,
pnames: paramNames,
pvalues: paramValues,
store: make(store),
handler: notFoundHandler,
}
}
func (c *context) NetContext() netContext.Context { func (c *context) NetContext() netContext.Context {
return c.netContext return c.netContext
} }
@@ -225,6 +214,10 @@ func (c *context) Path() string {
return c.path return c.path
} }
func (c *context) SetPath(p string) {
c.path = p
}
func (c *context) P(i int) (value string) { func (c *context) P(i int) (value string) {
l := len(c.pnames) l := len(c.pnames)
if i < l { if i < l {
@@ -248,6 +241,18 @@ func (c *context) ParamNames() []string {
return c.pnames return c.pnames
} }
func (c *context) SetParamNames(names []string) {
c.pnames = names
}
func (c *context) ParamValues() []string {
return c.pvalues
}
func (c *context) SetParamValues(values []string) {
c.pvalues = values
}
func (c *context) QueryParam(name string) string { func (c *context) QueryParam(name string) string {
return c.request.URL().QueryParam(name) return c.request.URL().QueryParam(name)
} }
@@ -283,6 +288,10 @@ func (c *context) Get(key string) interface{} {
return c.store[key] return c.store[key]
} }
func (c *context) Del(key string) {
delete(c.store, key)
}
func (c *context) Bind(i interface{}) error { func (c *context) Bind(i interface{}) error {
return c.echo.binder.Bind(i, c) return c.echo.binder.Bind(i, c)
} }
@@ -426,12 +435,12 @@ func (c *context) Handler() HandlerFunc {
return c.handler return c.handler
} }
func (c *context) Logger() *log.Logger { func (c *context) SetHandler(h HandlerFunc) {
return c.echo.logger c.handler = h
} }
func (c *context) Object() *context { func (c *context) Logger() *log.Logger {
return c return c.echo.logger
} }
func (c *context) ServeContent(content io.ReadSeeker, name string, modtime time.Time) error { func (c *context) ServeContent(content io.ReadSeeker, name string, modtime time.Time) error {

View File

@@ -36,8 +36,8 @@ func TestContext(t *testing.T) {
e := New() e := New()
rq := test.NewRequest(POST, "/", strings.NewReader(userJSON)) rq := test.NewRequest(POST, "/", strings.NewReader(userJSON))
rec := test.NewResponseRecorder() rc := test.NewResponseRecorder()
c := NewContext(rq, rec, e) c := e.NewContext(rq, rc).(*context)
// Request // Request
assert.NotNil(t, c.Request()) assert.NotNil(t, c.Request())
@@ -46,12 +46,12 @@ func TestContext(t *testing.T) {
assert.NotNil(t, c.Response()) assert.NotNil(t, c.Response())
// ParamNames // ParamNames
c.Object().pnames = []string{"uid", "fid"} c.pnames = []string{"uid", "fid"}
assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames())
// Param by id // Param by id
c.Object().pnames = []string{"id"} c.pnames = []string{"id"}
c.Object().pvalues = []string{"1"} c.pvalues = []string{"1"}
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
// Param by name // Param by name
@@ -67,13 +67,13 @@ func TestContext(t *testing.T) {
// JSON // JSON
testBindOk(t, c, MIMEApplicationJSON) testBindOk(t, c, MIMEApplicationJSON)
c.Object().request = test.NewRequest(POST, "/", strings.NewReader(invalidContent)) c.request = test.NewRequest(POST, "/", strings.NewReader(invalidContent))
testBindError(t, c, MIMEApplicationJSON) testBindError(t, c, MIMEApplicationJSON)
// XML // XML
c.Object().request = test.NewRequest(POST, "/", strings.NewReader(userXML)) c.request = test.NewRequest(POST, "/", strings.NewReader(userXML))
testBindOk(t, c, MIMEApplicationXML) testBindOk(t, c, MIMEApplicationXML)
c.Object().request = test.NewRequest(POST, "/", strings.NewReader(invalidContent)) c.request = test.NewRequest(POST, "/", strings.NewReader(invalidContent))
testBindError(t, c, MIMEApplicationXML) testBindError(t, c, MIMEApplicationXML)
// Unsupported // Unsupported
@@ -86,114 +86,114 @@ func TestContext(t *testing.T) {
tpl := &Template{ tpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
} }
c.Object().echo.SetRenderer(tpl) c.echo.SetRenderer(tpl)
err := c.Render(http.StatusOK, "hello", "Joe") err := c.Render(http.StatusOK, "hello", "Joe")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, "Hello, Joe!", rec.Body.String()) assert.Equal(t, "Hello, Joe!", rc.Body.String())
} }
c.Object().echo.renderer = nil c.echo.renderer = nil
err = c.Render(http.StatusOK, "hello", "Joe") err = c.Render(http.StatusOK, "hello", "Joe")
assert.Error(t, err) assert.Error(t, err)
// JSON // JSON
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
err = c.JSON(http.StatusOK, user{"1", "Joe"}) err = c.JSON(http.StatusOK, user{"1", "Joe"})
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rc.Header().Get(HeaderContentType))
assert.Equal(t, userJSON, rec.Body.String()) assert.Equal(t, userJSON, rc.Body.String())
} }
// JSON (error) // JSON (error)
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
err = c.JSON(http.StatusOK, make(chan bool)) err = c.JSON(http.StatusOK, make(chan bool))
assert.Error(t, err) assert.Error(t, err)
// JSONP // JSONP
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
callback := "callback" callback := "callback"
err = c.JSONP(http.StatusOK, callback, user{"1", "Joe"}) err = c.JSONP(http.StatusOK, callback, user{"1", "Joe"})
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rc.Header().Get(HeaderContentType))
assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) assert.Equal(t, callback+"("+userJSON+");", rc.Body.String())
} }
// XML // XML
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
err = c.XML(http.StatusOK, user{"1", "Joe"}) err = c.XML(http.StatusOK, user{"1", "Joe"})
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rc.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXML, rec.Body.String()) assert.Equal(t, xml.Header+userXML, rc.Body.String())
} }
// XML (error) // XML (error)
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
err = c.XML(http.StatusOK, make(chan bool)) err = c.XML(http.StatusOK, make(chan bool))
assert.Error(t, err) assert.Error(t, err)
// String // String
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
err = c.String(http.StatusOK, "Hello, World!") err = c.String(http.StatusOK, "Hello, World!")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextPlainCharsetUTF8, rc.Header().Get(HeaderContentType))
assert.Equal(t, "Hello, World!", rec.Body.String()) assert.Equal(t, "Hello, World!", rc.Body.String())
} }
// HTML // HTML
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>") err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextHTMLCharsetUTF8, rc.Header().Get(HeaderContentType))
assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String()) assert.Equal(t, "Hello, <strong>World!</strong>", rc.Body.String())
} }
// Attachment // Attachment
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
file, err := os.Open("_fixture/images/walle.png") file, err := os.Open("_fixture/images/walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
err = c.Attachment(file, "walle.png") err = c.Attachment(file, "walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "attachment; filename=walle.png", rc.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, 219885, rc.Body.Len())
} }
} }
// NoContent // NoContent
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
c.NoContent(http.StatusOK) c.NoContent(http.StatusOK)
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rc.Status())
// Redirect // Redirect
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
assert.Equal(t, http.StatusMovedPermanently, rec.Status()) assert.Equal(t, http.StatusMovedPermanently, rc.Status())
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) assert.Equal(t, "http://labstack.github.io/echo", rc.Header().Get(HeaderLocation))
// Error // Error
rec = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rec, e) c = e.NewContext(rq, rc).(*context)
c.Error(errors.New("error")) c.Error(errors.New("error"))
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rc.Status())
// Reset // Reset
c.Object().Reset(rq, test.NewResponseRecorder()) c.Reset(rq, test.NewResponseRecorder())
} }
func TestContextPath(t *testing.T) { func TestContextPath(t *testing.T) {
@@ -201,12 +201,12 @@ func TestContextPath(t *testing.T) {
r := e.Router() r := e.Router()
r.Add(GET, "/users/:id", nil, e) r.Add(GET, "/users/:id", nil, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil)
r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
assert.Equal(t, "/users/:id", c.Path()) assert.Equal(t, "/users/:id", c.Path())
r.Add(GET, "/users/:uid/files/:fid", nil, e) r.Add(GET, "/users/:uid/files/:fid", nil, e)
c = NewContext(nil, nil, e) c = e.NewContext(nil, nil)
r.Find(GET, "/users/1/files/1", c) r.Find(GET, "/users/1/files/1", c)
assert.Equal(t, "/users/:uid/files/:fid", c.Path()) assert.Equal(t, "/users/:uid/files/:fid", c.Path())
} }
@@ -216,7 +216,8 @@ func TestContextQueryParam(t *testing.T) {
q.Set("name", "joe") q.Set("name", "joe")
q.Set("email", "joe@labstack.com") q.Set("email", "joe@labstack.com")
rq := test.NewRequest(GET, "/?"+q.Encode(), nil) rq := test.NewRequest(GET, "/?"+q.Encode(), nil)
c := NewContext(rq, nil, New()) e := New()
c := e.NewContext(rq, nil)
assert.Equal(t, "joe", c.QueryParam("name")) assert.Equal(t, "joe", c.QueryParam("name"))
assert.Equal(t, "joe@labstack.com", c.QueryParam("email")) assert.Equal(t, "joe@labstack.com", c.QueryParam("email"))
} }
@@ -226,10 +227,11 @@ func TestContextFormValue(t *testing.T) {
f.Set("name", "joe") f.Set("name", "joe")
f.Set("email", "joe@labstack.com") f.Set("email", "joe@labstack.com")
e := New()
rq := test.NewRequest(POST, "/", strings.NewReader(f.Encode())) rq := test.NewRequest(POST, "/", strings.NewReader(f.Encode()))
rq.Header().Add(HeaderContentType, MIMEApplicationForm) rq.Header().Add(HeaderContentType, MIMEApplicationForm)
c := NewContext(rq, nil, New()) c := e.NewContext(rq, nil)
assert.Equal(t, "joe", c.FormValue("name")) assert.Equal(t, "joe", c.FormValue("name"))
assert.Equal(t, "joe@labstack.com", c.FormValue("email")) assert.Equal(t, "joe@labstack.com", c.FormValue("email"))
} }
@@ -244,7 +246,7 @@ func TestContextServeContent(t *testing.T) {
e := New() e := New()
rq := test.NewRequest(GET, "/", nil) rq := test.NewRequest(GET, "/", nil)
rc := test.NewResponseRecorder() rc := test.NewResponseRecorder()
c := NewContext(rq, rc, e) c := e.NewContext(rq, rc)
fs := http.Dir("_fixture/images") fs := http.Dir("_fixture/images")
f, err := fs.Open("walle.png") f, err := fs.Open("walle.png")
@@ -258,7 +260,7 @@ func TestContextServeContent(t *testing.T) {
// Cached // Cached
rc = test.NewResponseRecorder() rc = test.NewResponseRecorder()
c = NewContext(rq, rc, e) c = e.NewContext(rq, rc)
rq.Header().Set(HeaderIfModifiedSince, fi.ModTime().UTC().Format(http.TimeFormat)) rq.Header().Set(HeaderIfModifiedSince, fi.ModTime().UTC().Format(http.TimeFormat))
if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) { if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) {
assert.Equal(t, http.StatusNotModified, rc.Status()) assert.Equal(t, http.StatusNotModified, rc.Status())
@@ -276,7 +278,7 @@ func TestContextHandler(t *testing.T) {
_, err := b.Write([]byte("handler")) _, err := b.Write([]byte("handler"))
return err return err
}, e) }, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil)
r.Find(GET, "/handler", c) r.Find(GET, "/handler", c)
c.Handler()(c) c.Handler()(c)
assert.Equal(t, "handler", b.String()) assert.Equal(t, "handler", b.String())

14
echo.go
View File

@@ -215,7 +215,7 @@ var (
func New() (e *Echo) { func New() (e *Echo) {
e = &Echo{maxParam: new(int)} e = &Echo{maxParam: new(int)}
e.pool.New = func() interface{} { e.pool.New = func() interface{} {
return NewContext(nil, nil, e) return e.NewContext(nil, nil)
} }
e.router = NewRouter(e) e.router = NewRouter(e)
@@ -228,6 +228,18 @@ func New() (e *Echo) {
return return
} }
// NewContext returns a Context instance.
func (e *Echo) NewContext(rq engine.Request, rs engine.Response) Context {
return &context{
request: rq,
response: rs,
echo: e,
pvalues: make([]string, *e.maxParam),
store: make(store),
handler: notFoundHandler,
}
}
// Router returns router. // Router returns router.
func (e *Echo) Router() *Router { func (e *Echo) Router() *Router {
return e.router return e.router

View File

@@ -25,8 +25,8 @@ type (
func TestEcho(t *testing.T) { func TestEcho(t *testing.T) {
e := New() e := New()
rq := test.NewRequest(GET, "/", nil) rq := test.NewRequest(GET, "/", nil)
rec := test.NewResponseRecorder() rc := test.NewResponseRecorder()
c := NewContext(rq, rec, e) c := e.NewContext(rq, rc)
// Router // Router
assert.NotNil(t, e.Router()) assert.NotNil(t, e.Router())
@@ -37,7 +37,7 @@ func TestEcho(t *testing.T) {
// DefaultHTTPErrorHandler // DefaultHTTPErrorHandler
e.DefaultHTTPErrorHandler(errors.New("error"), c) e.DefaultHTTPErrorHandler(errors.New("error"), c)
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rc.Status())
} }
func TestEchoStatic(t *testing.T) { func TestEchoStatic(t *testing.T) {

View File

@@ -22,14 +22,13 @@ type (
} }
) )
// MockRequest returns `Request` instance for testing purpose. // NewRequest returns `Request` instance.
func MockRequest() *Request { func NewRequest(c *fasthttp.RequestCtx, l *log.Logger) *Request {
ctx := new(fasthttp.RequestCtx)
return &Request{ return &Request{
RequestCtx: ctx, RequestCtx: c,
url: &URL{URI: ctx.URI()}, url: &URL{URI: c.URI()},
header: &RequestHeader{RequestHeader: &ctx.Request.Header}, header: &RequestHeader{RequestHeader: &c.Request.Header},
logger: log.New("test"), logger: l,
} }
} }

View File

@@ -24,14 +24,13 @@ type (
} }
) )
// MockResponse returns `Response` instance for testing purpose. // NewResponse returns `Response` instance.
func MockResponse() *Response { func NewResponse(c *fasthttp.RequestCtx, l *log.Logger) *Response {
ctx := new(fasthttp.RequestCtx)
return &Response{ return &Response{
RequestCtx: ctx, RequestCtx: c,
header: &ResponseHeader{ResponseHeader: &ctx.Response.Header}, header: &ResponseHeader{ResponseHeader: &c.Response.Header},
writer: ctx, writer: c,
logger: log.New("test"), logger: l,
} }
} }

View File

@@ -19,14 +19,13 @@ type (
} }
) )
// MockRequest returns `Request` instance for testing purpose. // NewRequest returns `Request` instance.
func MockRequest() *Request { func NewRequest(r *http.Request, l *log.Logger) *Request {
rq := new(http.Request)
return &Request{ return &Request{
Request: new(http.Request), Request: r,
url: &URL{URL: rq.URL}, url: &URL{URL: r.URL},
header: &Header{Header: rq.Header}, header: &Header{Header: r.Header},
logger: log.New("test"), logger: l,
} }
} }

View File

@@ -5,7 +5,6 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"github.com/labstack/echo/engine" "github.com/labstack/echo/engine"
"github.com/labstack/gommon/log" "github.com/labstack/gommon/log"
@@ -29,14 +28,13 @@ type (
} }
) )
// MockResponse returns `Response` instance for testing purpose. // NewResponse returns `Response` instance.
func MockResponse() *Response { func NewResponse(w http.ResponseWriter, l *log.Logger) *Response {
rc := httptest.NewRecorder()
return &Response{ return &Response{
ResponseWriter: rc, ResponseWriter: w,
header: &Header{Header: rc.Header()}, header: &Header{Header: w.Header()},
writer: rc, writer: w,
logger: log.New("test"), logger: l,
} }
} }

View File

@@ -287,16 +287,16 @@ func (n *node) checkMethodNotAllowed() HandlerFunc {
// - Reset it `Context#Reset()` // - Reset it `Context#Reset()`
// - Return it `Echo#PutContext()`. // - Return it `Echo#PutContext()`.
func (r *Router) Find(method, path string, context Context) { func (r *Router) Find(method, path string, context Context) {
ctx := context.Object()
cn := r.tree // Current node as root cn := r.tree // Current node as root
var ( var (
search = path search = path
c *node // Child node c *node // Child node
n int // Param counter n int // Param counter
nk kind // Next kind nk kind // Next kind
nn *node // Next node nn *node // Next node
ns string // Next search ns string // Next search
pvalues = context.ParamValues()
) )
// Search order static > param > any // Search order static > param > any
@@ -356,7 +356,7 @@ func (r *Router) Find(method, path string, context Context) {
Param: Param:
if c = cn.findChildByKind(pkind); c != nil { if c = cn.findChildByKind(pkind); c != nil {
// Issue #378 // Issue #378
if len(ctx.pvalues) == n { if len(pvalues) == n {
continue continue
} }
@@ -371,7 +371,7 @@ func (r *Router) Find(method, path string, context Context) {
i, l := 0, len(search) i, l := 0, len(search)
for ; i < l && search[i] != '/'; i++ { for ; i < l && search[i] != '/'; i++ {
} }
ctx.pvalues[n] = search[:i] pvalues[n] = search[:i]
n++ n++
search = search[i:] search = search[i:]
continue continue
@@ -393,30 +393,32 @@ func (r *Router) Find(method, path string, context Context) {
// Not found // Not found
return return
} }
ctx.pvalues[len(cn.pnames)-1] = search pvalues[len(cn.pnames)-1] = search
goto End goto End
} }
End: End:
ctx.handler = cn.findHandler(method) context.SetHandler(cn.findHandler(method))
ctx.path = cn.ppath context.SetPath(cn.ppath)
ctx.pnames = cn.pnames context.SetParamNames(cn.pnames)
// NOTE: Slow zone... // NOTE: Slow zone...
if ctx.handler == nil { if context.Handler() == nil {
ctx.handler = cn.checkMethodNotAllowed() context.SetHandler(cn.checkMethodNotAllowed())
// Dig further for any, might have an empty value for *, e.g. // Dig further for any, might have an empty value for *, e.g.
// serving a directory. Issue #207. // serving a directory. Issue #207.
if cn = cn.findChildByKind(akind); cn == nil { if cn = cn.findChildByKind(akind); cn == nil {
return return
} }
if ctx.handler = cn.findHandler(method); ctx.handler == nil { if h := cn.findHandler(method); h != nil {
ctx.handler = cn.checkMethodNotAllowed() context.SetHandler(h)
} else {
context.SetHandler(cn.checkMethodNotAllowed())
} }
ctx.path = cn.ppath context.SetPath(cn.ppath)
ctx.pnames = cn.pnames context.SetParamNames(cn.pnames)
ctx.pvalues[len(cn.pnames)-1] = "" pvalues[len(cn.pnames)-1] = ""
} }
return return

View File

@@ -281,7 +281,7 @@ func TestRouterStatic(t *testing.T) {
c.Set("path", path) c.Set("path", path)
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e).Object() c := e.NewContext(nil, nil).(*context)
r.Find(GET, path, c) r.Find(GET, path, c)
c.handler(c) c.handler(c)
assert.Equal(t, path, c.Get("path")) assert.Equal(t, path, c.Get("path"))
@@ -293,7 +293,7 @@ func TestRouterParam(t *testing.T) {
r.Add(GET, "/users/:id", func(c Context) error { r.Add(GET, "/users/:id", func(c Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil).(*context)
r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
} }
@@ -304,7 +304,7 @@ func TestRouterTwoParam(t *testing.T) {
r.Add(GET, "/users/:uid/files/:fid", func(Context) error { r.Add(GET, "/users/:uid/files/:fid", func(Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil).(*context)
r.Find(GET, "/users/1/files/1", c) r.Find(GET, "/users/1/files/1", c)
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
@@ -324,7 +324,7 @@ func TestRouterParamWithSlash(t *testing.T) {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil).(*context)
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
r.Find(GET, "/a/1/c/d/2/3", c) r.Find(GET, "/a/1/c/d/2/3", c)
}) })
@@ -344,7 +344,7 @@ func TestRouterMatchAny(t *testing.T) {
r.Add(GET, "/users/*", func(Context) error { r.Add(GET, "/users/*", func(Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil).(*context)
r.Find(GET, "/", c) r.Find(GET, "/", c)
assert.Equal(t, "", c.P(0)) assert.Equal(t, "", c.P(0))
@@ -362,7 +362,7 @@ func TestRouterMicroParam(t *testing.T) {
r.Add(GET, "/:a/:b/:c", func(c Context) error { r.Add(GET, "/:a/:b/:c", func(c Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil).(*context)
r.Find(GET, "/1/2/3", c) r.Find(GET, "/1/2/3", c)
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
assert.Equal(t, "2", c.P(1)) assert.Equal(t, "2", c.P(1))
@@ -377,7 +377,7 @@ func TestRouterMixParamMatchAny(t *testing.T) {
r.Add(GET, "/users/:id/*", func(c Context) error { r.Add(GET, "/users/:id/*", func(c Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e).Object() c := e.NewContext(nil, nil).(*context)
r.Find(GET, "/users/joe/comments", c) r.Find(GET, "/users/joe/comments", c)
c.handler(c) c.handler(c)
@@ -396,7 +396,7 @@ func TestRouterMultiRoute(t *testing.T) {
r.Add(GET, "/users/:id", func(c Context) error { r.Add(GET, "/users/:id", func(c Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e).Object() c := e.NewContext(nil, nil).(*context)
// Route > /users // Route > /users
r.Find(GET, "/users", c) r.Find(GET, "/users", c)
@@ -408,7 +408,7 @@ func TestRouterMultiRoute(t *testing.T) {
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
// Route > /user // Route > /user
c = NewContext(nil, nil, e).Object() c = e.NewContext(nil, nil).(*context)
r.Find(GET, "/user", c) r.Find(GET, "/user", c)
he := c.handler(c).(*HTTPError) he := c.handler(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code) assert.Equal(t, http.StatusNotFound, he.Code)
@@ -447,7 +447,7 @@ func TestRouterPriority(t *testing.T) {
c.Set("g", 7) c.Set("g", 7)
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e).Object() c := e.NewContext(nil, nil).(*context)
// Route > /users // Route > /users
r.Find(GET, "/users", c) r.Find(GET, "/users", c)
@@ -490,7 +490,7 @@ func TestRouterPriority(t *testing.T) {
func TestRouterPriorityNotFound(t *testing.T) { func TestRouterPriorityNotFound(t *testing.T) {
e := New() e := New()
r := e.router r := e.router
c := NewContext(nil, nil, e).Object() c := e.NewContext(nil, nil).(*context)
// Add // Add
r.Add(GET, "/a/foo", func(c Context) error { r.Add(GET, "/a/foo", func(c Context) error {
@@ -511,7 +511,7 @@ func TestRouterPriorityNotFound(t *testing.T) {
c.handler(c) c.handler(c)
assert.Equal(t, 2, c.Get("b")) assert.Equal(t, 2, c.Get("b"))
c = NewContext(nil, nil, e).Object() c = e.NewContext(nil, nil).(*context)
r.Find(GET, "/abc/def", c) r.Find(GET, "/abc/def", c)
he := c.handler(c).(*HTTPError) he := c.handler(c).(*HTTPError)
assert.Equal(t, http.StatusNotFound, he.Code) assert.Equal(t, http.StatusNotFound, he.Code)
@@ -532,7 +532,7 @@ func TestRouterParamNames(t *testing.T) {
r.Add(GET, "/users/:uid/files/:fid", func(c Context) error { r.Add(GET, "/users/:uid/files/:fid", func(c Context) error {
return nil return nil
}, e) }, e)
c := NewContext(nil, nil, e).Object() c := e.NewContext(nil, nil).(*context)
// Route > /users // Route > /users
r.Find(GET, "/users", c) r.Find(GET, "/users", c)
@@ -541,14 +541,14 @@ func TestRouterParamNames(t *testing.T) {
// Route > /users/:id // Route > /users/:id
r.Find(GET, "/users/1", c) r.Find(GET, "/users/1", c)
assert.Equal(t, "id", c.Object().pnames[0]) assert.Equal(t, "id", c.pnames[0])
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
// Route > /users/:uid/files/:fid // Route > /users/:uid/files/:fid
r.Find(GET, "/users/1/files/1", c) r.Find(GET, "/users/1/files/1", c)
assert.Equal(t, "uid", c.Object().pnames[0]) assert.Equal(t, "uid", c.pnames[0])
assert.Equal(t, "1", c.P(0)) assert.Equal(t, "1", c.P(0))
assert.Equal(t, "fid", c.Object().pnames[1]) assert.Equal(t, "fid", c.pnames[1])
assert.Equal(t, "1", c.P(1)) assert.Equal(t, "1", c.P(1))
} }
@@ -561,10 +561,10 @@ func TestRouterAPI(t *testing.T) {
return nil return nil
}, e) }, e)
} }
c := NewContext(nil, nil, e) c := e.NewContext(nil, nil).(*context)
for _, route := range api { for _, route := range api {
r.Find(route.Method, route.Path, c) r.Find(route.Method, route.Path, c)
for i, n := range c.Object().pnames { for i, n := range c.pnames {
if assert.NotEmpty(t, n) { if assert.NotEmpty(t, n) {
assert.Equal(t, ":"+n, c.P(i)) assert.Equal(t, ":"+n, c.P(i))
} }