1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

First commit to v3, #665

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2016-09-22 22:53:44 -07:00
parent 04f45046b1
commit 2aec0353f5
66 changed files with 656 additions and 3264 deletions

View File

@ -10,8 +10,6 @@ before_install:
script: script:
- go test -coverprofile=echo.coverprofile - go test -coverprofile=echo.coverprofile
- go test -coverprofile=middleware.coverprofile ./middleware - go test -coverprofile=middleware.coverprofile ./middleware
- go test -coverprofile=engine_standatd.coverprofile ./engine/standard
- go test -coverprofile=engine_fasthttp.coverprofile ./engine/fasthttp
- $HOME/gopath/bin/gover - $HOME/gopath/bin/gover
- $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci - $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service=travis-ci
matrix: matrix:

View File

@ -22,43 +22,47 @@ type (
func (b *binder) Bind(i interface{}, c Context) (err error) { func (b *binder) Bind(i interface{}, c Context) (err error) {
req := c.Request() req := c.Request()
if req.Method() == GET { if req.Method == GET {
if err = b.bindData(i, c.QueryParams()); err != nil { if err = b.bindData(i, c.QueryParams()); err != nil {
err = NewHTTPError(http.StatusBadRequest, err.Error()) return NewHTTPError(http.StatusBadRequest, err.Error())
} }
return return
} }
ctype := req.Header().Get(HeaderContentType) ctype := req.Header.Get(HeaderContentType)
if req.Body() == nil { if req.ContentLength == 0 {
err = NewHTTPError(http.StatusBadRequest, "request body can't be empty") return NewHTTPError(http.StatusBadRequest, "request body can't be empty")
return
} }
err = ErrUnsupportedMediaType
switch { switch {
case strings.HasPrefix(ctype, MIMEApplicationJSON): case strings.HasPrefix(ctype, MIMEApplicationJSON):
if err = json.NewDecoder(req.Body()).Decode(i); err != nil { if err = json.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*json.UnmarshalTypeError); ok { if ute, ok := err.(*json.UnmarshalTypeError); ok {
err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset)) return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unmarshal type error: expected=%v, got=%v, offset=%v", ute.Type, ute.Value, ute.Offset))
} else if se, ok := err.(*json.SyntaxError); ok { } else if se, ok := err.(*json.SyntaxError); ok {
err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: offset=%v, error=%v", se.Offset, se.Error())) return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: offset=%v, error=%v", se.Offset, se.Error()))
} else { } else {
err = NewHTTPError(http.StatusBadRequest, err.Error()) return NewHTTPError(http.StatusBadRequest, err.Error())
} }
} }
case strings.HasPrefix(ctype, MIMEApplicationXML): case strings.HasPrefix(ctype, MIMEApplicationXML):
if err = xml.NewDecoder(req.Body()).Decode(i); err != nil { if err = xml.NewDecoder(req.Body).Decode(i); err != nil {
if ute, ok := err.(*xml.UnsupportedTypeError); ok { if ute, ok := err.(*xml.UnsupportedTypeError); ok {
err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unsupported type error: type=%v, error=%v", ute.Type, ute.Error())) return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("unsupported type error: type=%v, error=%v", ute.Type, ute.Error()))
} else if se, ok := err.(*xml.SyntaxError); ok { } else if se, ok := err.(*xml.SyntaxError); ok {
err = NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: line=%v, error=%v", se.Line, se.Error())) return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("syntax error: line=%v, error=%v", se.Line, se.Error()))
} else { } else {
err = NewHTTPError(http.StatusBadRequest, err.Error()) return NewHTTPError(http.StatusBadRequest, err.Error())
} }
} }
case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
if err = b.bindData(i, req.FormParams()); err != nil { params, err := c.FormParams()
err = NewHTTPError(http.StatusBadRequest, err.Error()) if err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error())
} }
if err = b.bindData(i, params); err != nil {
return NewHTTPError(http.StatusBadRequest, err.Error())
}
default:
return ErrUnsupportedMediaType
} }
return return
} }
@ -100,8 +104,8 @@ func (b *binder) bindData(ptr interface{}, data map[string][]string) error {
if structFieldKind == reflect.Slice && numElems > 0 { if structFieldKind == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind() sliceOf := structField.Type().Elem().Kind()
slice := reflect.MakeSlice(structField.Type(), numElems, numElems) slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
for i := 0; i < numElems; i++ { for j := 0; j < numElems; j++ {
if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { if err := setWithProperType(sliceOf, inputValue[j], slice.Index(j)); err != nil {
return err return err
} }
} }

View File

@ -5,11 +5,11 @@ import (
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -70,19 +70,19 @@ func TestBinderForm(t *testing.T) {
testBinderOkay(t, strings.NewReader(userForm), MIMEApplicationForm) testBinderOkay(t, strings.NewReader(userForm), MIMEApplicationForm)
testBinderError(t, nil, MIMEApplicationForm) testBinderError(t, nil, MIMEApplicationForm)
e := New() e := New()
req := test.NewRequest(POST, "/", strings.NewReader(userForm)) req, _ := http.NewRequest(POST, "/", strings.NewReader(userForm))
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header().Set(HeaderContentType, MIMEApplicationForm) req.Header.Set(HeaderContentType, MIMEApplicationForm)
var obj = make([]struct{ Field string }, 0) obj := []struct{ Field string }{}
err := c.Bind(&obj) err := c.Bind(&obj)
assert.Error(t, err) assert.Error(t, err)
} }
func TestBinderQueryParams(t *testing.T) { func TestBinderQueryParams(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/?id=1&name=Jon Snow", nil) req, _ := http.NewRequest(GET, "/?id=1&name=Jon Snow", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
u := new(user) u := new(user)
err := c.Bind(u) err := c.Bind(u)
@ -105,11 +105,6 @@ func TestBinderUnsupportedMediaType(t *testing.T) {
testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON) testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON)
} }
// func assertCustomer(t *testing.T, c *user) {
// assert.Equal(t, 1, c.ID)
// assert.Equal(t, "Joe", c.Name)
// }
func TestBinderbindForm(t *testing.T) { func TestBinderbindForm(t *testing.T) {
ts := new(binderTestStruct) ts := new(binderTestStruct)
b := new(binder) b := new(binder)
@ -201,10 +196,10 @@ func assertBinderTestStruct(t *testing.T, ts *binderTestStruct) {
func testBinderOkay(t *testing.T, r io.Reader, ctype string) { func testBinderOkay(t *testing.T, r io.Reader, ctype string) {
e := New() e := New()
req := test.NewRequest(POST, "/", r) req, _ := http.NewRequest(POST, "/", r)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header().Set(HeaderContentType, ctype) req.Header.Set(HeaderContentType, ctype)
u := new(user) u := new(user)
err := c.Bind(u) err := c.Bind(u)
if assert.NoError(t, err) { if assert.NoError(t, err) {
@ -215,10 +210,10 @@ func testBinderOkay(t *testing.T, r io.Reader, ctype string) {
func testBinderError(t *testing.T, r io.Reader, ctype string) { func testBinderError(t *testing.T, r io.Reader, ctype string) {
e := New() e := New()
req := test.NewRequest(POST, "/", r) req, _ := http.NewRequest(POST, "/", r)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
req.Header().Set(HeaderContentType, ctype) req.Header.Set(HeaderContentType, ctype)
u := new(user) u := new(user)
err := c.Bind(u) err := c.Bind(u)

View File

@ -7,12 +7,14 @@ import (
"io" "io"
"mime" "mime"
"mime/multipart" "mime/multipart"
"net"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log" "github.com/labstack/echo/log"
"bytes" "bytes"
@ -32,11 +34,21 @@ type (
// SetStdContext sets `context.Context`. // SetStdContext sets `context.Context`.
SetStdContext(context.Context) SetStdContext(context.Context)
// Request returns `engine.Request` interface. // Request returns `*http.Request`.
Request() engine.Request Request() *http.Request
// Request returns `engine.Response` interface. // Request returns `*Response`.
Response() engine.Response Response() *Response
// IsTLS returns true if HTTP connection is TLS otherwise false.
IsTLS() bool
// Scheme returns the HTTP protocol scheme, `http` or `https`.
Scheme() string
// RealIP returns the client's network address based on `X-Forwarded-For`
// or `X-Real-IP` request header.
RealIP() string
// Path returns the registered path for the handler. // Path returns the registered path for the handler.
Path() string Path() string
@ -62,41 +74,35 @@ type (
// SetParamValues sets path parameter values. // SetParamValues sets path parameter values.
SetParamValues(...string) SetParamValues(...string)
// QueryParam returns the query param for the provided name. It is an alias // QueryParam returns the query param for the provided name.
// for `engine.URL#QueryParam()`.
QueryParam(string) string QueryParam(string) string
// QueryParams returns the query parameters as map. // QueryParams returns the query parameters as `url.Values`.
// It is an alias for `engine.URL#QueryParams()`. QueryParams() url.Values
QueryParams() map[string][]string
// FormValue returns the form field value for the provided name. It is an // QueryString returns the URL query string.
// alias for `engine.Request#FormValue()`. QueryString() string
// FormValue returns the form field value for the provided name.
FormValue(string) string FormValue(string) string
// FormParams returns the form parameters as map. // FormParams returns the form parameters as `url.Values`.
// It is an alias for `engine.Request#FormParams()`. FormParams() (url.Values, error)
FormParams() map[string][]string
// FormFile returns the multipart form file for the provided name. It is an // FormFile returns the multipart form file for the provided name.
// alias for `engine.Request#FormFile()`.
FormFile(string) (*multipart.FileHeader, error) FormFile(string) (*multipart.FileHeader, error)
// MultipartForm returns the multipart form. // MultipartForm returns the multipart form.
// It is an alias for `engine.Request#MultipartForm()`.
MultipartForm() (*multipart.Form, error) MultipartForm() (*multipart.Form, error)
// Cookie returns the named cookie provided in the request. // Cookie returns the named cookie provided in the request.
// It is an alias for `engine.Request#Cookie()`. Cookie(string) (*http.Cookie, error)
Cookie(string) (engine.Cookie, error)
// SetCookie adds a `Set-Cookie` header in HTTP response. // SetCookie adds a `Set-Cookie` header in HTTP response.
// It is an alias for `engine.Response#SetCookie()`. SetCookie(*http.Cookie)
SetCookie(engine.Cookie)
// Cookies returns the HTTP cookies sent with the request. // Cookies returns the HTTP cookies sent with the request.
// It is an alias for `engine.Request#Cookies()`. Cookies() []*http.Cookie
Cookies() []engine.Cookie
// Get retrieves data from the context. // Get retrieves data from the context.
Get(string) interface{} Get(string) interface{}
@ -184,23 +190,25 @@ type (
// Reset resets the context after request completes. It must be called along // Reset resets the context after request completes. It must be called along
// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`.
// See `Echo#ServeHTTP()` // See `Echo#ServeHTTP()`
Reset(engine.Request, engine.Response) Reset(*http.Request, http.ResponseWriter)
} }
echoContext struct { echoContext struct {
context context.Context context context.Context
request engine.Request request *http.Request
response engine.Response response *Response
path string path string
pnames []string pnames []string
pvalues []string pvalues []string
query url.Values
handler HandlerFunc handler HandlerFunc
echo *Echo echo *Echo
} }
) )
const ( const (
indexPage = "index.html" defaultMemory = 32 << 20 // 32 MB
indexPage = "index.html"
) )
func (c *echoContext) StdContext() context.Context { func (c *echoContext) StdContext() context.Context {
@ -227,14 +235,39 @@ func (c *echoContext) Value(key interface{}) interface{} {
return c.context.Value(key) return c.context.Value(key)
} }
func (c *echoContext) Request() engine.Request { func (c *echoContext) Request() *http.Request {
return c.request return c.request
} }
func (c *echoContext) Response() engine.Response { func (c *echoContext) Response() *Response {
return c.response return c.response
} }
func (c *echoContext) IsTLS() bool {
return c.request.TLS != nil
}
func (c *echoContext) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if c.IsTLS() {
return "https"
}
return "http"
}
func (c *echoContext) RealIP() string {
ra := c.request.RemoteAddr
if ip := c.request.Header.Get(HeaderXForwardedFor); ip != "" {
ra = ip
} else if ip := c.request.Header.Get(HeaderXRealIP); ip != "" {
ra = ip
} else {
ra, _, _ = net.SplitHostPort(ra)
}
return ra
}
func (c *echoContext) Path() string { func (c *echoContext) Path() string {
return c.path return c.path
} }
@ -279,38 +312,59 @@ func (c *echoContext) SetParamValues(values ...string) {
} }
func (c *echoContext) QueryParam(name string) string { func (c *echoContext) QueryParam(name string) string {
return c.request.URL().QueryParam(name) if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query.Get(name)
} }
func (c *echoContext) QueryParams() map[string][]string { func (c *echoContext) QueryParams() url.Values {
return c.request.URL().QueryParams() if c.query == nil {
c.query = c.request.URL.Query()
}
return c.query
}
func (c *echoContext) QueryString() string {
return c.request.URL.RawQuery
} }
func (c *echoContext) FormValue(name string) string { func (c *echoContext) FormValue(name string) string {
return c.request.FormValue(name) return c.request.FormValue(name)
} }
func (c *echoContext) FormParams() map[string][]string { func (c *echoContext) FormParams() (url.Values, error) {
return c.request.FormParams() if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) {
if err := c.request.ParseMultipartForm(defaultMemory); err != nil {
return nil, err
}
} else {
if err := c.request.ParseForm(); err != nil {
return nil, err
}
}
return c.request.Form, nil
} }
func (c *echoContext) FormFile(name string) (*multipart.FileHeader, error) { func (c *echoContext) FormFile(name string) (*multipart.FileHeader, error) {
return c.request.FormFile(name) _, fh, err := c.request.FormFile(name)
return fh, err
} }
func (c *echoContext) MultipartForm() (*multipart.Form, error) { func (c *echoContext) MultipartForm() (*multipart.Form, error) {
return c.request.MultipartForm() err := c.request.ParseMultipartForm(defaultMemory)
return c.request.MultipartForm, err
} }
func (c *echoContext) Cookie(name string) (engine.Cookie, error) { func (c *echoContext) Cookie(name string) (*http.Cookie, error) {
return c.request.Cookie(name) return c.request.Cookie(name)
} }
func (c *echoContext) SetCookie(cookie engine.Cookie) { func (c *echoContext) SetCookie(cookie *http.Cookie) {
c.response.SetCookie(cookie) http.SetCookie(c.Response(), cookie)
} }
func (c *echoContext) Cookies() []engine.Cookie { func (c *echoContext) Cookies() []*http.Cookie {
return c.request.Cookies() return c.request.Cookies()
} }
@ -323,15 +377,15 @@ func (c *echoContext) Get(key string) interface{} {
} }
func (c *echoContext) Bind(i interface{}) error { func (c *echoContext) Bind(i interface{}) error {
return c.echo.binder.Bind(i, c) return c.echo.Binder.Bind(i, c)
} }
func (c *echoContext) Render(code int, name string, data interface{}) (err error) { func (c *echoContext) Render(code int, name string, data interface{}) (err error) {
if c.echo.renderer == nil { if c.echo.Renderer == nil {
return ErrRendererNotRegistered return ErrRendererNotRegistered
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
if err = c.echo.renderer.Render(buf, name, data, c); err != nil { if err = c.echo.Renderer.Render(buf, name, data, c); err != nil {
return return
} }
c.response.Header().Set(HeaderContentType, MIMETextHTMLCharsetUTF8) c.response.Header().Set(HeaderContentType, MIMETextHTMLCharsetUTF8)
@ -356,7 +410,7 @@ func (c *echoContext) String(code int, s string) (err error) {
func (c *echoContext) JSON(code int, i interface{}) (err error) { func (c *echoContext) JSON(code int, i interface{}) (err error) {
b, err := json.Marshal(i) b, err := json.Marshal(i)
if c.echo.Debug() { if c.echo.Debug {
b, err = json.MarshalIndent(i, "", " ") b, err = json.MarshalIndent(i, "", " ")
} }
if err != nil { if err != nil {
@ -392,7 +446,7 @@ func (c *echoContext) JSONPBlob(code int, callback string, b []byte) (err error)
func (c *echoContext) XML(code int, i interface{}) (err error) { func (c *echoContext) XML(code int, i interface{}) (err error) {
b, err := xml.Marshal(i) b, err := xml.Marshal(i)
if c.echo.Debug() { if c.echo.Debug {
b, err = xml.MarshalIndent(i, "", " ") b, err = xml.MarshalIndent(i, "", " ")
} }
if err != nil { if err != nil {
@ -474,7 +528,7 @@ func (c *echoContext) Redirect(code int, url string) error {
} }
func (c *echoContext) Error(err error) { func (c *echoContext) Error(err error) {
c.echo.httpErrorHandler(err, c) c.echo.HTTPErrorHandler(err, c)
} }
func (c *echoContext) Echo() *Echo { func (c *echoContext) Echo() *Echo {
@ -490,14 +544,14 @@ func (c *echoContext) SetHandler(h HandlerFunc) {
} }
func (c *echoContext) Logger() log.Logger { func (c *echoContext) Logger() log.Logger {
return c.echo.logger return c.echo.Logger
} }
func (c *echoContext) ServeContent(content io.ReadSeeker, name string, modtime time.Time) error { func (c *echoContext) ServeContent(content io.ReadSeeker, name string, modtime time.Time) error {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
if t, err := time.Parse(http.TimeFormat, req.Header().Get(HeaderIfModifiedSince)); err == nil && modtime.Before(t.Add(1*time.Second)) { if t, err := time.Parse(http.TimeFormat, req.Header.Get(HeaderIfModifiedSince)); err == nil && modtime.Before(t.Add(1*time.Second)) {
res.Header().Del(HeaderContentType) res.Header().Del(HeaderContentType)
res.Header().Del(HeaderContentLength) res.Header().Del(HeaderContentLength)
return c.NoContent(http.StatusNotModified) return c.NoContent(http.StatusNotModified)
@ -520,9 +574,10 @@ func ContentTypeByExtension(name string) (t string) {
return return
} }
func (c *echoContext) Reset(req engine.Request, res engine.Response) { func (c *echoContext) Reset(r *http.Request, w http.ResponseWriter) {
// c.query = nil
c.context = context.Background() c.context = context.Background()
c.request = req c.request = r
c.response = res c.response.reset(w)
c.handler = NotFoundHandler c.handler = NotFoundHandler
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/http/httptest"
"os" "os"
"testing" "testing"
"text/template" "text/template"
@ -19,7 +20,6 @@ import (
"encoding/xml" "encoding/xml"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -35,185 +35,182 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context)
func TestContext(t *testing.T) { func TestContext(t *testing.T) {
e := New() e := New()
req := test.NewRequest(POST, "/", strings.NewReader(userJSON)) req, _ := http.NewRequest(POST, "/", strings.NewReader(userJSON))
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*echoContext) c := e.NewContext(req, rec).(*echoContext)
// Echo // Echo
assert.Equal(t, e, c.Echo()) assert.Equal(t, e, c.Echo())
// Request // Request
assert.Equal(t, req, c.Request()) assert.NotNil(t, c.Request())
// Response // Response
assert.Equal(t, rec, c.Response()) assert.NotNil(t, c.Response())
// Logger
assert.Equal(t, e.logger, c.Logger())
//-------- //--------
// Render // Render
//-------- //--------
tpl := &Template{ tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
} }
c.echo.SetRenderer(tpl) c.echo.Renderer = tmpl
err := c.Render(http.StatusOK, "hello", "Jon Snow") err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
} }
c.echo.renderer = nil c.echo.Renderer = nil
err = c.Render(http.StatusOK, "hello", "Jon Snow") err = c.Render(http.StatusOK, "hello", "Jon Snow")
assert.Error(t, err) assert.Error(t, err)
// JSON // JSON
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, userJSON, rec.Body.String()) assert.Equal(t, userJSON, rec.Body.String())
} }
// JSON (error) // JSON (error)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
err = c.JSON(http.StatusOK, make(chan bool)) err = c.JSON(http.StatusOK, make(chan bool))
assert.Error(t, err) assert.Error(t, err)
// JSONP // JSONP
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
callback := "callback" callback := "callback"
err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, callback+"("+userJSON+");", rec.Body.String()) assert.Equal(t, callback+"("+userJSON+");", rec.Body.String())
} }
// XML // XML
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
err = c.XML(http.StatusOK, user{1, "Jon Snow"}) err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, xml.Header+userXML, rec.Body.String()) assert.Equal(t, xml.Header+userXML, rec.Body.String())
} }
// XML (error) // XML (error)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
err = c.XML(http.StatusOK, make(chan bool)) err = c.XML(http.StatusOK, make(chan bool))
assert.Error(t, err) assert.Error(t, err)
// String // String
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
err = c.String(http.StatusOK, "Hello, World!") err = c.String(http.StatusOK, "Hello, World!")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, "Hello, World!", rec.Body.String()) assert.Equal(t, "Hello, World!", rec.Body.String())
} }
// HTML // HTML
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>") err = c.HTML(http.StatusOK, "Hello, <strong>World!</strong>")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType))
assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String()) assert.Equal(t, "Hello, <strong>World!</strong>", rec.Body.String())
} }
// Stream // Stream
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
r := strings.NewReader("response from a stream") r := strings.NewReader("response from a stream")
err = c.Stream(http.StatusOK, "application/octet-stream", r) err = c.Stream(http.StatusOK, "application/octet-stream", r)
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType))
assert.Equal(t, "response from a stream", rec.Body.String()) assert.Equal(t, "response from a stream", rec.Body.String())
} }
// Attachment // Attachment
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
file, err := os.Open("_fixture/images/walle.png") file, err := os.Open("_fixture/images/walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
err = c.Attachment(file, "walle.png") err = c.Attachment(file, "walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "attachment; filename=walle.png", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
} }
// Inline // Inline
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
file, err = os.Open("_fixture/images/walle.png") file, err = os.Open("_fixture/images/walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
err = c.Inline(file, "walle.png") err = c.Inline(file, "walle.png")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "inline; filename=walle.png", rec.Header().Get(HeaderContentDisposition)) assert.Equal(t, "inline; filename=walle.png", rec.Header().Get(HeaderContentDisposition))
assert.Equal(t, 219885, rec.Body.Len()) assert.Equal(t, 219885, rec.Body.Len())
} }
} }
// NoContent // NoContent
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
c.NoContent(http.StatusOK) c.NoContent(http.StatusOK)
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
// Error // Error
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec).(*echoContext) c = e.NewContext(req, rec).(*echoContext)
c.Error(errors.New("error")) c.Error(errors.New("error"))
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rec.Code)
// Reset // Reset
c.Reset(req, test.NewResponseRecorder()) c.Reset(req, httptest.NewRecorder())
} }
func TestContextCookie(t *testing.T) { func TestContextCookie(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/", nil) req, _ := http.NewRequest(GET, "/", nil)
theme := "theme=light" theme := "theme=light"
user := "user=Jon Snow" user := "user=Jon Snow"
req.Header().Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, theme)
req.Header().Add(HeaderCookie, user) req.Header.Add(HeaderCookie, user)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec).(*echoContext) c := e.NewContext(req, rec).(*echoContext)
// Read single // Read single
cookie, err := c.Cookie("theme") cookie, err := c.Cookie("theme")
if assert.NoError(t, err) { if assert.NoError(t, err) {
assert.Equal(t, "theme", cookie.Name()) assert.Equal(t, "theme", cookie.Name)
assert.Equal(t, "light", cookie.Value()) assert.Equal(t, "light", cookie.Value)
} }
// Read multiple // Read multiple
for _, cookie := range c.Cookies() { for _, cookie := range c.Cookies() {
switch cookie.Name() { switch cookie.Name {
case "theme": case "theme":
assert.Equal(t, "light", cookie.Value()) assert.Equal(t, "light", cookie.Value)
case "user": case "user":
assert.Equal(t, "Jon Snow", cookie.Value()) assert.Equal(t, "Jon Snow", cookie.Value)
} }
} }
// Write // Write
cookie = &test.Cookie{Cookie: &http.Cookie{ cookie = &http.Cookie{
Name: "SSID", Name: "SSID",
Value: "Ap4PGTEq", Value: "Ap4PGTEq",
Domain: "labstack.com", Domain: "labstack.com",
@ -221,7 +218,7 @@ func TestContextCookie(t *testing.T) {
Expires: time.Now(), Expires: time.Now(),
Secure: true, Secure: true,
HttpOnly: true, HttpOnly: true,
}} }
c.SetCookie(cookie) c.SetCookie(cookie)
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "SSID")
assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq") assert.Contains(t, rec.Header().Get(HeaderSetCookie), "Ap4PGTEq")
@ -247,7 +244,7 @@ func TestContextPath(t *testing.T) {
func TestContextPathParam(t *testing.T) { func TestContextPathParam(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/", nil) req, _ := http.NewRequest(GET, "/", nil)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
// ParamNames // ParamNames
@ -271,8 +268,8 @@ func TestContextFormValue(t *testing.T) {
f.Set("email", "jon@labstack.com") f.Set("email", "jon@labstack.com")
e := New() e := New()
req := test.NewRequest(POST, "/", strings.NewReader(f.Encode())) req, _ := http.NewRequest(POST, "/", strings.NewReader(f.Encode()))
req.Header().Add(HeaderContentType, MIMEApplicationForm) req.Header.Add(HeaderContentType, MIMEApplicationForm)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
// FormValue // FormValue
@ -280,17 +277,20 @@ func TestContextFormValue(t *testing.T) {
assert.Equal(t, "jon@labstack.com", c.FormValue("email")) assert.Equal(t, "jon@labstack.com", c.FormValue("email"))
// FormParams // FormParams
assert.Equal(t, map[string][]string{ params, err := c.FormParams()
"name": []string{"Jon Snow"}, if assert.NoError(t, err) {
"email": []string{"jon@labstack.com"}, assert.Equal(t, url.Values{
}, c.FormParams()) "name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"},
}, params)
}
} }
func TestContextQueryParam(t *testing.T) { func TestContextQueryParam(t *testing.T) {
q := make(url.Values) q := make(url.Values)
q.Set("name", "Jon Snow") q.Set("name", "Jon Snow")
q.Set("email", "jon@labstack.com") q.Set("email", "jon@labstack.com")
req := test.NewRequest(GET, "/?"+q.Encode(), nil) req, _ := http.NewRequest(GET, "/?"+q.Encode(), nil)
e := New() e := New()
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
@ -299,7 +299,7 @@ func TestContextQueryParam(t *testing.T) {
assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) assert.Equal(t, "jon@labstack.com", c.QueryParam("email"))
// QueryParams // QueryParams
assert.Equal(t, map[string][]string{ assert.Equal(t, url.Values{
"name": []string{"Jon Snow"}, "name": []string{"Jon Snow"},
"email": []string{"jon@labstack.com"}, "email": []string{"jon@labstack.com"},
}, c.QueryParams()) }, c.QueryParams())
@ -314,9 +314,9 @@ func TestContextFormFile(t *testing.T) {
w.Write([]byte("test")) w.Write([]byte("test"))
} }
mr.Close() mr.Close()
req := test.NewRequest(POST, "/", buf) req, _ := http.NewRequest(POST, "/", buf)
req.Header().Set(HeaderContentType, mr.FormDataContentType()) req.Header.Set(HeaderContentType, mr.FormDataContentType())
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
f, err := c.FormFile("file") f, err := c.FormFile("file")
if assert.NoError(t, err) { if assert.NoError(t, err) {
@ -330,9 +330,9 @@ func TestContextMultipartForm(t *testing.T) {
mw := multipart.NewWriter(buf) mw := multipart.NewWriter(buf)
mw.WriteField("name", "Jon Snow") mw.WriteField("name", "Jon Snow")
mw.Close() mw.Close()
req := test.NewRequest(POST, "/", buf) req, _ := http.NewRequest(POST, "/", buf)
req.Header().Set(HeaderContentType, mw.FormDataContentType()) req.Header.Set(HeaderContentType, mw.FormDataContentType())
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
f, err := c.MultipartForm() f, err := c.MultipartForm()
if assert.NoError(t, err) { if assert.NoError(t, err) {
@ -342,11 +342,11 @@ func TestContextMultipartForm(t *testing.T) {
func TestContextRedirect(t *testing.T) { func TestContextRedirect(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/", nil) req, _ := http.NewRequest(GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo")) assert.Equal(t, nil, c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo"))
assert.Equal(t, http.StatusMovedPermanently, rec.Status()) assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation)) assert.Equal(t, "http://labstack.github.io/echo", rec.Header().Get(HeaderLocation))
assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo"))
} }
@ -374,8 +374,8 @@ func TestContextStore(t *testing.T) {
func TestContextServeContent(t *testing.T) { func TestContextServeContent(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/", nil) req, _ := http.NewRequest(GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
fs := http.Dir("_fixture/images") fs := http.Dir("_fixture/images")
@ -385,15 +385,15 @@ func TestContextServeContent(t *testing.T) {
if assert.NoError(t, err) { if assert.NoError(t, err) {
// Not cached // Not cached
if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) { if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
} }
// Cached // Cached
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header().Set(HeaderIfModifiedSince, fi.ModTime().UTC().Format(http.TimeFormat)) req.Header.Set(HeaderIfModifiedSince, fi.ModTime().UTC().Format(http.TimeFormat))
if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) { if assert.NoError(t, c.ServeContent(f, fi.Name(), fi.ModTime())) {
assert.Equal(t, http.StatusNotModified, rec.Status()) assert.Equal(t, http.StatusNotModified, rec.Code)
} }
} }
} }

218
echo.go
View File

@ -39,18 +39,20 @@ package echo
import ( import (
"bytes" "bytes"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"path" "path"
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
"time"
"golang.org/x/net/context" "golang.org/x/net/context"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log" "github.com/labstack/echo/log"
glog "github.com/labstack/gommon/log" glog "github.com/labstack/gommon/log"
) )
@ -58,18 +60,24 @@ import (
type ( type (
// Echo is the top-level framework instance. // Echo is the top-level framework instance.
Echo struct { Echo struct {
server engine.Server Server *http.Server
premiddleware []MiddlewareFunc TLSCertFile string
middleware []MiddlewareFunc TLSKeyFile string
maxParam *int Listener net.Listener
notFoundHandler HandlerFunc // DisableHTTP2 disables HTTP2
httpErrorHandler HTTPErrorHandler DisableHTTP2 bool
binder Binder // Debug mode
renderer Renderer Debug bool
pool sync.Pool HTTPErrorHandler
debug bool Binder Binder
router *Router Renderer Renderer
logger log.Logger Logger log.Logger
premiddleware []MiddlewareFunc
middleware []MiddlewareFunc
maxParam *int
notFoundHandler HandlerFunc
pool sync.Pool
router *Router
} }
// Route contains a handler and information for matching against requests. // Route contains a handler and information for matching against requests.
@ -226,22 +234,21 @@ func New() (e *Echo) {
return e.NewContext(nil, nil) return e.NewContext(nil, nil)
} }
e.router = NewRouter(e) e.router = NewRouter(e)
// Defaults // Defaults
e.SetHTTPErrorHandler(e.DefaultHTTPErrorHandler) e.HTTPErrorHandler = e.DefaultHTTPErrorHandler
e.SetBinder(&binder{}) e.Binder = &binder{}
l := glog.New("echo") l := glog.New("echo")
l.SetLevel(glog.OFF) l.SetLevel(glog.OFF)
e.SetLogger(l) e.Logger = l
return return
} }
// NewContext returns a Context instance. // NewContext returns a Context instance.
func (e *Echo) NewContext(req engine.Request, res engine.Response) Context { func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context {
return &echoContext{ return &echoContext{
context: context.Background(), context: context.Background(),
request: req, request: r,
response: res, response: NewResponse(w, e),
echo: e, echo: e,
pvalues: make([]string, *e.maxParam), pvalues: make([]string, *e.maxParam),
handler: NotFoundHandler, handler: NotFoundHandler,
@ -253,26 +260,6 @@ func (e *Echo) Router() *Router {
return e.router return e.router
} }
// Logger returns the logger instance.
func (e *Echo) Logger() log.Logger {
return e.logger
}
// SetLogger defines a custom logger.
func (e *Echo) SetLogger(l log.Logger) {
e.logger = l
}
// SetLogOutput sets the output destination for the logger. Default value is `os.Std*`
func (e *Echo) SetLogOutput(w io.Writer) {
e.logger.SetOutput(w)
}
// SetLogLevel sets the log level for the logger. Default value ERROR.
func (e *Echo) SetLogLevel(l glog.Lvl) {
e.logger.SetLevel(l)
}
// DefaultHTTPErrorHandler invokes the default HTTP error handler. // DefaultHTTPErrorHandler invokes the default HTTP error handler.
func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
code := http.StatusInternalServerError code := http.StatusInternalServerError
@ -281,47 +268,17 @@ func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) {
code = he.Code code = he.Code
msg = he.Message msg = he.Message
} }
if e.debug { if e.Debug {
msg = err.Error() msg = err.Error()
} }
if !c.Response().Committed() { if !c.Response().Committed {
if c.Request().Method() == HEAD { // Issue #608 if c.Request().Method == HEAD { // Issue #608
c.NoContent(code) c.NoContent(code)
} else { } else {
c.String(code, msg) c.String(code, msg)
} }
} }
e.logger.Error(err) e.Logger.Error(err)
}
// SetHTTPErrorHandler registers a custom Echo.HTTPErrorHandler.
func (e *Echo) SetHTTPErrorHandler(h HTTPErrorHandler) {
e.httpErrorHandler = h
}
// SetBinder registers a custom binder. It's invoked by `Context#Bind()`.
func (e *Echo) SetBinder(b Binder) {
e.binder = b
}
// Binder returns the binder instance.
func (e *Echo) Binder() Binder {
return e.binder
}
// SetRenderer registers an HTML template renderer. It's invoked by `Context#Render()`.
func (e *Echo) SetRenderer(r Renderer) {
e.renderer = r
}
// SetDebug enables/disables debug mode.
func (e *Echo) SetDebug(on bool) {
e.debug = on
}
// Debug returns debug mode (enabled or disabled).
func (e *Echo) Debug() bool {
return e.debug
} }
// Pre adds middleware to the chain which is run before router. // Pre adds middleware to the chain which is run before router.
@ -340,99 +297,54 @@ func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(CONNECT, path, h, m...) e.add(CONNECT, path, h, m...)
} }
// Connect is deprecated, use `CONNECT()` instead.
func (e *Echo) Connect(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.CONNECT(path, h, m...)
}
// DELETE registers a new DELETE route for a path with matching handler in the router // DELETE registers a new DELETE route for a path with matching handler in the router
// with optional route-level middleware. // with optional route-level middleware.
func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(DELETE, path, h, m...) e.add(DELETE, path, h, m...)
} }
// Delete is deprecated, use `DELETE()` instead.
func (e *Echo) Delete(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.DELETE(path, h, m...)
}
// GET registers a new GET route for a path with matching handler in the router // GET registers a new GET route for a path with matching handler in the router
// with optional route-level middleware. // with optional route-level middleware.
func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(GET, path, h, m...) e.add(GET, path, h, m...)
} }
// Get is deprecated, use `GET()` instead.
func (e *Echo) Get(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.GET(path, h, m...)
}
// HEAD registers a new HEAD route for a path with matching handler in the // HEAD registers a new HEAD route for a path with matching handler in the
// router with optional route-level middleware. // router with optional route-level middleware.
func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(HEAD, path, h, m...) e.add(HEAD, path, h, m...)
} }
// Head is deprecated, use `HEAD()` instead.
func (e *Echo) Head(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.HEAD(path, h, m...)
}
// OPTIONS registers a new OPTIONS route for a path with matching handler in the // OPTIONS registers a new OPTIONS route for a path with matching handler in the
// router with optional route-level middleware. // router with optional route-level middleware.
func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(OPTIONS, path, h, m...) e.add(OPTIONS, path, h, m...)
} }
// Options is deprecated, use `OPTIONS()` instead.
func (e *Echo) Options(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.OPTIONS(path, h, m...)
}
// PATCH registers a new PATCH route for a path with matching handler in the // PATCH registers a new PATCH route for a path with matching handler in the
// router with optional route-level middleware. // router with optional route-level middleware.
func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(PATCH, path, h, m...) e.add(PATCH, path, h, m...)
} }
// Patch is deprecated, use `PATCH()` instead.
func (e *Echo) Patch(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.PATCH(path, h, m...)
}
// POST registers a new POST route for a path with matching handler in the // POST registers a new POST route for a path with matching handler in the
// router with optional route-level middleware. // router with optional route-level middleware.
func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(POST, path, h, m...) e.add(POST, path, h, m...)
} }
// Post is deprecated, use `POST()` instead.
func (e *Echo) Post(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.POST(path, h, m...)
}
// PUT registers a new PUT route for a path with matching handler in the // PUT registers a new PUT route for a path with matching handler in the
// router with optional route-level middleware. // router with optional route-level middleware.
func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(PUT, path, h, m...) e.add(PUT, path, h, m...)
} }
// Put is deprecated, use `PUT()` instead.
func (e *Echo) Put(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.PUT(path, h, m...)
}
// TRACE registers a new TRACE route for a path with matching handler in the // TRACE registers a new TRACE route for a path with matching handler in the
// router with optional route-level middleware. // router with optional route-level middleware.
func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) { func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.add(TRACE, path, h, m...) e.add(TRACE, path, h, m...)
} }
// Trace is deprecated, use `TRACE()` instead.
func (e *Echo) Trace(path string, h HandlerFunc, m ...MiddlewareFunc) {
e.TRACE(path, h, m...)
}
// Any registers a new route for all HTTP methods and path with matching handler // Any registers a new route for all HTTP methods and path with matching handler
// in the router with optional route-level middleware. // in the router with optional route-level middleware.
func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) { func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) {
@ -540,14 +452,15 @@ func (e *Echo) ReleaseContext(c Context) {
e.pool.Put(c) e.pool.Put(c)
} }
func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) { // ServeHTTP implements `http.Handler` interface, which serves HTTP requests.
func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c := e.pool.Get().(*echoContext) c := e.pool.Get().(*echoContext)
c.Reset(req, res) c.Reset(r, w)
// Middleware // Middleware
h := func(Context) error { h := func(Context) error {
method := req.Method() method := r.Method
path := req.URL().Path() path := r.URL.Path
e.router.Find(method, path, c) e.router.Find(method, path, c)
h := c.handler h := c.handler
for i := len(e.middleware) - 1; i >= 0; i-- { for i := len(e.middleware) - 1; i >= 0; i-- {
@ -563,27 +476,64 @@ func (e *Echo) ServeHTTP(req engine.Request, res engine.Response) {
// Execute chain // Execute chain
if err := h(c); err != nil { if err := h(c); err != nil {
e.httpErrorHandler(err, c) e.HTTPErrorHandler(err, c)
} }
e.pool.Put(c) e.pool.Put(c)
} }
// Run starts the HTTP server. // Run starts the HTTP server.
func (e *Echo) Run(s engine.Server) error { func (e *Echo) Run(address string) (err error) {
e.server = s if e.Server == nil {
s.SetHandler(e) e.Server = &http.Server{Handler: e}
s.SetLogger(e.logger)
if e.Debug() {
e.SetLogLevel(glog.DEBUG)
e.logger.Debug("running in debug mode")
} }
return s.Start() if e.Listener == nil {
e.Listener, err = net.Listen("tcp", address)
if err != nil {
return
}
if e.TLSCertFile != "" && e.TLSKeyFile != "" {
// TODO: https://github.com/golang/go/commit/d24f446a90ea94b87591bf16228d7d871fec3d92
config := &tls.Config{
NextProtos: []string{"http/1.1"},
}
if !e.DisableHTTP2 {
config.NextProtos = append(config.NextProtos, "h2")
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(e.TLSCertFile, e.TLSKeyFile)
if err != nil {
return
}
e.Listener = tls.NewListener(tcpKeepAliveListener{e.Listener.(*net.TCPListener)}, config)
} else {
e.Listener = tcpKeepAliveListener{e.Listener.(*net.TCPListener)}
}
}
return e.Server.Serve(e.Listener)
} }
// Stop stops the HTTP server. // Stop stops the HTTP server
func (e *Echo) Stop() error { func (e *Echo) Stop() error {
return e.server.Stop() return e.Listener.Close()
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
} }
// NewHTTPError creates a new HTTPError instance. // NewHTTPError creates a new HTTPError instance.

View File

@ -2,9 +2,8 @@ package echo
import ( import (
"bytes" "bytes"
"fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"reflect" "reflect"
@ -12,8 +11,6 @@ import (
"errors" "errors"
"github.com/labstack/echo/test"
"github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -33,20 +30,16 @@ const (
func TestEcho(t *testing.T) { func TestEcho(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/", nil) req, _ := http.NewRequest(GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
// Router // Router
assert.NotNil(t, e.Router()) assert.NotNil(t, e.Router())
// Debug
e.SetDebug(true)
assert.True(t, e.debug)
// DefaultHTTPErrorHandler // DefaultHTTPErrorHandler
e.DefaultHTTPErrorHandler(errors.New("error"), c) e.DefaultHTTPErrorHandler(errors.New("error"), c)
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rec.Code)
} }
func TestEchoStatic(t *testing.T) { func TestEchoStatic(t *testing.T) {
@ -306,10 +299,10 @@ func TestEchoGroup(t *testing.T) {
func TestEchoNotFound(t *testing.T) { func TestEchoNotFound(t *testing.T) {
e := New() e := New()
req := test.NewRequest(GET, "/files", nil) req, _ := http.NewRequest(GET, "/files", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(req, rec) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNotFound, rec.Status()) assert.Equal(t, http.StatusNotFound, rec.Code)
} }
func TestEchoMethodNotAllowed(t *testing.T) { func TestEchoMethodNotAllowed(t *testing.T) {
@ -317,10 +310,10 @@ func TestEchoMethodNotAllowed(t *testing.T) {
e.GET("/", func(c Context) error { e.GET("/", func(c Context) error {
return c.String(http.StatusOK, "Echo!") return c.String(http.StatusOK, "Echo!")
}) })
req := test.NewRequest(POST, "/", nil) req, _ := http.NewRequest(POST, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(req, rec) e.ServeHTTP(rec, req)
assert.Equal(t, http.StatusMethodNotAllowed, rec.Status()) assert.Equal(t, http.StatusMethodNotAllowed, rec.Code)
} }
func TestEchoHTTPError(t *testing.T) { func TestEchoHTTPError(t *testing.T) {
@ -337,25 +330,13 @@ func TestEchoContext(t *testing.T) {
e.ReleaseContext(c) e.ReleaseContext(c)
} }
func TestEchoLogger(t *testing.T) {
e := New()
l := log.New("test")
e.SetLogger(l)
assert.Equal(t, l, e.Logger())
e.SetLogOutput(ioutil.Discard)
assert.Equal(t, l.Output(), ioutil.Discard)
e.SetLogLevel(log.OFF)
assert.Equal(t, l.Level(), log.OFF)
}
func testMethod(t *testing.T, method, path string, e *Echo) { func testMethod(t *testing.T, method, path string, e *Echo) {
m := fmt.Sprintf("%c%s", method[0], strings.ToLower(method[1:]))
p := reflect.ValueOf(path) p := reflect.ValueOf(path)
h := reflect.ValueOf(func(c Context) error { h := reflect.ValueOf(func(c Context) error {
return c.String(http.StatusOK, method) return c.String(http.StatusOK, method)
}) })
i := interface{}(e) i := interface{}(e)
reflect.ValueOf(i).MethodByName(m).Call([]reflect.Value{p, h}) reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h})
_, body := request(method, path, e) _, body := request(method, path, e)
if body != method { if body != method {
t.Errorf("expected body `%s`, got %s.", method, body) t.Errorf("expected body `%s`, got %s.", method, body)
@ -363,15 +344,8 @@ func testMethod(t *testing.T, method, path string, e *Echo) {
} }
func request(method, path string, e *Echo) (int, string) { func request(method, path string, e *Echo) (int, string) {
req := test.NewRequest(method, path, nil) req, _ := http.NewRequest(method, path, nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(req, rec) e.ServeHTTP(rec, req)
return rec.Status(), rec.Body.String() return rec.Code, rec.Body.String()
}
func TestEchoBinder(t *testing.T) {
e := New()
b := &binder{}
e.SetBinder(b)
assert.Equal(t, b, e.Binder())
} }

View File

@ -1,233 +0,0 @@
package engine
import (
"io"
"mime/multipart"
"time"
"net"
"github.com/labstack/echo/log"
)
type (
// Server defines the interface for HTTP server.
Server interface {
// SetHandler sets the handler for the HTTP server.
SetHandler(Handler)
// SetLogger sets the logger for the HTTP server.
SetLogger(log.Logger)
// Start starts the HTTP server.
Start() error
// Stop stops the HTTP server by closing underlying TCP connection.
Stop() error
}
// Request defines the interface for HTTP request.
Request interface {
// IsTLS returns true if HTTP connection is TLS otherwise false.
IsTLS() bool
// Scheme returns the HTTP protocol scheme, `http` or `https`.
Scheme() string
// Host returns HTTP request host. Per RFC 2616, this is either the value of
// the `Host` header or the host name given in the URL itself.
Host() string
// SetHost sets the host of the request.
SetHost(string)
// URI returns the unmodified `Request-URI` sent by the client.
URI() string
// SetURI sets the URI of the request.
SetURI(string)
// URL returns `engine.URL`.
URL() URL
// Header returns `engine.Header`.
Header() Header
// Referer returns the referring URL, if sent in the request.
Referer() string
// Protocol returns the protocol version string of the HTTP request.
// Protocol() string
// ProtocolMajor returns the major protocol version of the HTTP request.
// ProtocolMajor() int
// ProtocolMinor returns the minor protocol version of the HTTP request.
// ProtocolMinor() int
// ContentLength returns the size of request's body.
ContentLength() int64
// UserAgent returns the client's `User-Agent`.
UserAgent() string
// RemoteAddress returns the client's network address.
RemoteAddress() string
// RealIP returns the client's network address based on `X-Forwarded-For`
// or `X-Real-IP` request header.
RealIP() string
// Method returns the request's HTTP function.
Method() string
// SetMethod sets the HTTP method of the request.
SetMethod(string)
// Body returns request's body.
Body() io.Reader
// Body sets request's body.
SetBody(io.Reader)
// FormValue returns the form field value for the provided name.
FormValue(string) string
// FormParams returns the form parameters.
FormParams() map[string][]string
// FormFile returns the multipart form file for the provided name.
FormFile(string) (*multipart.FileHeader, error)
// MultipartForm returns the multipart form.
MultipartForm() (*multipart.Form, error)
// Cookie returns the named cookie provided in the request.
Cookie(string) (Cookie, error)
// Cookies returns the HTTP cookies sent with the request.
Cookies() []Cookie
}
// Response defines the interface for HTTP response.
Response interface {
// Header returns `engine.Header`
Header() Header
// WriteHeader sends an HTTP response header with status code.
WriteHeader(int)
// Write writes the data to the connection as part of an HTTP reply.
Write(b []byte) (int, error)
// SetCookie adds a `Set-Cookie` header in HTTP response.
SetCookie(Cookie)
// Status returns the HTTP response status.
Status() int
// Size returns the number of bytes written to HTTP response.
Size() int64
// Committed returns true if HTTP response header is written, otherwise false.
Committed() bool
// Write returns the HTTP response writer.
Writer() io.Writer
// SetWriter sets the HTTP response writer.
SetWriter(io.Writer)
}
// Header defines the interface for HTTP header.
Header interface {
// Add adds the key, value pair to the header. It appends to any existing values
// associated with key.
Add(string, string)
// Del deletes the values associated with key.
Del(string)
// Set sets the header entries associated with key to the single element value.
// It replaces any existing values associated with key.
Set(string, string)
// Get gets the first value associated with the given key. If there are
// no values associated with the key, Get returns "".
Get(string) string
// Keys returns the header keys.
Keys() []string
// Contains checks if the header is set.
Contains(string) bool
}
// URL defines the interface for HTTP request url.
URL interface {
// Path returns the request URL path.
Path() string
// SetPath sets the request URL path.
SetPath(string)
// QueryParam returns the query param for the provided name.
QueryParam(string) string
// QueryParam returns the query parameters as map.
QueryParams() map[string][]string
// QueryString returns the URL query string.
QueryString() string
}
// Cookie defines the interface for HTTP cookie.
Cookie interface {
// Name returns the name of the cookie.
Name() string
// Value returns the value of the cookie.
Value() string
// Path returns the path of the cookie.
Path() string
// Domain returns the domain of the cookie.
Domain() string
// Expires returns the expiry time of the cookie.
Expires() time.Time
// Secure indicates if cookie is secured.
Secure() bool
// HTTPOnly indicate if cookies is HTTP only.
HTTPOnly() bool
}
// Config defines engine config.
Config struct {
Address string // TCP address to listen on.
Listener net.Listener // Custom `net.Listener`. If set, server accepts connections on it.
TLSCertFile string // TLS certificate file path.
TLSKeyFile string // TLS key file path.
DisableHTTP2 bool // Disables HTTP/2.
ReadTimeout time.Duration // Maximum duration before timing out read of the request.
WriteTimeout time.Duration // Maximum duration before timing out write of the response.
}
// Handler defines an interface to server HTTP requests via `ServeHTTP(Request, Response)`
// function.
Handler interface {
ServeHTTP(Request, Response)
}
// HandlerFunc is an adapter to allow the use of `func(Request, Response)` as
// an HTTP handler.
HandlerFunc func(Request, Response)
)
// ServeHTTP serves HTTP request.
func (h HandlerFunc) ServeHTTP(req Request, res Response) {
h(req, res)
}

View File

@ -1,49 +0,0 @@
package fasthttp
import (
"time"
"github.com/valyala/fasthttp"
)
type (
// Cookie implements `engine.Cookie`.
Cookie struct {
*fasthttp.Cookie
}
)
// Name implements `engine.Cookie#Name` function.
func (c *Cookie) Name() string {
return string(c.Cookie.Key())
}
// Value implements `engine.Cookie#Value` function.
func (c *Cookie) Value() string {
return string(c.Cookie.Value())
}
// Path implements `engine.Cookie#Path` function.
func (c *Cookie) Path() string {
return string(c.Cookie.Path())
}
// Domain implements `engine.Cookie#Domain` function.
func (c *Cookie) Domain() string {
return string(c.Cookie.Domain())
}
// Expires implements `engine.Cookie#Expires` function.
func (c *Cookie) Expires() time.Time {
return c.Cookie.Expire()
}
// Secure implements `engine.Cookie#Secure` function.
func (c *Cookie) Secure() bool {
return c.Cookie.Secure()
}
// HTTPOnly implements `engine.Cookie#HTTPOnly` function.
func (c *Cookie) HTTPOnly() bool {
return c.Cookie.HTTPOnly()
}

View File

@ -1,24 +0,0 @@
package fasthttp
import (
"github.com/labstack/echo/engine/test"
fast "github.com/valyala/fasthttp"
"testing"
"time"
)
func TestCookie(t *testing.T) {
fCookie := &fast.Cookie{}
fCookie.SetKey("session")
fCookie.SetValue("securetoken")
fCookie.SetPath("/")
fCookie.SetDomain("github.com")
fCookie.SetExpire(time.Date(2016, time.January, 1, 0, 0, 0, 0, time.UTC))
fCookie.SetSecure(true)
fCookie.SetHTTPOnly(true)
cookie := &Cookie{
fCookie,
}
test.CookieTest(t, cookie)
}

View File

@ -1,97 +0,0 @@
// +build !appengine
package fasthttp
import "github.com/valyala/fasthttp"
type (
// RequestHeader holds `fasthttp.RequestHeader`.
RequestHeader struct {
*fasthttp.RequestHeader
}
// ResponseHeader holds `fasthttp.ResponseHeader`.
ResponseHeader struct {
*fasthttp.ResponseHeader
}
)
// Add implements `engine.Header#Add` function.
func (h *RequestHeader) Add(key, val string) {
h.RequestHeader.Add(key, val)
}
// Del implements `engine.Header#Del` function.
func (h *RequestHeader) Del(key string) {
h.RequestHeader.Del(key)
}
// Set implements `engine.Header#Set` function.
func (h *RequestHeader) Set(key, val string) {
h.RequestHeader.Set(key, val)
}
// Get implements `engine.Header#Get` function.
func (h *RequestHeader) Get(key string) string {
return string(h.Peek(key))
}
// Keys implements `engine.Header#Keys` function.
func (h *RequestHeader) Keys() (keys []string) {
keys = make([]string, h.Len())
i := 0
h.VisitAll(func(k, v []byte) {
keys[i] = string(k)
i++
})
return
}
// Contains implements `engine.Header#Contains` function.
func (h *RequestHeader) Contains(key string) bool {
return h.Peek(key) != nil
}
func (h *RequestHeader) reset(hdr *fasthttp.RequestHeader) {
h.RequestHeader = hdr
}
// Add implements `engine.Header#Add` function.
func (h *ResponseHeader) Add(key, val string) {
h.ResponseHeader.Add(key, val)
}
// Del implements `engine.Header#Del` function.
func (h *ResponseHeader) Del(key string) {
h.ResponseHeader.Del(key)
}
// Get implements `engine.Header#Get` function.
func (h *ResponseHeader) Get(key string) string {
return string(h.Peek(key))
}
// Set implements `engine.Header#Set` function.
func (h *ResponseHeader) Set(key, val string) {
h.ResponseHeader.Set(key, val)
}
// Keys implements `engine.Header#Keys` function.
func (h *ResponseHeader) Keys() (keys []string) {
keys = make([]string, h.Len())
i := 0
h.VisitAll(func(k, v []byte) {
keys[i] = string(k)
i++
})
return
}
// Contains implements `engine.Header#Contains` function.
func (h *ResponseHeader) Contains(key string) bool {
return h.Peek(key) != nil
}
func (h *ResponseHeader) reset(hdr *fasthttp.ResponseHeader) {
h.ResponseHeader = hdr
}

View File

@ -1,24 +0,0 @@
package fasthttp
import (
"github.com/labstack/echo/engine/test"
"github.com/stretchr/testify/assert"
fast "github.com/valyala/fasthttp"
"testing"
)
func TestRequestHeader(t *testing.T) {
header := &RequestHeader{&fast.RequestHeader{}}
test.HeaderTest(t, header)
header.reset(&fast.RequestHeader{})
assert.Len(t, header.Keys(), 0)
}
func TestResponseHeader(t *testing.T) {
header := &ResponseHeader{&fast.ResponseHeader{}}
test.HeaderTest(t, header)
header.reset(&fast.ResponseHeader{})
assert.Len(t, header.Keys(), 1)
}

View File

@ -1,198 +0,0 @@
// +build !appengine
package fasthttp
import (
"bytes"
"io"
"mime/multipart"
"net"
"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log"
"github.com/valyala/fasthttp"
)
type (
// Request implements `engine.Request`.
Request struct {
*fasthttp.RequestCtx
header engine.Header
url engine.URL
logger log.Logger
}
)
// NewRequest returns `Request` instance.
func NewRequest(c *fasthttp.RequestCtx, l log.Logger) *Request {
return &Request{
RequestCtx: c,
url: &URL{URI: c.URI()},
header: &RequestHeader{RequestHeader: &c.Request.Header},
logger: l,
}
}
// IsTLS implements `engine.Request#TLS` function.
func (r *Request) IsTLS() bool {
return r.RequestCtx.IsTLS()
}
// Scheme implements `engine.Request#Scheme` function.
func (r *Request) Scheme() string {
return string(r.RequestCtx.URI().Scheme())
}
// Host implements `engine.Request#Host` function.
func (r *Request) Host() string {
return string(r.RequestCtx.Host())
}
// SetHost implements `engine.Request#SetHost` function.
func (r *Request) SetHost(host string) {
r.RequestCtx.Request.SetHost(host)
}
// URL implements `engine.Request#URL` function.
func (r *Request) URL() engine.URL {
return r.url
}
// Header implements `engine.Request#Header` function.
func (r *Request) Header() engine.Header {
return r.header
}
// Referer implements `engine.Request#Referer` function.
func (r *Request) Referer() string {
return string(r.Request.Header.Referer())
}
// ContentLength implements `engine.Request#ContentLength` function.
func (r *Request) ContentLength() int64 {
return int64(r.Request.Header.ContentLength())
}
// UserAgent implements `engine.Request#UserAgent` function.
func (r *Request) UserAgent() string {
return string(r.RequestCtx.UserAgent())
}
// RemoteAddress implements `engine.Request#RemoteAddress` function.
func (r *Request) RemoteAddress() string {
return r.RemoteAddr().String()
}
// RealIP implements `engine.Request#RealIP` function.
func (r *Request) RealIP() string {
ra := r.RemoteAddress()
if ip := r.Header().Get(echo.HeaderXForwardedFor); ip != "" {
ra = ip
} else if ip := r.Header().Get(echo.HeaderXRealIP); ip != "" {
ra = ip
} else {
ra, _, _ = net.SplitHostPort(ra)
}
return ra
}
// Method implements `engine.Request#Method` function.
func (r *Request) Method() string {
return string(r.RequestCtx.Method())
}
// SetMethod implements `engine.Request#SetMethod` function.
func (r *Request) SetMethod(method string) {
r.Request.Header.SetMethodBytes([]byte(method))
}
// URI implements `engine.Request#URI` function.
func (r *Request) URI() string {
return string(r.RequestURI())
}
// SetURI implements `engine.Request#SetURI` function.
func (r *Request) SetURI(uri string) {
r.Request.Header.SetRequestURI(uri)
}
// Body implements `engine.Request#Body` function.
func (r *Request) Body() io.Reader {
return bytes.NewBuffer(r.Request.Body())
}
// SetBody implements `engine.Request#SetBody` function.
func (r *Request) SetBody(reader io.Reader) {
r.Request.SetBodyStream(reader, 0)
}
// FormValue implements `engine.Request#FormValue` function.
func (r *Request) FormValue(name string) string {
return string(r.RequestCtx.FormValue(name))
}
// FormParams implements `engine.Request#FormParams` function.
func (r *Request) FormParams() (params map[string][]string) {
params = make(map[string][]string)
mf, err := r.RequestCtx.MultipartForm()
if err == fasthttp.ErrNoMultipartForm {
r.PostArgs().VisitAll(func(k, v []byte) {
key := string(k)
if _, ok := params[key]; ok {
params[key] = append(params[key], string(v))
} else {
params[string(k)] = []string{string(v)}
}
})
} else if err == nil {
for k, v := range mf.Value {
if len(v) > 0 {
params[k] = v
}
}
}
return
}
// FormFile implements `engine.Request#FormFile` function.
func (r *Request) FormFile(name string) (*multipart.FileHeader, error) {
return r.RequestCtx.FormFile(name)
}
// MultipartForm implements `engine.Request#MultipartForm` function.
func (r *Request) MultipartForm() (*multipart.Form, error) {
return r.RequestCtx.MultipartForm()
}
// Cookie implements `engine.Request#Cookie` function.
func (r *Request) Cookie(name string) (engine.Cookie, error) {
c := new(fasthttp.Cookie)
b := r.Request.Header.Cookie(name)
if b == nil {
return nil, echo.ErrCookieNotFound
}
c.SetKey(name)
c.SetValueBytes(b)
return &Cookie{c}, nil
}
// Cookies implements `engine.Request#Cookies` function.
func (r *Request) Cookies() []engine.Cookie {
cookies := []engine.Cookie{}
r.Request.Header.VisitAllCookie(func(name, value []byte) {
c := new(fasthttp.Cookie)
c.SetKeyBytes(name)
c.SetValueBytes(value)
cookies = append(cookies, &Cookie{c})
})
return cookies
}
func (r *Request) reset(c *fasthttp.RequestCtx, h engine.Header, u engine.URL) {
r.RequestCtx = c
r.header = h
r.url = u
}

View File

@ -1,31 +0,0 @@
package fasthttp
import (
"bufio"
"bytes"
"net"
"net/url"
"testing"
"github.com/labstack/echo/engine/test"
"github.com/labstack/gommon/log"
fast "github.com/valyala/fasthttp"
)
type fakeAddr struct {
addr string
net.Addr
}
func (a fakeAddr) String() string {
return a.addr
}
func TestRequest(t *testing.T) {
ctx := new(fast.RequestCtx)
url, _ := url.Parse("http://github.com/labstack/echo")
ctx.Init(&fast.Request{}, fakeAddr{addr: "127.0.0.1"}, nil)
ctx.Request.Read(bufio.NewReader(bytes.NewBufferString(test.MultipartRequest)))
ctx.Request.SetRequestURI(url.String())
test.RequestTest(t, NewRequest(ctx, log.New("echo")))
}

View File

@ -1,108 +0,0 @@
// +build !appengine
package fasthttp
import (
"io"
"net/http"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log"
"github.com/valyala/fasthttp"
)
type (
// Response implements `engine.Response`.
Response struct {
*fasthttp.RequestCtx
header engine.Header
status int
size int64
committed bool
writer io.Writer
logger log.Logger
}
)
// NewResponse returns `Response` instance.
func NewResponse(c *fasthttp.RequestCtx, l log.Logger) *Response {
return &Response{
RequestCtx: c,
header: &ResponseHeader{ResponseHeader: &c.Response.Header},
writer: c,
logger: l,
}
}
// Header implements `engine.Response#Header` function.
func (r *Response) Header() engine.Header {
return r.header
}
// WriteHeader implements `engine.Response#WriteHeader` function.
func (r *Response) WriteHeader(code int) {
if r.committed {
r.logger.Warn("response already committed")
return
}
r.status = code
r.SetStatusCode(code)
r.committed = true
}
// Write implements `engine.Response#Write` function.
func (r *Response) Write(b []byte) (n int, err error) {
if !r.committed {
r.WriteHeader(http.StatusOK)
}
n, err = r.writer.Write(b)
r.size += int64(n)
return
}
// SetCookie implements `engine.Response#SetCookie` function.
func (r *Response) SetCookie(c engine.Cookie) {
cookie := new(fasthttp.Cookie)
cookie.SetKey(c.Name())
cookie.SetValue(c.Value())
cookie.SetPath(c.Path())
cookie.SetDomain(c.Domain())
cookie.SetExpire(c.Expires())
cookie.SetSecure(c.Secure())
cookie.SetHTTPOnly(c.HTTPOnly())
r.Response.Header.SetCookie(cookie)
}
// Status implements `engine.Response#Status` function.
func (r *Response) Status() int {
return r.status
}
// Size implements `engine.Response#Size` function.
func (r *Response) Size() int64 {
return r.size
}
// Committed implements `engine.Response#Committed` function.
func (r *Response) Committed() bool {
return r.committed
}
// Writer implements `engine.Response#Writer` function.
func (r *Response) Writer() io.Writer {
return r.writer
}
// SetWriter implements `engine.Response#SetWriter` function.
func (r *Response) SetWriter(w io.Writer) {
r.writer = w
}
func (r *Response) reset(c *fasthttp.RequestCtx, h engine.Header) {
r.RequestCtx = c
r.header = h
r.status = http.StatusOK
r.size = 0
r.committed = false
r.writer = c
}

View File

@ -1,41 +0,0 @@
package fasthttp
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
"github.com/labstack/gommon/log"
)
func TestResponseWriteHeader(t *testing.T) {
c := new(fasthttp.RequestCtx)
res := NewResponse(c, log.New("test"))
res.WriteHeader(http.StatusOK)
assert.True(t, res.Committed())
assert.Equal(t, http.StatusOK, res.Status())
}
func TestResponseWrite(t *testing.T) {
c := new(fasthttp.RequestCtx)
res := NewResponse(c, log.New("test"))
res.Write([]byte("test"))
assert.Equal(t, int64(4), res.Size())
assert.Equal(t, "test", string(c.Response.Body()))
}
func TestResponseSetCookie(t *testing.T) {
c := new(fasthttp.RequestCtx)
res := NewResponse(c, log.New("test"))
cookie := new(fasthttp.Cookie)
cookie.SetKey("name")
cookie.SetValue("Jon Snow")
res.SetCookie(&Cookie{cookie})
c.Response.Header.SetCookie(cookie)
ck := new(fasthttp.Cookie)
ck.SetKey("name")
assert.True(t, c.Response.Header.Cookie(ck))
assert.Equal(t, "Jon Snow", string(ck.Value()))
}

View File

@ -1,188 +0,0 @@
// +build !appengine
package fasthttp
import (
"sync"
"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log"
glog "github.com/labstack/gommon/log"
"github.com/valyala/fasthttp"
)
type (
// Server implements `engine.Server`.
Server struct {
*fasthttp.Server
config engine.Config
handler engine.Handler
logger log.Logger
pool *pool
}
pool struct {
request sync.Pool
response sync.Pool
requestHeader sync.Pool
responseHeader sync.Pool
url sync.Pool
}
)
// New returns `Server` with provided listen address.
func New(addr string) *Server {
c := engine.Config{Address: addr}
return WithConfig(c)
}
// WithTLS returns `Server` with provided TLS config.
func WithTLS(addr, certFile, keyFile string) *Server {
c := engine.Config{
Address: addr,
TLSCertFile: certFile,
TLSKeyFile: keyFile,
}
return WithConfig(c)
}
// WithConfig returns `Server` with provided config.
func WithConfig(c engine.Config) (s *Server) {
s = &Server{
Server: new(fasthttp.Server),
config: c,
pool: &pool{
request: sync.Pool{
New: func() interface{} {
return &Request{logger: s.logger}
},
},
response: sync.Pool{
New: func() interface{} {
return &Response{logger: s.logger}
},
},
requestHeader: sync.Pool{
New: func() interface{} {
return &RequestHeader{}
},
},
responseHeader: sync.Pool{
New: func() interface{} {
return &ResponseHeader{}
},
},
url: sync.Pool{
New: func() interface{} {
return &URL{}
},
},
},
handler: engine.HandlerFunc(func(req engine.Request, res engine.Response) {
panic("echo: handler not set, use `Server#SetHandler()` to set it.")
}),
logger: glog.New("echo"),
}
s.ReadTimeout = c.ReadTimeout
s.WriteTimeout = c.WriteTimeout
s.Handler = s.ServeHTTP
return
}
// SetHandler implements `engine.Server#SetHandler` function.
func (s *Server) SetHandler(h engine.Handler) {
s.handler = h
}
// SetLogger implements `engine.Server#SetLogger` function.
func (s *Server) SetLogger(l log.Logger) {
s.logger = l
}
// Start implements `engine.Server#Start` function.
func (s *Server) Start() error {
if s.config.Listener == nil {
return s.startDefaultListener()
}
return s.startCustomListener()
}
// Stop implements `engine.Server#Stop` function.
func (s *Server) Stop() error {
// TODO: implement `engine.Server#Stop` function
return nil
}
func (s *Server) startDefaultListener() error {
c := s.config
if c.TLSCertFile != "" && c.TLSKeyFile != "" {
return s.ListenAndServeTLS(c.Address, c.TLSCertFile, c.TLSKeyFile)
}
return s.ListenAndServe(c.Address)
}
func (s *Server) startCustomListener() error {
c := s.config
if c.TLSCertFile != "" && c.TLSKeyFile != "" {
return s.ServeTLS(c.Listener, c.TLSCertFile, c.TLSKeyFile)
}
return s.Serve(c.Listener)
}
func (s *Server) ServeHTTP(c *fasthttp.RequestCtx) {
// Request
req := s.pool.request.Get().(*Request)
reqHdr := s.pool.requestHeader.Get().(*RequestHeader)
reqURL := s.pool.url.Get().(*URL)
reqHdr.reset(&c.Request.Header)
reqURL.reset(c.URI())
req.reset(c, reqHdr, reqURL)
// Response
res := s.pool.response.Get().(*Response)
resHdr := s.pool.responseHeader.Get().(*ResponseHeader)
resHdr.reset(&c.Response.Header)
res.reset(c, resHdr)
s.handler.ServeHTTP(req, res)
// Return to pool
s.pool.request.Put(req)
s.pool.requestHeader.Put(reqHdr)
s.pool.url.Put(reqURL)
s.pool.response.Put(res)
s.pool.responseHeader.Put(resHdr)
}
// WrapHandler wraps `fasthttp.RequestHandler` into `echo.HandlerFunc`.
func WrapHandler(h fasthttp.RequestHandler) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request().(*Request)
res := c.Response().(*Response)
ctx := req.RequestCtx
h(ctx)
res.status = ctx.Response.StatusCode()
res.size = int64(ctx.Response.Header.ContentLength())
return nil
}
}
// WrapMiddleware wraps `func(fasthttp.RequestHandler) fasthttp.RequestHandler`
// into `echo.MiddlewareFunc`
func WrapMiddleware(m func(fasthttp.RequestHandler) fasthttp.RequestHandler) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
req := c.Request().(*Request)
res := c.Response().(*Response)
ctx := req.RequestCtx
m(func(ctx *fasthttp.RequestCtx) {
next(c)
})(ctx)
res.status = ctx.Response.StatusCode()
res.size = int64(ctx.Response.Header.ContentLength())
return
}
}
}

View File

@ -1,59 +0,0 @@
package fasthttp
import (
"bytes"
"net/http"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/stretchr/testify/assert"
"github.com/valyala/fasthttp"
)
// TODO: Fix me
func TestServer(t *testing.T) {
s := New("")
s.SetHandler(engine.HandlerFunc(func(req engine.Request, res engine.Response) {
}))
ctx := new(fasthttp.RequestCtx)
s.ServeHTTP(ctx)
}
func TestServerWrapHandler(t *testing.T) {
e := echo.New()
ctx := new(fasthttp.RequestCtx)
req := NewRequest(ctx, nil)
res := NewResponse(ctx, nil)
c := e.NewContext(req, res)
h := WrapHandler(func(ctx *fasthttp.RequestCtx) {
ctx.Write([]byte("test"))
})
if assert.NoError(t, h(c)) {
assert.Equal(t, http.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, "test", string(ctx.Response.Body()))
}
}
func TestServerWrapMiddleware(t *testing.T) {
e := echo.New()
ctx := new(fasthttp.RequestCtx)
req := NewRequest(ctx, nil)
res := NewResponse(ctx, nil)
c := e.NewContext(req, res)
buf := new(bytes.Buffer)
mw := WrapMiddleware(func(h fasthttp.RequestHandler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
buf.Write([]byte("mw"))
h(ctx)
}
})
h := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
if assert.NoError(t, h(c)) {
assert.Equal(t, "mw", buf.String())
assert.Equal(t, http.StatusOK, ctx.Response.StatusCode())
assert.Equal(t, "OK", string(ctx.Response.Body()))
}
}

View File

@ -1,49 +0,0 @@
// +build !appengine
package fasthttp
import "github.com/valyala/fasthttp"
type (
// URL implements `engine.URL`.
URL struct {
*fasthttp.URI
}
)
// Path implements `engine.URL#Path` function.
func (u *URL) Path() string {
return string(u.URI.PathOriginal())
}
// SetPath implements `engine.URL#SetPath` function.
func (u *URL) SetPath(path string) {
u.URI.SetPath(path)
}
// QueryParam implements `engine.URL#QueryParam` function.
func (u *URL) QueryParam(name string) string {
return string(u.QueryArgs().Peek(name))
}
// QueryParams implements `engine.URL#QueryParams` function.
func (u *URL) QueryParams() (params map[string][]string) {
params = make(map[string][]string)
u.QueryArgs().VisitAll(func(k, v []byte) {
_, ok := params[string(k)]
if !ok {
params[string(k)] = make([]string, 0)
}
params[string(k)] = append(params[string(k)], string(v))
})
return
}
// QueryString implements `engine.URL#QueryString` function.
func (u *URL) QueryString() string {
return string(u.URI.QueryString())
}
func (u *URL) reset(uri *fasthttp.URI) {
u.URI = uri
}

View File

@ -1,18 +0,0 @@
package fasthttp
import (
"github.com/labstack/echo/engine/test"
"github.com/stretchr/testify/assert"
fast "github.com/valyala/fasthttp"
"testing"
)
func TestURL(t *testing.T) {
uri := &fast.URI{}
uri.Parse([]byte("github.com"), []byte("/labstack/echo?param1=value1&param1=value2&param2=value3"))
mUrl := &URL{uri}
test.URLTest(t, mUrl)
mUrl.reset(&fast.URI{})
assert.Equal(t, "", string(mUrl.Host()))
}

View File

@ -1,48 +0,0 @@
package standard
import (
"net/http"
"time"
)
type (
// Cookie implements `engine.Cookie`.
Cookie struct {
*http.Cookie
}
)
// Name implements `engine.Cookie#Name` function.
func (c *Cookie) Name() string {
return c.Cookie.Name
}
// Value implements `engine.Cookie#Value` function.
func (c *Cookie) Value() string {
return c.Cookie.Value
}
// Path implements `engine.Cookie#Path` function.
func (c *Cookie) Path() string {
return c.Cookie.Path
}
// Domain implements `engine.Cookie#Domain` function.
func (c *Cookie) Domain() string {
return c.Cookie.Domain
}
// Expires implements `engine.Cookie#Expires` function.
func (c *Cookie) Expires() time.Time {
return c.Cookie.Expires
}
// Secure implements `engine.Cookie#Secure` function.
func (c *Cookie) Secure() bool {
return c.Cookie.Secure
}
// HTTPOnly implements `engine.Cookie#HTTPOnly` function.
func (c *Cookie) HTTPOnly() bool {
return c.Cookie.HttpOnly
}

View File

@ -1,21 +0,0 @@
package standard
import (
"github.com/labstack/echo/engine/test"
"net/http"
"testing"
"time"
)
func TestCookie(t *testing.T) {
cookie := &Cookie{&http.Cookie{
Name: "session",
Value: "securetoken",
Path: "/",
Domain: "github.com",
Expires: time.Date(2016, time.January, 1, 0, 0, 0, 0, time.UTC),
Secure: true,
HttpOnly: true,
}}
test.CookieTest(t, cookie)
}

View File

@ -1,51 +0,0 @@
package standard
import "net/http"
type (
// Header implements `engine.Header`.
Header struct {
http.Header
}
)
// Add implements `engine.Header#Add` function.
func (h *Header) Add(key, val string) {
h.Header.Add(key, val)
}
// Del implements `engine.Header#Del` function.
func (h *Header) Del(key string) {
h.Header.Del(key)
}
// Set implements `engine.Header#Set` function.
func (h *Header) Set(key, val string) {
h.Header.Set(key, val)
}
// Get implements `engine.Header#Get` function.
func (h *Header) Get(key string) string {
return h.Header.Get(key)
}
// Keys implements `engine.Header#Keys` function.
func (h *Header) Keys() (keys []string) {
keys = make([]string, len(h.Header))
i := 0
for k := range h.Header {
keys[i] = k
i++
}
return
}
// Contains implements `engine.Header#Contains` function.
func (h *Header) Contains(key string) bool {
_, ok := h.Header[key]
return ok
}
func (h *Header) reset(hdr http.Header) {
h.Header = hdr
}

View File

@ -1,16 +0,0 @@
package standard
import (
"github.com/labstack/echo/engine/test"
"github.com/stretchr/testify/assert"
"net/http"
"testing"
)
func TestHeader(t *testing.T) {
header := &Header{http.Header{}}
test.HeaderTest(t, header)
header.reset(http.Header{})
assert.Len(t, header.Keys(), 0)
}

View File

@ -1,205 +0,0 @@
package standard
import (
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net"
"net/http"
"strings"
"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log"
)
type (
// Request implements `engine.Request`.
Request struct {
*http.Request
header engine.Header
url engine.URL
logger log.Logger
}
)
const (
defaultMemory = 32 << 20 // 32 MB
)
// NewRequest returns `Request` instance.
func NewRequest(r *http.Request, l log.Logger) *Request {
return &Request{
Request: r,
url: &URL{URL: r.URL},
header: &Header{Header: r.Header},
logger: l,
}
}
// IsTLS implements `engine.Request#TLS` function.
func (r *Request) IsTLS() bool {
return r.Request.TLS != nil
}
// Scheme implements `engine.Request#Scheme` function.
func (r *Request) Scheme() string {
// Can't use `r.Request.URL.Scheme`
// See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0
if r.IsTLS() {
return "https"
}
return "http"
}
// Host implements `engine.Request#Host` function.
func (r *Request) Host() string {
return r.Request.Host
}
// SetHost implements `engine.Request#SetHost` function.
func (r *Request) SetHost(host string) {
r.Request.Host = host
}
// URL implements `engine.Request#URL` function.
func (r *Request) URL() engine.URL {
return r.url
}
// Header implements `engine.Request#Header` function.
func (r *Request) Header() engine.Header {
return r.header
}
// Referer implements `engine.Request#Referer` function.
func (r *Request) Referer() string {
return r.Request.Referer()
}
// func Proto() string {
// return r.request.Proto()
// }
//
// func ProtoMajor() int {
// return r.request.ProtoMajor()
// }
//
// func ProtoMinor() int {
// return r.request.ProtoMinor()
// }
// ContentLength implements `engine.Request#ContentLength` function.
func (r *Request) ContentLength() int64 {
return r.Request.ContentLength
}
// UserAgent implements `engine.Request#UserAgent` function.
func (r *Request) UserAgent() string {
return r.Request.UserAgent()
}
// RemoteAddress implements `engine.Request#RemoteAddress` function.
func (r *Request) RemoteAddress() string {
return r.RemoteAddr
}
// RealIP implements `engine.Request#RealIP` function.
func (r *Request) RealIP() string {
ra := r.RemoteAddress()
if ip := r.Header().Get(echo.HeaderXForwardedFor); ip != "" {
ra = ip
} else if ip := r.Header().Get(echo.HeaderXRealIP); ip != "" {
ra = ip
} else {
ra, _, _ = net.SplitHostPort(ra)
}
return ra
}
// Method implements `engine.Request#Method` function.
func (r *Request) Method() string {
return r.Request.Method
}
// SetMethod implements `engine.Request#SetMethod` function.
func (r *Request) SetMethod(method string) {
r.Request.Method = method
}
// URI implements `engine.Request#URI` function.
func (r *Request) URI() string {
return r.RequestURI
}
// SetURI implements `engine.Request#SetURI` function.
func (r *Request) SetURI(uri string) {
r.RequestURI = uri
}
// Body implements `engine.Request#Body` function.
func (r *Request) Body() io.Reader {
return r.Request.Body
}
// SetBody implements `engine.Request#SetBody` function.
func (r *Request) SetBody(reader io.Reader) {
r.Request.Body = ioutil.NopCloser(reader)
}
// FormValue implements `engine.Request#FormValue` function.
func (r *Request) FormValue(name string) string {
return r.Request.FormValue(name)
}
// FormParams implements `engine.Request#FormParams` function.
func (r *Request) FormParams() map[string][]string {
if strings.HasPrefix(r.header.Get(echo.HeaderContentType), echo.MIMEMultipartForm) {
if err := r.ParseMultipartForm(defaultMemory); err != nil {
panic(fmt.Sprintf("echo: %v", err))
}
} else {
if err := r.ParseForm(); err != nil {
panic(fmt.Sprintf("echo: %v", err))
}
}
return map[string][]string(r.Request.Form)
}
// FormFile implements `engine.Request#FormFile` function.
func (r *Request) FormFile(name string) (*multipart.FileHeader, error) {
_, fh, err := r.Request.FormFile(name)
return fh, err
}
// MultipartForm implements `engine.Request#MultipartForm` function.
func (r *Request) MultipartForm() (*multipart.Form, error) {
err := r.ParseMultipartForm(defaultMemory)
return r.Request.MultipartForm, err
}
// Cookie implements `engine.Request#Cookie` function.
func (r *Request) Cookie(name string) (engine.Cookie, error) {
c, err := r.Request.Cookie(name)
if err != nil {
return nil, echo.ErrCookieNotFound
}
return &Cookie{c}, nil
}
// Cookies implements `engine.Request#Cookies` function.
func (r *Request) Cookies() []engine.Cookie {
cs := r.Request.Cookies()
cookies := make([]engine.Cookie, len(cs))
for i, c := range cs {
cookies[i] = &Cookie{c}
}
return cookies
}
func (r *Request) reset(req *http.Request, h engine.Header, u engine.URL) {
r.Request = req
r.header = h
r.url = u
}

View File

@ -1,25 +0,0 @@
package standard
import (
"bufio"
"net/http"
"net/url"
"strings"
"testing"
"github.com/labstack/echo/engine/test"
"github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert"
)
func TestRequest(t *testing.T) {
httpReq, _ := http.ReadRequest(bufio.NewReader(strings.NewReader(test.MultipartRequest)))
url, _ := url.Parse("http://github.com/labstack/echo")
httpReq.URL = url
httpReq.RemoteAddr = "127.0.0.1"
req := NewRequest(httpReq, log.New("echo"))
test.RequestTest(t, req)
nr, _ := http.NewRequest("GET", "/", nil)
req.reset(nr, nil, nil)
assert.Equal(t, "", req.Host())
}

View File

@ -1,146 +0,0 @@
package standard
import (
"bufio"
"io"
"net"
"net/http"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log"
)
type (
// Response implements `engine.Response`.
Response struct {
http.ResponseWriter
adapter *responseAdapter
header engine.Header
status int
size int64
committed bool
writer io.Writer
logger log.Logger
}
responseAdapter struct {
*Response
}
)
// NewResponse returns `Response` instance.
func NewResponse(w http.ResponseWriter, l log.Logger) (r *Response) {
r = &Response{
ResponseWriter: w,
header: &Header{Header: w.Header()},
writer: w,
logger: l,
}
r.adapter = &responseAdapter{Response: r}
return
}
// Header implements `engine.Response#Header` function.
func (r *Response) Header() engine.Header {
return r.header
}
// WriteHeader implements `engine.Response#WriteHeader` function.
func (r *Response) WriteHeader(code int) {
if r.committed {
r.logger.Warn("response already committed")
return
}
r.status = code
r.ResponseWriter.WriteHeader(code)
r.committed = true
}
// Write implements `engine.Response#Write` function.
func (r *Response) Write(b []byte) (n int, err error) {
if !r.committed {
r.WriteHeader(http.StatusOK)
}
n, err = r.writer.Write(b)
r.size += int64(n)
return
}
// SetCookie implements `engine.Response#SetCookie` function.
func (r *Response) SetCookie(c engine.Cookie) {
http.SetCookie(r.ResponseWriter, &http.Cookie{
Name: c.Name(),
Value: c.Value(),
Path: c.Path(),
Domain: c.Domain(),
Expires: c.Expires(),
Secure: c.Secure(),
HttpOnly: c.HTTPOnly(),
})
}
// Status implements `engine.Response#Status` function.
func (r *Response) Status() int {
return r.status
}
// Size implements `engine.Response#Size` function.
func (r *Response) Size() int64 {
return r.size
}
// Committed implements `engine.Response#Committed` function.
func (r *Response) Committed() bool {
return r.committed
}
// Writer implements `engine.Response#Writer` function.
func (r *Response) Writer() io.Writer {
return r.writer
}
// SetWriter implements `engine.Response#SetWriter` function.
func (r *Response) SetWriter(w io.Writer) {
r.writer = w
}
// Flush implements the http.Flusher interface to allow an HTTP handler to flush
// buffered data to the client.
// See https://golang.org/pkg/net/http/#Flusher
func (r *Response) Flush() {
r.ResponseWriter.(http.Flusher).Flush()
}
// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See https://golang.org/pkg/net/http/#Hijacker
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.ResponseWriter.(http.Hijacker).Hijack()
}
// CloseNotify implements the http.CloseNotifier interface to allow detecting
// when the underlying connection has gone away.
// This mechanism can be used to cancel long operations on the server if the
// client has disconnected before the response is ready.
// See https://golang.org/pkg/net/http/#CloseNotifier
func (r *Response) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
func (r *Response) reset(w http.ResponseWriter, a *responseAdapter, h engine.Header) {
r.ResponseWriter = w
r.adapter = a
r.header = h
r.status = http.StatusOK
r.size = 0
r.committed = false
r.writer = w
}
func (r *responseAdapter) Header() http.Header {
return r.ResponseWriter.Header()
}
func (r *responseAdapter) reset(res *Response) {
r.Response = res
}

View File

@ -1,38 +0,0 @@
package standard
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/gommon/log"
"github.com/stretchr/testify/assert"
)
func TestResponseWriteHeader(t *testing.T) {
rec := httptest.NewRecorder()
res := NewResponse(rec, log.New("test"))
res.WriteHeader(http.StatusOK)
assert.True(t, res.Committed())
assert.Equal(t, http.StatusOK, res.Status())
}
func TestResponseWrite(t *testing.T) {
rec := httptest.NewRecorder()
res := NewResponse(rec, log.New("test"))
res.Write([]byte("test"))
assert.Equal(t, int64(4), res.Size())
assert.Equal(t, "test", rec.Body.String())
res.Flush()
assert.True(t, rec.Flushed)
}
func TestResponseSetCookie(t *testing.T) {
rec := httptest.NewRecorder()
res := NewResponse(rec, log.New("test"))
res.SetCookie(&Cookie{&http.Cookie{
Name: "name",
Value: "Jon Snow",
}})
assert.Equal(t, "name=Jon Snow", rec.Header().Get("Set-Cookie"))
}

View File

@ -1,208 +0,0 @@
package standard
import (
"crypto/tls"
"net"
"net/http"
"sync"
"time"
"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/labstack/echo/log"
glog "github.com/labstack/gommon/log"
)
type (
// Server implements `engine.Server`.
Server struct {
*http.Server
config engine.Config
handler engine.Handler
logger log.Logger
pool *pool
}
pool struct {
request sync.Pool
response sync.Pool
responseAdapter sync.Pool
header sync.Pool
url sync.Pool
}
)
// New returns `Server` instance with provided listen address.
func New(addr string) *Server {
c := engine.Config{Address: addr}
return WithConfig(c)
}
// WithTLS returns `Server` instance with provided TLS config.
func WithTLS(addr, certFile, keyFile string) *Server {
c := engine.Config{
Address: addr,
TLSCertFile: certFile,
TLSKeyFile: keyFile,
}
return WithConfig(c)
}
// WithConfig returns `Server` instance with provided config.
func WithConfig(c engine.Config) (s *Server) {
s = &Server{
Server: new(http.Server),
config: c,
pool: &pool{
request: sync.Pool{
New: func() interface{} {
return &Request{logger: s.logger}
},
},
response: sync.Pool{
New: func() interface{} {
return &Response{logger: s.logger}
},
},
responseAdapter: sync.Pool{
New: func() interface{} {
return &responseAdapter{}
},
},
header: sync.Pool{
New: func() interface{} {
return &Header{}
},
},
url: sync.Pool{
New: func() interface{} {
return &URL{}
},
},
},
handler: engine.HandlerFunc(func(req engine.Request, res engine.Response) {
panic("echo: handler not set, use `Server#SetHandler()` to set it.")
}),
logger: glog.New("echo"),
}
s.ReadTimeout = c.ReadTimeout
s.WriteTimeout = c.WriteTimeout
s.Addr = c.Address
s.Handler = s
return
}
// SetHandler implements `engine.Server#SetHandler` function.
func (s *Server) SetHandler(h engine.Handler) {
s.handler = h
}
// SetLogger implements `engine.Server#SetLogger` function.
func (s *Server) SetLogger(l log.Logger) {
s.logger = l
}
// Start implements `engine.Server#Start` function.
func (s *Server) Start() error {
if s.config.Listener == nil {
ln, err := net.Listen("tcp", s.config.Address)
if err != nil {
return err
}
if s.config.TLSCertFile != "" && s.config.TLSKeyFile != "" {
// TODO: https://github.com/golang/go/commit/d24f446a90ea94b87591bf16228d7d871fec3d92
config := &tls.Config{
NextProtos: []string{"http/1.1"},
}
if !s.config.DisableHTTP2 {
config.NextProtos = append(config.NextProtos, "h2")
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(s.config.TLSCertFile, s.config.TLSKeyFile)
if err != nil {
return err
}
s.config.Listener = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config)
} else {
s.config.Listener = tcpKeepAliveListener{ln.(*net.TCPListener)}
}
}
return s.Serve(s.config.Listener)
}
// Stop implements `engine.Server#Stop` function.
func (s *Server) Stop() error {
return s.config.Listener.Close()
}
// ServeHTTP implements `http.Handler` interface.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Request
req := s.pool.request.Get().(*Request)
reqHdr := s.pool.header.Get().(*Header)
reqURL := s.pool.url.Get().(*URL)
reqHdr.reset(r.Header)
reqURL.reset(r.URL)
req.reset(r, reqHdr, reqURL)
// Response
res := s.pool.response.Get().(*Response)
resAdpt := s.pool.responseAdapter.Get().(*responseAdapter)
resAdpt.reset(res)
resHdr := s.pool.header.Get().(*Header)
resHdr.reset(w.Header())
res.reset(w, resAdpt, resHdr)
s.handler.ServeHTTP(req, res)
// Return to pool
s.pool.request.Put(req)
s.pool.header.Put(reqHdr)
s.pool.url.Put(reqURL)
s.pool.response.Put(res)
s.pool.header.Put(resHdr)
}
// WrapHandler wraps `http.Handler` into `echo.HandlerFunc`.
func WrapHandler(h http.Handler) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request().(*Request)
res := c.Response().(*Response)
h.ServeHTTP(res.adapter, req.Request)
return nil
}
}
// WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc`
func WrapMiddleware(m func(http.Handler) http.Handler) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
req := c.Request().(*Request)
res := c.Response().(*Response)
m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err = next(c)
})).ServeHTTP(res.adapter, req.Request)
return
}
}
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections. It's used by ListenAndServe and ListenAndServeTLS so
// dead TCP connections (e.g. closing laptop mid-download) eventually
// go away.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}

View File

@ -1,61 +0,0 @@
package standard
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/stretchr/testify/assert"
)
// TODO: Fix me
func TestServer(t *testing.T) {
s := New("")
s.SetHandler(engine.HandlerFunc(func(req engine.Request, res engine.Response) {
}))
rec := httptest.NewRecorder()
req := new(http.Request)
s.ServeHTTP(rec, req)
}
func TestServerWrapHandler(t *testing.T) {
e := echo.New()
req := NewRequest(new(http.Request), nil)
rec := httptest.NewRecorder()
res := NewResponse(rec, nil)
c := e.NewContext(req, res)
h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test"))
}))
if assert.NoError(t, h(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "test", rec.Body.String())
}
}
func TestServerWrapMiddleware(t *testing.T) {
e := echo.New()
req := NewRequest(new(http.Request), nil)
rec := httptest.NewRecorder()
res := NewResponse(rec, nil)
c := e.NewContext(req, res)
buf := new(bytes.Buffer)
mw := WrapMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf.Write([]byte("mw"))
h.ServeHTTP(w, r)
})
})
h := mw(func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
})
if assert.NoError(t, h(c)) {
assert.Equal(t, "mw", buf.String())
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "OK", rec.Body.String())
}
}

View File

@ -1,47 +0,0 @@
package standard
import "net/url"
type (
// URL implements `engine.URL`.
URL struct {
*url.URL
query url.Values
}
)
// Path implements `engine.URL#Path` function.
func (u *URL) Path() string {
return u.URL.EscapedPath()
}
// SetPath implements `engine.URL#SetPath` function.
func (u *URL) SetPath(path string) {
u.URL.Path = path
}
// QueryParam implements `engine.URL#QueryParam` function.
func (u *URL) QueryParam(name string) string {
if u.query == nil {
u.query = u.Query()
}
return u.query.Get(name)
}
// QueryParams implements `engine.URL#QueryParams` function.
func (u *URL) QueryParams() map[string][]string {
if u.query == nil {
u.query = u.Query()
}
return map[string][]string(u.query)
}
// QueryString implements `engine.URL#QueryString` function.
func (u *URL) QueryString() string {
return u.URL.RawQuery
}
func (u *URL) reset(url *url.URL) {
u.URL = url
u.query = nil
}

View File

@ -1,17 +0,0 @@
package standard
import (
"github.com/labstack/echo/engine/test"
"github.com/stretchr/testify/assert"
"net/url"
"testing"
)
func TestURL(t *testing.T) {
u, _ := url.Parse("https://github.com/labstack/echo?param1=value1&param1=value2&param2=value3")
mUrl := &URL{u, nil}
test.URLTest(t, mUrl)
mUrl.reset(&url.URL{})
assert.Equal(t, "", mUrl.Host)
}

View File

@ -1,62 +0,0 @@
package test
import (
"testing"
"time"
"github.com/labstack/echo/engine"
"github.com/stretchr/testify/assert"
)
func HeaderTest(t *testing.T, header engine.Header) {
h := "X-My-Header"
v := "value"
nv := "new value"
h1 := "X-Another-Header"
header.Add(h, v)
assert.Equal(t, v, header.Get(h))
header.Set(h, nv)
assert.Equal(t, nv, header.Get(h))
assert.True(t, header.Contains(h))
header.Del(h)
assert.False(t, header.Contains(h))
header.Add(h, v)
header.Add(h1, v)
for _, expected := range []string{h, h1} {
found := false
for _, actual := range header.Keys() {
if actual == expected {
found = true
break
}
}
if !found {
t.Errorf("Header %s not found", expected)
}
}
}
func URLTest(t *testing.T, url engine.URL) {
path := "/echo/test"
url.SetPath(path)
assert.Equal(t, path, url.Path())
assert.Equal(t, map[string][]string{"param1": []string{"value1", "value2"}, "param2": []string{"value3"}}, url.QueryParams())
assert.Equal(t, "value1", url.QueryParam("param1"))
assert.Equal(t, "param1=value1&param1=value2&param2=value3", url.QueryString())
}
func CookieTest(t *testing.T, cookie engine.Cookie) {
assert.Equal(t, "github.com", cookie.Domain())
assert.Equal(t, time.Date(2016, time.January, 1, 0, 0, 0, 0, time.UTC), cookie.Expires())
assert.True(t, cookie.HTTPOnly())
assert.True(t, cookie.Secure())
assert.Equal(t, "session", cookie.Name())
assert.Equal(t, "/", cookie.Path())
assert.Equal(t, "securetoken", cookie.Value())
}

View File

@ -1,97 +0,0 @@
package test
import (
"io/ioutil"
"strings"
"testing"
"github.com/labstack/echo/engine"
"github.com/stretchr/testify/assert"
)
const MultipartRequest = `POST /labstack/echo HTTP/1.1
Host: github.com
Connection: close
User-Agent: Mozilla/5.0 (Macintosh; U; Intel Mac OS X; de-de) AppleWebKit/523.10.3 (KHTML, like Gecko) Version/3.0.4 Safari/523.10
Content-Type: multipart/form-data; boundary=Asrf456BGe4h
Content-Length: 261
Accept-Encoding: gzip
Accept-Charset: ISO-8859-1,UTF-8;q=0.7,*;q=0.7
Cache-Control: no-cache
Accept-Language: de,en;q=0.7,en-us;q=0.3
Referer: https://github.com/
Cookie: session=securetoken; user=123
X-Real-IP: 192.168.1.1
--Asrf456BGe4h
Content-Disposition: form-data; name="foo"
bar
--Asrf456BGe4h
Content-Disposition: form-data; name="baz"
bat
--Asrf456BGe4h
Content-Disposition: form-data; name="note"; filename="note.txt"
Content-Type: text/plain
Hello world!
--Asrf456BGe4h--
`
func RequestTest(t *testing.T, request engine.Request) {
assert.Equal(t, "github.com", request.Host())
request.SetHost("labstack.com")
assert.Equal(t, "labstack.com", request.Host())
request.SetURI("/labstack/echo?token=54321")
assert.Equal(t, "/labstack/echo?token=54321", request.URI())
assert.Equal(t, "/labstack/echo", request.URL().Path())
assert.Equal(t, "https://github.com/", request.Referer())
assert.Equal(t, "192.168.1.1", request.Header().Get("X-Real-IP"))
assert.Equal(t, "http", request.Scheme())
assert.Equal(t, "Mozilla/5.0 (Macintosh; U; Intel Mac OS X; de-de) AppleWebKit/523.10.3 (KHTML, like Gecko) Version/3.0.4 Safari/523.10", request.UserAgent())
assert.Equal(t, "127.0.0.1", request.RemoteAddress())
assert.Equal(t, "192.168.1.1", request.RealIP())
assert.Equal(t, "POST", request.Method())
assert.Equal(t, int64(261), request.ContentLength())
assert.Equal(t, "bar", request.FormValue("foo"))
if fHeader, err := request.FormFile("note"); assert.NoError(t, err) {
if file, err := fHeader.Open(); assert.NoError(t, err) {
text, _ := ioutil.ReadAll(file)
assert.Equal(t, "Hello world!", string(text))
}
}
assert.Equal(t, map[string][]string{"baz": []string{"bat"}, "foo": []string{"bar"}}, request.FormParams())
if form, err := request.MultipartForm(); assert.NoError(t, err) {
_, ok := form.File["note"]
assert.True(t, ok)
}
request.SetMethod("PUT")
assert.Equal(t, "PUT", request.Method())
request.SetBody(strings.NewReader("Hello"))
if body, err := ioutil.ReadAll(request.Body()); assert.NoError(t, err) {
assert.Equal(t, "Hello", string(body))
}
if cookie, err := request.Cookie("session"); assert.NoError(t, err) {
assert.Equal(t, "session", cookie.Name())
assert.Equal(t, "securetoken", cookie.Value())
}
_, err := request.Cookie("foo")
assert.Error(t, err)
// Cookies
cs := request.Cookies()
if assert.Len(t, cs, 2) {
assert.Equal(t, "session", cs[0].Name())
assert.Equal(t, "securetoken", cs[0].Value())
assert.Equal(t, "user", cs[1].Name())
assert.Equal(t, "123", cs[1].Value())
}
}

View File

@ -60,7 +60,7 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c) return next(c)
} }
auth := c.Request().Header().Get(echo.HeaderAuthorization) auth := c.Request().Header.Get(echo.HeaderAuthorization)
l := len(basic) l := len(basic)
if len(auth) > l+1 && auth[:l] == basic { if len(auth) > l+1 && auth[:l] == basic {

View File

@ -3,17 +3,17 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
res := test.NewResponseRecorder() res := httptest.NewRecorder()
c := e.NewContext(req, res) c := e.NewContext(req, res)
f := func(u, p string) bool { f := func(u, p string) bool {
if u == "joe" && p == "secret" { if u == "joe" && p == "secret" {
@ -27,24 +27,24 @@ func TestBasicAuth(t *testing.T) {
// Valid credentials // Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header().Set(echo.HeaderAuthorization, auth) req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c)) assert.NoError(t, h(c))
// Incorrect password // Incorrect password
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header().Set(echo.HeaderAuthorization, auth) req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError) he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate)) assert.Equal(t, basic+" realm=Restricted", res.Header().Get(echo.HeaderWWWAuthenticate))
// Empty Authorization header // Empty Authorization header
req.Header().Set(echo.HeaderAuthorization, "") req.Header.Set(echo.HeaderAuthorization, "")
he = h(c).(*echo.HTTPError) he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, http.StatusUnauthorized, he.Code)
// Invalid Authorization header // Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid")) auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header().Set(echo.HeaderAuthorization, auth) req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError) he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code) assert.Equal(t, http.StatusUnauthorized, he.Code)
} }

View File

@ -23,7 +23,7 @@ type (
limitedReader struct { limitedReader struct {
BodyLimitConfig BodyLimitConfig
reader io.Reader reader io.ReadCloser
read int64 read int64
context echo.Context context echo.Context
} }
@ -74,15 +74,15 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
// Based on content length // Based on content length
if req.ContentLength() > config.limit { if req.ContentLength > config.limit {
return echo.ErrStatusRequestEntityTooLarge return echo.ErrStatusRequestEntityTooLarge
} }
// Based on content read // Based on content read
r := pool.Get().(*limitedReader) r := pool.Get().(*limitedReader)
r.Reset(req.Body(), c) r.Reset(req.Body, c)
defer pool.Put(r) defer pool.Put(r)
req.SetBody(r) req.Body = r
return next(c) return next(c)
} }
@ -98,7 +98,11 @@ func (r *limitedReader) Read(b []byte) (n int, err error) {
return return
} }
func (r *limitedReader) Reset(reader io.Reader, context echo.Context) { func (r *limitedReader) Close() error {
return r.reader.Close()
}
func (r *limitedReader) Reset(reader io.ReadCloser, context echo.Context) {
r.reader = reader r.reader = reader
r.context = context r.context = context
} }

View File

@ -4,21 +4,22 @@ import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestBodyLimit(t *testing.T) { func TestBodyLimit(t *testing.T) {
return
e := echo.New() e := echo.New()
hw := []byte("Hello, World!") hw := []byte("Hello, World!")
req := test.NewRequest(echo.POST, "/", bytes.NewReader(hw)) req, _ := http.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
body, err := ioutil.ReadAll(c.Request().Body()) body, err := ioutil.ReadAll(c.Request().Body)
if err != nil { if err != nil {
return err return err
} }
@ -27,8 +28,8 @@ func TestBodyLimit(t *testing.T) {
// Based on content length (within limit) // Based on content length (within limit)
if assert.NoError(t, BodyLimit("2M")(h)(c)) { if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes()) assert.Equal(t, hw, rec.Body.Bytes)
} }
// Based on content read (overlimit) // Based on content read (overlimit)
@ -36,17 +37,17 @@ func TestBodyLimit(t *testing.T) {
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
// Based on content read (within limit) // Based on content read (within limit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw)) req, _ = http.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
if assert.NoError(t, BodyLimit("2M")(h)(c)) { if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String()) assert.Equal(t, "Hello, World!", rec.Body.String())
} }
// Based on content read (overlimit) // Based on content read (overlimit)
req = test.NewRequest(echo.POST, "/", bytes.NewReader(hw)) req, _ = http.NewRequest(echo.POST, "/", bytes.NewReader(hw))
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError) he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)

View File

@ -1,15 +1,16 @@
package middleware package middleware
import ( import (
"bufio"
"compress/gzip" "compress/gzip"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/engine"
) )
type ( type (
@ -24,8 +25,8 @@ type (
} }
gzipResponseWriter struct { gzipResponseWriter struct {
engine.Response
io.Writer io.Writer
http.ResponseWriter
} }
) )
@ -65,36 +66,51 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res := c.Response() res := c.Response()
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header().Get(echo.HeaderAcceptEncoding), scheme) { if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), scheme) {
rw := res.Writer() rw := res.Writer()
gw := pool.Get().(*gzip.Writer) w := pool.Get().(*gzip.Writer)
gw.Reset(rw) w.Reset(c.Response().Writer())
// rw := res.Writer()
// gw := pool.Get().(*gzip.Writer)
// gw.Reset(rw)
defer func() { defer func() {
if res.Size() == 0 { if res.Size == 0 {
// We have to reset response to it's pristine state when // We have to reset response to it's pristine state when
// nothing is written to body or error is returned. // nothing is written to body or error is returned.
// See issue #424, #407. // See issue #424, #407.
res.SetWriter(rw) res.SetWriter(rw)
res.Header().Del(echo.HeaderContentEncoding) res.Header().Del(echo.HeaderContentEncoding)
gw.Reset(ioutil.Discard) w.Reset(ioutil.Discard)
} }
gw.Close() w.Close()
pool.Put(gw) pool.Put(w)
}() }()
g := gzipResponseWriter{Response: res, Writer: gw} grw := gzipResponseWriter{Writer: w, ResponseWriter: res.Writer()}
res.Header().Set(echo.HeaderContentEncoding, scheme) res.Header().Set(echo.HeaderContentEncoding, scheme)
res.SetWriter(g) res.SetWriter(grw)
} }
return next(c) return next(c)
} }
} }
} }
func (g gzipResponseWriter) Write(b []byte) (int, error) { func (w gzipResponseWriter) Write(b []byte) (int, error) {
if g.Header().Get(echo.HeaderContentType) == "" { if w.Header().Get(echo.HeaderContentType) == "" {
g.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) w.Header().Set(echo.HeaderContentType, http.DetectContentType(b))
} }
return g.Writer.Write(b) return w.Writer.Write(b)
}
func (w gzipResponseWriter) Flush() error {
return w.Writer.(*gzip.Writer).Flush()
}
func (w gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.ResponseWriter.(http.Hijacker).Hijack()
}
func (w *gzipResponseWriter) CloseNotify() <-chan bool {
return w.ResponseWriter.(http.CloseNotifier).CloseNotify()
} }
func gzipPool(config GzipConfig) sync.Pool { func gzipPool(config GzipConfig) sync.Pool {

View File

@ -4,17 +4,17 @@ import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestGzip(t *testing.T) { func TestGzip(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
// Skip if no Accept-Encoding header // Skip if no Accept-Encoding header
@ -25,9 +25,9 @@ func TestGzip(t *testing.T) {
h(c) h(c)
assert.Equal(t, "test", rec.Body.String()) assert.Equal(t, "test", rec.Body.String())
req = test.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
req.Header().Set(echo.HeaderAcceptEncoding, "gzip") req.Header.Set(echo.HeaderAcceptEncoding, "gzip")
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
// Gzip // Gzip
@ -45,8 +45,8 @@ func TestGzip(t *testing.T) {
func TestGzipNoContent(t *testing.T) { func TestGzipNoContent(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := Gzip()(func(c echo.Context) error { h := Gzip()(func(c echo.Context) error {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
@ -64,9 +64,9 @@ func TestGzipErrorReturned(t *testing.T) {
e.GET("/", func(c echo.Context) error { e.GET("/", func(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, "error") return echo.NewHTTPError(http.StatusInternalServerError, "error")
}) })
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
e.ServeHTTP(req, rec) e.ServeHTTP(rec, req)
assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding))
assert.Equal(t, "error", rec.Body.String()) assert.Equal(t, "error", rec.Body.String())
} }

View File

@ -88,8 +88,8 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
origin := req.Header().Get(echo.HeaderOrigin) origin := req.Header.Get(echo.HeaderOrigin)
originSet := req.Header().Contains(echo.HeaderOrigin) // Issue #517 _, originSet := req.Header[echo.HeaderOrigin]
// Check allowed origins // Check allowed origins
allowedOrigin := "" allowedOrigin := ""
@ -101,7 +101,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
} }
// Simple request // Simple request
if req.Method() != echo.OPTIONS { if req.Method != echo.OPTIONS {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !originSet || allowedOrigin == "" { if !originSet || allowedOrigin == "" {
return next(c) return next(c)
@ -131,7 +131,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
if allowHeaders != "" { if allowHeaders != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders) res.Header().Set(echo.HeaderAccessControlAllowHeaders, allowHeaders)
} else { } else {
h := req.Header().Get(echo.HeaderAccessControlRequestHeaders) h := req.Header.Get(echo.HeaderAccessControlRequestHeaders)
if h != "" { if h != "" {
res.Header().Set(echo.HeaderAccessControlAllowHeaders, h) res.Header().Set(echo.HeaderAccessControlAllowHeaders, h)
} }

View File

@ -2,17 +2,17 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCORS(t *testing.T) { func TestCORS(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
cors := CORSWithConfig(CORSConfig{ cors := CORSWithConfig(CORSConfig{
AllowCredentials: true, AllowCredentials: true,
@ -26,26 +26,26 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Empty origin header // Empty origin header
req = test.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderOrigin, "") req.Header.Set(echo.HeaderOrigin, "")
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Wildcard origin // Wildcard origin
req = test.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
h(c) h(c)
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Simple request // Simple request
req = test.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
cors = CORSWithConfig(CORSConfig{ cors = CORSWithConfig(CORSConfig{
AllowOrigins: []string{"localhost"}, AllowOrigins: []string{"localhost"},
AllowCredentials: true, AllowCredentials: true,
@ -58,11 +58,11 @@ func TestCORS(t *testing.T) {
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
// Preflight request // Preflight request
req = test.NewRequest(echo.OPTIONS, "/", nil) req, _ = http.NewRequest(echo.OPTIONS, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderOrigin, "localhost") req.Header.Set(echo.HeaderOrigin, "localhost")
req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON)
h(c) h(c)
assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) assert.Equal(t, "localhost", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods)) assert.NotEmpty(t, rec.Header().Get(echo.HeaderAccessControlAllowMethods))

View File

@ -131,10 +131,10 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
token = random.String(config.TokenLength) token = random.String(config.TokenLength)
} else { } else {
// Reuse token // Reuse token
token = k.Value() token = k.Value
} }
switch req.Method() { switch req.Method {
case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE: case echo.GET, echo.HEAD, echo.OPTIONS, echo.TRACE:
default: default:
// Validate token only for requests which are not defined as 'safe' by RFC7231 // Validate token only for requests which are not defined as 'safe' by RFC7231
@ -148,18 +148,18 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
} }
// Set CSRF cookie // Set CSRF cookie
cookie := new(echo.Cookie) cookie := new(http.Cookie)
cookie.SetName(config.CookieName) cookie.Name = config.CookieName
cookie.SetValue(token) cookie.Value = token
if config.CookiePath != "" { if config.CookiePath != "" {
cookie.SetPath(config.CookiePath) cookie.Path = config.CookiePath
} }
if config.CookieDomain != "" { if config.CookieDomain != "" {
cookie.SetDomain(config.CookieDomain) cookie.Domain = config.CookieDomain
} }
cookie.SetExpires(time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)) cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.SetSecure(config.CookieSecure) cookie.Secure = config.CookieSecure
cookie.SetHTTPOnly(config.CookieHTTPOnly) cookie.HttpOnly = config.CookieHTTPOnly
c.SetCookie(cookie) c.SetCookie(cookie)
// Store token in the context // Store token in the context
@ -177,7 +177,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
// provided request header. // provided request header.
func csrfTokenFromHeader(header string) csrfTokenExtractor { func csrfTokenFromHeader(header string) csrfTokenExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
return c.Request().Header().Get(header), nil return c.Request().Header.Get(header), nil
} }
} }

View File

@ -2,20 +2,20 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/labstack/gommon/random" "github.com/labstack/gommon/random"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCSRF(t *testing.T) { func TestCSRF(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
csrf := CSRFWithConfig(CSRFConfig{ csrf := CSRFWithConfig(CSRFConfig{
TokenLength: 16, TokenLength: 16,
@ -29,24 +29,24 @@ func TestCSRF(t *testing.T) {
assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf")
// Without CSRF cookie // Without CSRF cookie
req = test.NewRequest(echo.POST, "/", nil) req, _ = http.NewRequest(echo.POST, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
assert.Error(t, h(c)) assert.Error(t, h(c))
// Empty/invalid CSRF token // Empty/invalid CSRF token
req = test.NewRequest(echo.POST, "/", nil) req, _ = http.NewRequest(echo.POST, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
req.Header().Set(echo.HeaderXCSRFToken, "") req.Header.Set(echo.HeaderXCSRFToken, "")
assert.Error(t, h(c)) assert.Error(t, h(c))
// Valid CSRF token // Valid CSRF token
token := random.String(16) token := random.String(16)
req.Header().Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderCookie, "_csrf="+token)
req.Header().Set(echo.HeaderXCSRFToken, token) req.Header.Set(echo.HeaderXCSRFToken, token)
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
} }
} }
@ -54,8 +54,8 @@ func TestCSRFTokenFromForm(t *testing.T) {
f := make(url.Values) f := make(url.Values)
f.Set("csrf", "token") f.Set("csrf", "token")
e := echo.New() e := echo.New()
req := test.NewRequest(echo.POST, "/", strings.NewReader(f.Encode())) req, _ := http.NewRequest(echo.POST, "/", strings.NewReader(f.Encode()))
req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
token, err := csrfTokenFromForm("csrf")(c) token, err := csrfTokenFromForm("csrf")(c)
if assert.NoError(t, err) { if assert.NoError(t, err) {
@ -69,8 +69,8 @@ func TestCSRFTokenFromQuery(t *testing.T) {
q := make(url.Values) q := make(url.Values)
q.Set("csrf", "token") q.Set("csrf", "token")
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/?"+q.Encode(), nil) req, _ := http.NewRequest(echo.GET, "/?"+q.Encode(), nil)
req.Header().Add(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm)
c := e.NewContext(req, nil) c := e.NewContext(req, nil)
token, err := csrfTokenFromQuery("csrf")(c) token, err := csrfTokenFromQuery("csrf")(c)
if assert.NoError(t, err) { if assert.NoError(t, err) {

View File

@ -153,7 +153,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
// jwtFromHeader returns a `jwtExtractor` that extracts token from request header. // jwtFromHeader returns a `jwtExtractor` that extracts token from request header.
func jwtFromHeader(header string) jwtExtractor { func jwtFromHeader(header string) jwtExtractor {
return func(c echo.Context) (string, error) { return func(c echo.Context) (string, error) {
auth := c.Request().Header().Get(header) auth := c.Request().Header.Get(header)
l := len(bearer) l := len(bearer)
if len(auth) > l+1 && auth[:l] == bearer { if len(auth) > l+1 && auth[:l] == bearer {
return auth[l+1:], nil return auth[l+1:], nil
@ -181,6 +181,6 @@ func jwtFromCookie(name string) jwtExtractor {
if err != nil { if err != nil {
return "", errors.New("empty jwt in cookie") return "", errors.New("empty jwt in cookie")
} }
return cookie.Value(), nil return cookie.Value, nil
} }
} }

View File

@ -2,11 +2,11 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -148,10 +148,10 @@ func TestJWT(t *testing.T) {
tc.reqURL = "/" tc.reqURL = "/"
} }
req := test.NewRequest(echo.GET, tc.reqURL, nil) req, _ := http.NewRequest(echo.GET, tc.reqURL, nil)
res := test.NewResponseRecorder() res := httptest.NewRecorder()
req.Header().Set(echo.HeaderAuthorization, tc.hdrAuth) req.Header.Set(echo.HeaderAuthorization, tc.hdrAuth)
req.Header().Set(echo.HeaderCookie, tc.hdrCookie) req.Header.Set(echo.HeaderCookie, tc.hdrCookie)
c := e.NewContext(req, res) c := e.NewContext(req, res)
if tc.expPanic { if tc.expPanic {

View File

@ -117,16 +117,16 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "time_rfc3339": case "time_rfc3339":
return w.Write([]byte(time.Now().Format(time.RFC3339))) return w.Write([]byte(time.Now().Format(time.RFC3339)))
case "remote_ip": case "remote_ip":
ra := req.RealIP() ra := c.RealIP()
return w.Write([]byte(ra)) return w.Write([]byte(ra))
case "host": case "host":
return w.Write([]byte(req.Host())) return w.Write([]byte(req.Host))
case "uri": case "uri":
return w.Write([]byte(req.URI())) return w.Write([]byte(req.RequestURI))
case "method": case "method":
return w.Write([]byte(req.Method())) return w.Write([]byte(req.Method))
case "path": case "path":
p := req.URL().Path() p := req.URL.Path
if p == "" { if p == "" {
p = "/" p = "/"
} }
@ -136,7 +136,7 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "user_agent": case "user_agent":
return w.Write([]byte(req.UserAgent())) return w.Write([]byte(req.UserAgent()))
case "status": case "status":
n := res.Status() n := res.Status
s := config.color.Green(n) s := config.color.Green(n)
switch { switch {
case n >= 500: case n >= 500:
@ -153,13 +153,13 @@ func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc {
case "latency_human": case "latency_human":
return w.Write([]byte(stop.Sub(start).String())) return w.Write([]byte(stop.Sub(start).String()))
case "bytes_in": case "bytes_in":
b := req.Header().Get(echo.HeaderContentLength) b := req.Header.Get(echo.HeaderContentLength)
if b == "" { if b == "" {
b = "0" b = "0"
} }
return w.Write([]byte(b)) return w.Write([]byte(b))
case "bytes_out": case "bytes_out":
return w.Write([]byte(strconv.FormatInt(res.Size(), 10))) return w.Write([]byte(strconv.FormatInt(res.Size, 10)))
} }
return 0, nil return 0, nil
}) })

View File

@ -4,18 +4,18 @@ import (
"bytes" "bytes"
"errors" "errors"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestLogger(t *testing.T) { func TestLogger(t *testing.T) {
// Note: Just for the test coverage, not a real test. // Note: Just for the test coverage, not a real test.
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := Logger()(func(c echo.Context) error { h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
@ -25,7 +25,7 @@ func TestLogger(t *testing.T) {
h(c) h(c)
// Status 3xx // Status 3xx
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = Logger()(func(c echo.Context) error { h = Logger()(func(c echo.Context) error {
return c.String(http.StatusTemporaryRedirect, "test") return c.String(http.StatusTemporaryRedirect, "test")
@ -33,7 +33,7 @@ func TestLogger(t *testing.T) {
h(c) h(c)
// Status 4xx // Status 4xx
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = Logger()(func(c echo.Context) error { h = Logger()(func(c echo.Context) error {
return c.String(http.StatusNotFound, "test") return c.String(http.StatusNotFound, "test")
@ -41,8 +41,8 @@ func TestLogger(t *testing.T) {
h(c) h(c)
// Status 5xx with empty path // Status 5xx with empty path
req = test.NewRequest(echo.GET, "", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = Logger()(func(c echo.Context) error { h = Logger()(func(c echo.Context) error {
return errors.New("error") return errors.New("error")
@ -52,25 +52,25 @@ func TestLogger(t *testing.T) {
func TestLoggerIPAddress(t *testing.T) { func TestLoggerIPAddress(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.Logger().SetOutput(buf) e.Logger.SetOutput(buf)
ip := "127.0.0.1" ip := "127.0.0.1"
h := Logger()(func(c echo.Context) error { h := Logger()(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
}) })
// With X-Real-IP // With X-Real-IP
req.Header().Add(echo.HeaderXRealIP, ip) req.Header.Add(echo.HeaderXRealIP, ip)
h(c) h(c)
assert.Contains(t, ip, buf.String()) assert.Contains(t, ip, buf.String())
// With X-Forwarded-For // With X-Forwarded-For
buf.Reset() buf.Reset()
req.Header().Del(echo.HeaderXRealIP) req.Header.Del(echo.HeaderXRealIP)
req.Header().Add(echo.HeaderXForwardedFor, ip) req.Header.Add(echo.HeaderXForwardedFor, ip)
h(c) h(c)
assert.Contains(t, ip, buf.String()) assert.Contains(t, ip, buf.String())

View File

@ -52,10 +52,10 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
if req.Method() == echo.POST { if req.Method == echo.POST {
m := config.Getter(c) m := config.Getter(c)
if m != "" { if m != "" {
req.SetMethod(m) req.Method = m
} }
} }
return next(c) return next(c)
@ -67,7 +67,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc {
// the request header. // the request header.
func MethodFromHeader(header string) MethodOverrideGetter { func MethodFromHeader(header string) MethodOverrideGetter {
return func(c echo.Context) string { return func(c echo.Context) string {
return c.Request().Header().Get(header) return c.Request().Header.Get(header)
} }
} }

View File

@ -3,10 +3,10 @@ package middleware
import ( import (
"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -18,32 +18,32 @@ func TestMethodOverride(t *testing.T) {
} }
// Override with http header // Override with http header
req := test.NewRequest(echo.POST, "/", nil) req, _ := http.NewRequest(echo.POST, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) req.Header.Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
m(h)(c) m(h)(c)
assert.Equal(t, echo.DELETE, req.Method()) assert.Equal(t, echo.DELETE, req.Method)
// Override with form parameter // Override with form parameter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")})
req = test.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE))) req, _ = http.NewRequest(echo.POST, "/", bytes.NewReader([]byte("_method="+echo.DELETE)))
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
req.Header().Set(echo.HeaderContentType, echo.MIMEApplicationForm) req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm)
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
m(h)(c) m(h)(c)
assert.Equal(t, echo.DELETE, req.Method()) assert.Equal(t, echo.DELETE, req.Method)
// Override with query paramter // Override with query paramter
m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")})
req = test.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil) req, _ = http.NewRequest(echo.POST, "/?_method="+echo.DELETE, nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
m(h)(c) m(h)(c)
assert.Equal(t, echo.DELETE, req.Method()) assert.Equal(t, echo.DELETE, req.Method)
// Ignore `GET` // Ignore `GET`
req = test.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
req.Header().Set(echo.HeaderXHTTPMethodOverride, echo.DELETE) req.Header.Set(echo.HeaderXHTTPMethodOverride, echo.DELETE)
assert.Equal(t, echo.GET, req.Method()) assert.Equal(t, echo.GET, req.Method)
} }

View File

@ -3,24 +3,24 @@ package middleware
import ( import (
"bytes" "bytes"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRecover(t *testing.T) { func TestRecover(t *testing.T) {
e := echo.New() e := echo.New()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
e.SetLogOutput(buf) e.Logger.SetOutput(buf)
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := Recover()(echo.HandlerFunc(func(c echo.Context) error { h := Recover()(echo.HandlerFunc(func(c echo.Context) error {
panic("test") panic("test")
})) }))
h(c) h(c)
assert.Equal(t, http.StatusInternalServerError, rec.Status()) assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Contains(t, buf.String(), "PANIC RECOVER") assert.Contains(t, buf.String(), "PANIC RECOVER")
} }

View File

@ -52,9 +52,10 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
host := req.Host() host := req.Host
uri := req.URI() uri := req.RequestURI
if !req.IsTLS() { println(uri)
if !c.IsTLS() {
return c.Redirect(config.Code, "https://"+host+uri) return c.Redirect(config.Code, "https://"+host+uri)
} }
return next(c) return next(c)
@ -88,9 +89,9 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
host := req.Host() host := req.Host
uri := req.URI() uri := req.RequestURI
if !req.IsTLS() && host[:3] != "www" { if !c.IsTLS() && host[:3] != "www" {
return c.Redirect(http.StatusMovedPermanently, "https://www."+host+uri) return c.Redirect(http.StatusMovedPermanently, "https://www."+host+uri)
} }
return next(c) return next(c)
@ -124,10 +125,10 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
scheme := req.Scheme() scheme := c.Scheme()
host := req.Host() host := req.Host
if host[:3] != "www" { if host[:3] != "www" {
uri := req.URI() uri := req.RequestURI
return c.Redirect(http.StatusMovedPermanently, scheme+"://www."+host+uri) return c.Redirect(http.StatusMovedPermanently, scheme+"://www."+host+uri)
} }
return next(c) return next(c)
@ -160,10 +161,10 @@ func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc {
} }
req := c.Request() req := c.Request()
scheme := req.Scheme() scheme := c.Scheme()
host := req.Host() host := req.Host
if host[:3] == "www" { if host[:3] == "www" {
uri := req.URI() uri := req.RequestURI
return c.Redirect(http.StatusMovedPermanently, scheme+"://"+host[4:]+uri) return c.Redirect(http.StatusMovedPermanently, scheme+"://"+host[4:]+uri)
} }
return next(c) return next(c)

View File

@ -2,61 +2,61 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestHTTPSRedirect(t *testing.T) { func TestRedirectHTTPSRedirect(t *testing.T) {
e := echo.New() e := echo.New()
next := func(c echo.Context) (err error) { next := func(c echo.Context) (err error) {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
} }
req := test.NewRequest(echo.GET, "http://labstack.com", nil) req, _ := http.NewRequest(echo.GET, "http://labstack.com", nil)
res := test.NewResponseRecorder() res := httptest.NewRecorder()
c := e.NewContext(req, res) c := e.NewContext(req, res)
HTTPSRedirect()(next)(c) HTTPSRedirect()(next)(c)
assert.Equal(t, http.StatusMovedPermanently, res.Status()) assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "https://labstack.com", res.Header().Get(echo.HeaderLocation)) assert.Equal(t, "https://labstack.com", res.Header().Get(echo.HeaderLocation))
} }
func TestHTTPSWWWRedirect(t *testing.T) { func TestRedirectHTTPSWWWRedirect(t *testing.T) {
e := echo.New() e := echo.New()
next := func(c echo.Context) (err error) { next := func(c echo.Context) (err error) {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
} }
req := test.NewRequest(echo.GET, "http://labstack.com", nil) req, _ := http.NewRequest(echo.GET, "http://labstack.com", nil)
res := test.NewResponseRecorder() res := httptest.NewRecorder()
c := e.NewContext(req, res) c := e.NewContext(req, res)
HTTPSWWWRedirect()(next)(c) HTTPSWWWRedirect()(next)(c)
assert.Equal(t, http.StatusMovedPermanently, res.Status()) assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "https://www.labstack.com", res.Header().Get(echo.HeaderLocation)) assert.Equal(t, "https://www.labstack.com", res.Header().Get(echo.HeaderLocation))
} }
func TestWWWRedirect(t *testing.T) { func TestRedirectWWWRedirect(t *testing.T) {
e := echo.New() e := echo.New()
next := func(c echo.Context) (err error) { next := func(c echo.Context) (err error) {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
} }
req := test.NewRequest(echo.GET, "http://labstack.com", nil) req, _ := http.NewRequest(echo.GET, "http://labstack.com", nil)
res := test.NewResponseRecorder() res := httptest.NewRecorder()
c := e.NewContext(req, res) c := e.NewContext(req, res)
WWWRedirect()(next)(c) WWWRedirect()(next)(c)
assert.Equal(t, http.StatusMovedPermanently, res.Status()) assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "http://www.labstack.com", res.Header().Get(echo.HeaderLocation)) assert.Equal(t, "http://www.labstack.com", res.Header().Get(echo.HeaderLocation))
} }
func TestNonWWWRedirect(t *testing.T) { func TestRedirectNonWWWRedirect(t *testing.T) {
e := echo.New() e := echo.New()
next := func(c echo.Context) (err error) { next := func(c echo.Context) (err error) {
return c.NoContent(http.StatusOK) return c.NoContent(http.StatusOK)
} }
req := test.NewRequest(echo.GET, "http://www.labstack.com", nil) req, _ := http.NewRequest(echo.GET, "http://www.labstack.com", nil)
res := test.NewResponseRecorder() res := httptest.NewRecorder()
c := e.NewContext(req, res) c := e.NewContext(req, res)
NonWWWRedirect()(next)(c) NonWWWRedirect()(next)(c)
assert.Equal(t, http.StatusMovedPermanently, res.Status()) assert.Equal(t, http.StatusMovedPermanently, res.Code)
assert.Equal(t, "http://labstack.com", res.Header().Get(echo.HeaderLocation)) assert.Equal(t, "http://labstack.com", res.Header().Get(echo.HeaderLocation))
} }

View File

@ -100,7 +100,7 @@ func SecureWithConfig(config SecureConfig) echo.MiddlewareFunc {
if config.XFrameOptions != "" { if config.XFrameOptions != "" {
res.Header().Set(echo.HeaderXFrameOptions, config.XFrameOptions) res.Header().Set(echo.HeaderXFrameOptions, config.XFrameOptions)
} }
if (req.IsTLS() || (req.Header().Get(echo.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 { if (c.IsTLS() || (req.Header.Get(echo.HeaderXForwardedProto) == "https")) && config.HSTSMaxAge != 0 {
subdomains := "" subdomains := ""
if !config.HSTSExcludeSubdomains { if !config.HSTSExcludeSubdomains {
subdomains = "; includeSubdomains" subdomains = "; includeSubdomains"

View File

@ -2,17 +2,17 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestSecure(t *testing.T) { func TestSecure(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := func(c echo.Context) error { h := func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
@ -27,8 +27,8 @@ func TestSecure(t *testing.T) {
assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy)) assert.Equal(t, "", rec.Header().Get(echo.HeaderContentSecurityPolicy))
// Custom // Custom
req.Header().Set(echo.HeaderXForwardedProto, "https") req.Header.Set(echo.HeaderXForwardedProto, "https")
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
SecureWithConfig(SecureConfig{ SecureWithConfig(SecureConfig{
XSSProtection: "", XSSProtection: "",

View File

@ -46,9 +46,9 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
} }
req := c.Request() req := c.Request()
url := req.URL() url := req.URL
path := url.Path() path := url.Path
qs := url.QueryString() qs := c.QueryString()
if path != "/" && path[len(path)-1] != '/' { if path != "/" && path[len(path)-1] != '/' {
path += "/" path += "/"
uri := path uri := path
@ -62,8 +62,8 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
} }
// Forward // Forward
req.SetURI(uri) req.RequestURI = uri
url.SetPath(path) url.Path = path
} }
return next(c) return next(c)
} }
@ -93,9 +93,9 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
} }
req := c.Request() req := c.Request()
url := req.URL() url := req.URL
path := url.Path() path := url.Path
qs := url.QueryString() qs := c.QueryString()
l := len(path) - 1 l := len(path) - 1
if l >= 0 && path != "/" && path[l] == '/' { if l >= 0 && path != "/" && path[l] == '/' {
path = path[:l] path = path[:l]
@ -110,8 +110,8 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
} }
// Forward // Forward
req.SetURI(uri) req.RequestURI = uri
url.SetPath(path) url.Path = path
} }
return next(c) return next(c)
} }

View File

@ -2,28 +2,28 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestAddTrailingSlash(t *testing.T) { func TestAddTrailingSlash(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/add-slash", nil) req, _ := http.NewRequest(echo.GET, "/add-slash", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := AddTrailingSlash()(func(c echo.Context) error { h := AddTrailingSlash()(func(c echo.Context) error {
return nil return nil
}) })
h(c) h(c)
assert.Equal(t, "/add-slash/", req.URL().Path()) assert.Equal(t, "/add-slash/", req.URL.Path)
assert.Equal(t, "/add-slash/", req.URI()) assert.Equal(t, "/add-slash/", req.RequestURI)
// With config // With config
req = test.NewRequest(echo.GET, "/add-slash?key=value", nil) req, _ = http.NewRequest(echo.GET, "/add-slash?key=value", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = AddTrailingSlashWithConfig(TrailingSlashConfig{ h = AddTrailingSlashWithConfig(TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently, RedirectCode: http.StatusMovedPermanently,
@ -31,25 +31,25 @@ func TestAddTrailingSlash(t *testing.T) {
return nil return nil
}) })
h(c) h(c)
assert.Equal(t, http.StatusMovedPermanently, rec.Status()) assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation)) assert.Equal(t, "/add-slash/?key=value", rec.Header().Get(echo.HeaderLocation))
} }
func TestRemoveTrailingSlash(t *testing.T) { func TestRemoveTrailingSlash(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/remove-slash/", nil) req, _ := http.NewRequest(echo.GET, "/remove-slash/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := RemoveTrailingSlash()(func(c echo.Context) error { h := RemoveTrailingSlash()(func(c echo.Context) error {
return nil return nil
}) })
h(c) h(c)
assert.Equal(t, "/remove-slash", req.URL().Path()) assert.Equal(t, "/remove-slash", req.URL.Path)
assert.Equal(t, "/remove-slash", req.URI()) assert.Equal(t, "/remove-slash", req.RequestURI)
// With config // With config
req = test.NewRequest(echo.GET, "/remove-slash/?key=value", nil) req, _ = http.NewRequest(echo.GET, "/remove-slash/?key=value", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{ h = RemoveTrailingSlashWithConfig(TrailingSlashConfig{
RedirectCode: http.StatusMovedPermanently, RedirectCode: http.StatusMovedPermanently,
@ -57,16 +57,16 @@ func TestRemoveTrailingSlash(t *testing.T) {
return nil return nil
}) })
h(c) h(c)
assert.Equal(t, http.StatusMovedPermanently, rec.Status()) assert.Equal(t, http.StatusMovedPermanently, rec.Code)
assert.Equal(t, "/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation)) assert.Equal(t, "/remove-slash?key=value", rec.Header().Get(echo.HeaderLocation))
// With bare URL // With bare URL
req = test.NewRequest(echo.GET, "http://localhost", nil) req, _ = http.NewRequest(echo.GET, "http://localhost", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
h = RemoveTrailingSlash()(func(c echo.Context) error { h = RemoveTrailingSlash()(func(c echo.Context) error {
return nil return nil
}) })
h(c) h(c)
assert.Equal(t, "", req.URL().Path()) assert.Equal(t, "", req.URL.Path)
} }

View File

@ -68,7 +68,7 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
} }
fs := http.Dir(config.Root) fs := http.Dir(config.Root)
p := c.Request().URL().Path() p := c.Request().URL.Path
if strings.Contains(c.Path(), "*") { // If serving from a group, e.g. `/static*`. if strings.Contains(c.Path(), "*") { // If serving from a group, e.g. `/static*`.
p = c.P(0) p = c.P(0)
} }

View File

@ -2,17 +2,17 @@ package middleware
import ( import (
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"github.com/labstack/echo" "github.com/labstack/echo"
"github.com/labstack/echo/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestStatic(t *testing.T) { func TestStatic(t *testing.T) {
e := echo.New() e := echo.New()
req := test.NewRequest(echo.GET, "/", nil) req, _ := http.NewRequest(echo.GET, "/", nil)
rec := test.NewResponseRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
h := Static("../_fixture")(func(c echo.Context) error { h := Static("../_fixture")(func(c echo.Context) error {
return echo.ErrNotFound return echo.ErrNotFound
@ -24,8 +24,8 @@ func TestStatic(t *testing.T) {
} }
// HTML5 mode // HTML5 mode
req = test.NewRequest(echo.GET, "/client", nil) req, _ = http.NewRequest(echo.GET, "/client", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
static := StaticWithConfig(StaticConfig{ static := StaticWithConfig(StaticConfig{
Root: "../_fixture", Root: "../_fixture",
@ -35,12 +35,12 @@ func TestStatic(t *testing.T) {
return echo.ErrNotFound return echo.ErrNotFound
}) })
if assert.NoError(t, h(c)) { if assert.NoError(t, h(c)) {
assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, http.StatusOK, rec.Code)
} }
// Browse // Browse
req = test.NewRequest(echo.GET, "/", nil) req, _ = http.NewRequest(echo.GET, "/", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
static = StaticWithConfig(StaticConfig{ static = StaticWithConfig(StaticConfig{
Root: "../_fixture/images", Root: "../_fixture/images",
@ -54,8 +54,8 @@ func TestStatic(t *testing.T) {
} }
// Not found // Not found
req = test.NewRequest(echo.GET, "/not-found", nil) req, _ = http.NewRequest(echo.GET, "/not-found", nil)
rec = test.NewResponseRecorder() rec = httptest.NewRecorder()
c = e.NewContext(req, rec) c = e.NewContext(req, rec)
static = StaticWithConfig(StaticConfig{ static = StaticWithConfig(StaticConfig{
Root: "../_fixture/images", Root: "../_fixture/images",

99
response.go Normal file
View File

@ -0,0 +1,99 @@
package echo
import (
"bufio"
"net"
"net/http"
)
type (
// Response wraps an http.ResponseWriter and implements its interface to be used
// by an HTTP handler to construct an HTTP response.
// See: https://golang.org/pkg/net/http/#ResponseWriter
Response struct {
writer http.ResponseWriter
Status int
Size int64
Committed bool
echo *Echo
}
)
// NewResponse creates a new instance of Response.
func NewResponse(w http.ResponseWriter, e *Echo) (r *Response) {
return &Response{writer: w, echo: e}
}
// SetWriter sets the http.ResponseWriter instance for this Response.
func (r *Response) SetWriter(w http.ResponseWriter) {
r.writer = w
}
// Writer returns the http.ResponseWriter instance for this Response.
func (r *Response) Writer() http.ResponseWriter {
return r.writer
}
// Header returns the header map for the writer that will be sent by
// WriteHeader. Changing the header after a call to WriteHeader (or Write) has
// no effect unless the modified headers were declared as trailers by setting
// the "Trailer" header before the call to WriteHeader (see example)
// To suppress implicit response headers, set their value to nil.
// Example: https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
func (r *Response) Header() http.Header {
return r.writer.Header()
}
// WriteHeader sends an HTTP response header with status code. If WriteHeader is
// not called explicitly, the first call to Write will trigger an implicit
// WriteHeader(http.StatusOK). Thus explicit calls to WriteHeader are mainly
// used to send error codes.
func (r *Response) WriteHeader(code int) {
if r.Committed {
r.echo.Logger.Warn("response already committed")
return
}
r.Status = code
r.writer.WriteHeader(code)
r.Committed = true
}
// Write writes the data to the connection as part of an HTTP reply.
func (r *Response) Write(b []byte) (n int, err error) {
if !r.Committed {
r.WriteHeader(http.StatusOK)
}
n, err = r.writer.Write(b)
r.Size += int64(n)
return
}
// Flush implements the http.Flusher interface to allow an HTTP handler to flush
// buffered data to the client.
// See [http.Flusher](https://golang.org/pkg/net/http/#Flusher)
func (r *Response) Flush() {
r.writer.(http.Flusher).Flush()
}
// Hijack implements the http.Hijacker interface to allow an HTTP handler to
// take over the connection.
// See [http.Hijacker](https://golang.org/pkg/net/http/#Hijacker)
func (r *Response) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.writer.(http.Hijacker).Hijack()
}
// CloseNotify implements the http.CloseNotifier interface to allow detecting
// when the underlying connection has gone away.
// This mechanism can be used to cancel long operations on the server if the
// client has disconnected before the response is ready.
// See [http.CloseNotifier](https://golang.org/pkg/net/http/#CloseNotifier)
func (r *Response) CloseNotify() <-chan bool {
return r.writer.(http.CloseNotifier).CloseNotify()
}
func (r *Response) reset(w http.ResponseWriter) {
r.writer = w
r.Size = 0
r.Status = http.StatusOK
r.Committed = false
}

View File

@ -1,48 +0,0 @@
package test
import (
"net/http"
"time"
)
type (
// Cookie implements `engine.Cookie`.
Cookie struct {
*http.Cookie
}
)
// Name implements `engine.Cookie#Name` function.
func (c *Cookie) Name() string {
return c.Cookie.Name
}
// Value implements `engine.Cookie#Value` function.
func (c *Cookie) Value() string {
return c.Cookie.Value
}
// Path implements `engine.Cookie#Path` function.
func (c *Cookie) Path() string {
return c.Cookie.Path
}
// Domain implements `engine.Cookie#Domain` function.
func (c *Cookie) Domain() string {
return c.Cookie.Domain
}
// Expires implements `engine.Cookie#Expires` function.
func (c *Cookie) Expires() time.Time {
return c.Cookie.Expires
}
// Secure implements `engine.Cookie#Secure` function.
func (c *Cookie) Secure() bool {
return c.Cookie.Secure
}
// HTTPOnly implements `engine.Cookie#HTTPOnly` function.
func (c *Cookie) HTTPOnly() bool {
return c.Cookie.HttpOnly
}

View File

@ -1,44 +0,0 @@
package test
import "net/http"
type (
Header struct {
header http.Header
}
)
func (h *Header) Add(key, val string) {
h.header.Add(key, val)
}
func (h *Header) Del(key string) {
h.header.Del(key)
}
func (h *Header) Get(key string) string {
return h.header.Get(key)
}
func (h *Header) Set(key, val string) {
h.header.Set(key, val)
}
func (h *Header) Keys() (keys []string) {
keys = make([]string, len(h.header))
i := 0
for k := range h.header {
keys[i] = k
i++
}
return
}
func (h *Header) Contains(key string) bool {
_, ok := h.header[key]
return ok
}
func (h *Header) reset(hdr http.Header) {
h.header = hdr
}

View File

@ -1,176 +0,0 @@
package test
import (
"errors"
"io"
"io/ioutil"
"mime/multipart"
"net"
"net/http"
"strings"
"github.com/labstack/echo/engine"
)
type (
Request struct {
request *http.Request
url engine.URL
header engine.Header
}
)
const (
defaultMemory = 32 << 20 // 32 MB
)
func NewRequest(method, url string, body io.Reader) engine.Request {
r, _ := http.NewRequest(method, url, body)
return &Request{
request: r,
url: &URL{url: r.URL},
header: &Header{r.Header},
}
}
func (r *Request) IsTLS() bool {
return r.request.TLS != nil
}
func (r *Request) Scheme() string {
if r.IsTLS() {
return "https"
}
return "http"
}
func (r *Request) Host() string {
return r.request.Host
}
func (r *Request) SetHost(host string) {
r.request.Host = host
}
func (r *Request) URL() engine.URL {
return r.url
}
func (r *Request) Header() engine.Header {
return r.header
}
func (r *Request) Referer() string {
return r.request.Referer()
}
// func Proto() string {
// return r.request.Proto()
// }
//
// func ProtoMajor() int {
// return r.request.ProtoMajor()
// }
//
// func ProtoMinor() int {
// return r.request.ProtoMinor()
// }
func (r *Request) ContentLength() int64 {
return r.request.ContentLength
}
func (r *Request) UserAgent() string {
return r.request.UserAgent()
}
func (r *Request) RemoteAddress() string {
return r.request.RemoteAddr
}
func (r *Request) RealIP() string {
ra := r.RemoteAddress()
if ip := r.Header().Get("X-Forwarded-For"); ip != "" {
ra = ip
} else if ip := r.Header().Get("X-Real-IP"); ip != "" {
ra = ip
} else {
ra, _, _ = net.SplitHostPort(ra)
}
return ra
}
func (r *Request) Method() string {
return r.request.Method
}
func (r *Request) SetMethod(method string) {
r.request.Method = method
}
func (r *Request) URI() string {
return r.request.RequestURI
}
func (r *Request) SetURI(uri string) {
r.request.RequestURI = uri
}
func (r *Request) Body() io.Reader {
return r.request.Body
}
func (r *Request) SetBody(reader io.Reader) {
r.request.Body = ioutil.NopCloser(reader)
}
func (r *Request) FormValue(name string) string {
return r.request.FormValue(name)
}
func (r *Request) FormParams() map[string][]string {
if strings.HasPrefix(r.header.Get("Content-Type"), "multipart/form-data") {
if err := r.request.ParseMultipartForm(defaultMemory); err != nil {
panic(err)
}
} else {
if err := r.request.ParseForm(); err != nil {
panic(err)
}
}
return map[string][]string(r.request.Form)
}
func (r *Request) FormFile(name string) (*multipart.FileHeader, error) {
_, fh, err := r.request.FormFile(name)
return fh, err
}
func (r *Request) MultipartForm() (*multipart.Form, error) {
err := r.request.ParseMultipartForm(defaultMemory)
return r.request.MultipartForm, err
}
func (r *Request) Cookie(name string) (engine.Cookie, error) {
c, err := r.request.Cookie(name)
if err != nil {
return nil, errors.New("cookie not found")
}
return &Cookie{c}, nil
}
// Cookies implements `engine.Request#Cookies` function.
func (r *Request) Cookies() []engine.Cookie {
cs := r.request.Cookies()
cookies := make([]engine.Cookie, len(cs))
for i, c := range cs {
cookies[i] = &Cookie{c}
}
return cookies
}
func (r *Request) reset(req *http.Request, h engine.Header, u engine.URL) {
r.request = req
r.header = h
r.url = u
}

View File

@ -1,103 +0,0 @@
package test
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"github.com/labstack/echo/engine"
"github.com/labstack/gommon/log"
)
type (
Response struct {
response http.ResponseWriter
header engine.Header
status int
size int64
committed bool
writer io.Writer
logger *log.Logger
}
ResponseRecorder struct {
engine.Response
Body *bytes.Buffer
}
)
func NewResponseRecorder() *ResponseRecorder {
rec := httptest.NewRecorder()
return &ResponseRecorder{
Response: &Response{
response: rec,
header: &Header{rec.Header()},
writer: rec,
logger: log.New("test"),
},
Body: rec.Body,
}
}
func (r *Response) Header() engine.Header {
return r.header
}
func (r *Response) WriteHeader(code int) {
if r.committed {
r.logger.Warn("response already committed")
return
}
r.status = code
r.response.WriteHeader(code)
r.committed = true
}
func (r *Response) Write(b []byte) (n int, err error) {
n, err = r.writer.Write(b)
r.size += int64(n)
return
}
// SetCookie implements `engine.Response#SetCookie` function.
func (r *Response) SetCookie(c engine.Cookie) {
http.SetCookie(r.response, &http.Cookie{
Name: c.Name(),
Value: c.Value(),
Path: c.Path(),
Domain: c.Domain(),
Expires: c.Expires(),
Secure: c.Secure(),
HttpOnly: c.HTTPOnly(),
})
}
func (r *Response) Status() int {
return r.status
}
func (r *Response) Size() int64 {
return r.size
}
func (r *Response) Committed() bool {
return r.committed
}
func (r *Response) SetWriter(w io.Writer) {
r.writer = w
}
func (r *Response) Writer() io.Writer {
return r.writer
}
func (r *Response) reset(w http.ResponseWriter, h engine.Header) {
r.response = w
r.header = h
r.status = http.StatusOK
r.size = 0
r.committed = false
r.writer = w
}

View File

@ -1,129 +0,0 @@
package test
import (
"net/http"
"sync"
"github.com/labstack/echo/engine"
"github.com/labstack/gommon/log"
)
type (
Server struct {
*http.Server
config *engine.Config
handler engine.Handler
pool *Pool
logger *log.Logger
}
Pool struct {
request sync.Pool
response sync.Pool
header sync.Pool
url sync.Pool
}
)
func New(addr string) *Server {
c := &engine.Config{Address: addr}
return NewConfig(c)
}
func NewTLS(addr, certFile, keyFile string) *Server {
c := &engine.Config{
Address: addr,
TLSCertFile: certFile,
TLSKeyFile: keyFile,
}
return NewConfig(c)
}
func NewConfig(c *engine.Config) (s *Server) {
s = &Server{
Server: new(http.Server),
config: c,
pool: &Pool{
request: sync.Pool{
New: func() interface{} {
return &Request{}
},
},
response: sync.Pool{
New: func() interface{} {
return &Response{logger: s.logger}
},
},
header: sync.Pool{
New: func() interface{} {
return &Header{}
},
},
url: sync.Pool{
New: func() interface{} {
return &URL{}
},
},
},
handler: engine.HandlerFunc(func(req engine.Request, res engine.Response) {
panic("echo: handler not set, use `Server#SetHandler()` to set it.")
}),
logger: log.New("echo"),
}
return
}
func (s *Server) SetHandler(h engine.Handler) {
s.handler = h
}
func (s *Server) SetLogger(l *log.Logger) {
s.logger = l
}
func (s *Server) Start() error {
if s.config.Listener == nil {
return s.startDefaultListener()
}
return s.startCustomListener()
}
func (s *Server) Stop() error {
return nil
}
func (s *Server) startDefaultListener() error {
c := s.config
if c.TLSCertFile != "" && c.TLSKeyFile != "" {
return s.ListenAndServeTLS(c.TLSCertFile, c.TLSKeyFile)
}
return s.ListenAndServe()
}
func (s *Server) startCustomListener() error {
return s.Serve(s.config.Listener)
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Request
req := s.pool.request.Get().(*Request)
reqHdr := s.pool.header.Get().(*Header)
reqURL := s.pool.url.Get().(*URL)
reqHdr.reset(r.Header)
reqURL.reset(r.URL)
req.reset(r, reqHdr, reqURL)
// Response
res := s.pool.response.Get().(*Response)
resHdr := s.pool.header.Get().(*Header)
resHdr.reset(w.Header())
res.reset(w, resHdr)
s.handler.ServeHTTP(req, res)
s.pool.request.Put(req)
s.pool.header.Put(reqHdr)
s.pool.url.Put(reqURL)
s.pool.response.Put(res)
s.pool.header.Put(resHdr)
}

View File

@ -1,44 +0,0 @@
package test
import "net/url"
type (
URL struct {
url *url.URL
query url.Values
}
)
func (u *URL) URL() *url.URL {
return u.url
}
func (u *URL) SetPath(path string) {
u.url.Path = path
}
func (u *URL) Path() string {
return u.url.Path
}
func (u *URL) QueryParam(name string) string {
if u.query == nil {
u.query = u.url.Query()
}
return u.query.Get(name)
}
func (u *URL) QueryParams() map[string][]string {
if u.query == nil {
u.query = u.url.Query()
}
return map[string][]string(u.query)
}
func (u *URL) QueryString() string {
return u.url.RawQuery
}
func (u *URL) reset(url *url.URL) {
u.url = url
}