diff --git a/binder.go b/binder.go
new file mode 100644
index 00000000..c17a6507
--- /dev/null
+++ b/binder.go
@@ -0,0 +1,173 @@
+package echo
+
+import (
+ "encoding/json"
+ "encoding/xml"
+ "errors"
+ "net/http"
+ "reflect"
+ "strconv"
+ "strings"
+)
+
+type (
+ // Binder is the interface that wraps the Bind method.
+ Binder interface {
+ Bind(interface{}, Context) error
+ }
+
+ binder struct{}
+)
+
+func (b *binder) Bind(i interface{}, c Context) (err error) {
+ req := c.Request()
+ ctype := req.Header().Get(HeaderContentType)
+ if req.Body() == nil {
+ err = NewHTTPError(http.StatusBadRequest, "request body can't be empty")
+ return
+ }
+ err = ErrUnsupportedMediaType
+ switch {
+ case strings.HasPrefix(ctype, MIMEApplicationJSON):
+ if err = json.NewDecoder(req.Body()).Decode(i); err != nil {
+ err = NewHTTPError(http.StatusBadRequest, err.Error())
+ }
+ case strings.HasPrefix(ctype, MIMEApplicationXML):
+ if err = xml.NewDecoder(req.Body()).Decode(i); err != nil {
+ err = NewHTTPError(http.StatusBadRequest, err.Error())
+ }
+ case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
+ if err = b.bindForm(i, req.FormParams()); err != nil {
+ err = NewHTTPError(http.StatusBadRequest, err.Error())
+ }
+ }
+ return
+}
+
+func (b *binder) bindForm(ptr interface{}, form map[string][]string) error {
+ typ := reflect.TypeOf(ptr).Elem()
+ val := reflect.ValueOf(ptr).Elem()
+
+ for i := 0; i < typ.NumField(); i++ {
+ typeField := typ.Field(i)
+ structField := val.Field(i)
+ if !structField.CanSet() {
+ continue
+ }
+ structFieldKind := structField.Kind()
+ inputFieldName := typeField.Tag.Get("form")
+
+ if inputFieldName == "" {
+ inputFieldName = typeField.Name
+ // If "form" tag is nil, we inspect if the field is a struct.
+ if structFieldKind == reflect.Struct {
+ err := b.bindForm(structField.Addr().Interface(), form)
+ if err != nil {
+ return err
+ }
+ continue
+ }
+ }
+ inputValue, exists := form[inputFieldName]
+ if !exists {
+ continue
+ }
+
+ numElems := len(inputValue)
+ if structFieldKind == reflect.Slice && numElems > 0 {
+ sliceOf := structField.Type().Elem().Kind()
+ slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
+ for i := 0; i < numElems; i++ {
+ if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
+ return err
+ }
+ }
+ val.Field(i).Set(slice)
+ } else {
+ if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
+ switch valueKind {
+ case reflect.Int:
+ return setIntField(val, 0, structField)
+ case reflect.Int8:
+ return setIntField(val, 8, structField)
+ case reflect.Int16:
+ return setIntField(val, 16, structField)
+ case reflect.Int32:
+ return setIntField(val, 32, structField)
+ case reflect.Int64:
+ return setIntField(val, 64, structField)
+ case reflect.Uint:
+ return setUintField(val, 0, structField)
+ case reflect.Uint8:
+ return setUintField(val, 8, structField)
+ case reflect.Uint16:
+ return setUintField(val, 16, structField)
+ case reflect.Uint32:
+ return setUintField(val, 32, structField)
+ case reflect.Uint64:
+ return setUintField(val, 64, structField)
+ case reflect.Bool:
+ return setBoolField(val, structField)
+ case reflect.Float32:
+ return setFloatField(val, 32, structField)
+ case reflect.Float64:
+ return setFloatField(val, 64, structField)
+ case reflect.String:
+ structField.SetString(val)
+ default:
+ return errors.New("unknown type")
+ }
+ return nil
+}
+
+func setIntField(value string, bitSize int, field reflect.Value) error {
+ if value == "" {
+ value = "0"
+ }
+ intVal, err := strconv.ParseInt(value, 10, bitSize)
+ if err == nil {
+ field.SetInt(intVal)
+ }
+ return err
+}
+
+func setUintField(value string, bitSize int, field reflect.Value) error {
+ if value == "" {
+ value = "0"
+ }
+ uintVal, err := strconv.ParseUint(value, 10, bitSize)
+ if err == nil {
+ field.SetUint(uintVal)
+ }
+ return err
+}
+
+func setBoolField(value string, field reflect.Value) error {
+ if value == "" {
+ value = "false"
+ }
+ boolVal, err := strconv.ParseBool(value)
+ if err == nil {
+ field.SetBool(boolVal)
+ }
+ return err
+}
+
+func setFloatField(value string, bitSize int, field reflect.Value) error {
+ if value == "" {
+ value = "0.0"
+ }
+ floatVal, err := strconv.ParseFloat(value, bitSize)
+ if err == nil {
+ field.SetFloat(floatVal)
+ }
+ return err
+}
diff --git a/binder_test.go b/binder_test.go
new file mode 100644
index 00000000..a1508b80
--- /dev/null
+++ b/binder_test.go
@@ -0,0 +1,214 @@
+package echo
+
+import (
+ "bytes"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/labstack/echo/test"
+ "github.com/stretchr/testify/assert"
+)
+
+type (
+ binderTestStruct struct {
+ I int
+ I8 int8
+ I16 int16
+ I32 int32
+ I64 int64
+ UI uint
+ UI8 uint8
+ UI16 uint16
+ UI32 uint32
+ UI64 uint64
+ B bool
+ F32 float32
+ F64 float64
+ S string
+ cantSet string
+ DoesntExist string
+ }
+)
+
+func (t binderTestStruct) GetCantSet() string {
+ return t.cantSet
+}
+
+var values = map[string][]string{
+ "I": {"0"},
+ "I8": {"8"},
+ "I16": {"16"},
+ "I32": {"32"},
+ "I64": {"64"},
+ "UI": {"0"},
+ "UI8": {"8"},
+ "UI16": {"16"},
+ "UI32": {"32"},
+ "UI64": {"64"},
+ "B": {"true"},
+ "F32": {"32.5"},
+ "F64": {"64.5"},
+ "S": {"test"},
+ "cantSet": {"test"},
+}
+
+func TestBinderJSON(t *testing.T) {
+ testBinderOkay(t, strings.NewReader(userJSON), MIMEApplicationJSON)
+ testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON)
+}
+
+func TestBinderXML(t *testing.T) {
+ testBinderOkay(t, strings.NewReader(userXML), MIMEApplicationXML)
+ testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationXML)
+}
+
+func TestBinderForm(t *testing.T) {
+ testBinderOkay(t, strings.NewReader(userForm), MIMEApplicationForm)
+}
+
+func TestBinderMultipartForm(t *testing.T) {
+ body := new(bytes.Buffer)
+ mw := multipart.NewWriter(body)
+ mw.WriteField("id", "1")
+ mw.WriteField("name", "Jon Snow")
+ mw.Close()
+ testBinderOkay(t, body, mw.FormDataContentType())
+}
+
+func TestBinderUnsupportedMediaType(t *testing.T) {
+ testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON)
+}
+
+// func assertCustomer(t *testing.T, c *user) {
+// assert.Equal(t, 1, c.ID)
+// assert.Equal(t, "Joe", c.Name)
+// }
+
+func TestBinderbindForm(t *testing.T) {
+ ts := new(binderTestStruct)
+ b := new(binder)
+ b.bindForm(ts, values)
+ assertBinderTestStruct(t, ts)
+}
+
+func TestBinderSetWithProperType(t *testing.T) {
+ ts := new(binderTestStruct)
+ typ := reflect.TypeOf(ts).Elem()
+ val := reflect.ValueOf(ts).Elem()
+ for i := 0; i < typ.NumField(); i++ {
+ typeField := typ.Field(i)
+ structField := val.Field(i)
+ if !structField.CanSet() {
+ continue
+ }
+ if len(values[typeField.Name]) == 0 {
+ continue
+ }
+ val := values[typeField.Name][0]
+ err := setWithProperType(typeField.Type.Kind(), val, structField)
+ assert.NoError(t, err)
+ }
+ assertBinderTestStruct(t, ts)
+
+ type foo struct {
+ Bar bytes.Buffer
+ }
+ v := &foo{}
+ typ = reflect.TypeOf(v).Elem()
+ val = reflect.ValueOf(v).Elem()
+ assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
+}
+
+func TestBinderSetFields(t *testing.T) {
+ ts := new(binderTestStruct)
+ val := reflect.ValueOf(ts).Elem()
+ // Int
+ if assert.NoError(t, setIntField("5", 0, val.FieldByName("I"))) {
+ assert.Equal(t, 5, ts.I)
+ }
+ if assert.NoError(t, setIntField("", 0, val.FieldByName("I"))) {
+ assert.Equal(t, 0, ts.I)
+ }
+
+ // Uint
+ if assert.NoError(t, setUintField("10", 0, val.FieldByName("UI"))) {
+ assert.Equal(t, uint(10), ts.UI)
+ }
+ if assert.NoError(t, setUintField("", 0, val.FieldByName("UI"))) {
+ assert.Equal(t, uint(0), ts.UI)
+ }
+
+ // Float
+ if assert.NoError(t, setFloatField("15.5", 0, val.FieldByName("F32"))) {
+ assert.Equal(t, float32(15.5), ts.F32)
+ }
+ if assert.NoError(t, setFloatField("", 0, val.FieldByName("F32"))) {
+ assert.Equal(t, float32(0.0), ts.F32)
+ }
+
+ // Bool
+ if assert.NoError(t, setBoolField("true", val.FieldByName("B"))) {
+ assert.Equal(t, true, ts.B)
+ }
+ if assert.NoError(t, setBoolField("", val.FieldByName("B"))) {
+ assert.Equal(t, false, ts.B)
+ }
+}
+
+func assertBinderTestStruct(t *testing.T, ts *binderTestStruct) {
+ assert.Equal(t, 0, ts.I)
+ assert.Equal(t, int8(8), ts.I8)
+ assert.Equal(t, int16(16), ts.I16)
+ assert.Equal(t, int32(32), ts.I32)
+ assert.Equal(t, int64(64), ts.I64)
+ assert.Equal(t, uint(0), ts.UI)
+ assert.Equal(t, uint8(8), ts.UI8)
+ assert.Equal(t, uint16(16), ts.UI16)
+ assert.Equal(t, uint32(32), ts.UI32)
+ assert.Equal(t, uint64(64), ts.UI64)
+ assert.Equal(t, true, ts.B)
+ assert.Equal(t, float32(32.5), ts.F32)
+ assert.Equal(t, float64(64.5), ts.F64)
+ assert.Equal(t, "test", ts.S)
+ assert.Equal(t, "", ts.GetCantSet())
+}
+
+func testBinderOkay(t *testing.T, r io.Reader, ctype string) {
+ e := New()
+ req := test.NewRequest(POST, "/", r)
+ rec := test.NewResponseRecorder()
+ c := e.NewContext(req, rec)
+ req.Header().Set(HeaderContentType, ctype)
+ u := new(user)
+ err := c.Bind(u)
+ if assert.NoError(t, err) {
+ assert.Equal(t, 1, u.ID)
+ assert.Equal(t, "Jon Snow", u.Name)
+ }
+}
+
+func testBinderError(t *testing.T, r io.Reader, ctype string) {
+ e := New()
+ req := test.NewRequest(POST, "/", r)
+ rec := test.NewResponseRecorder()
+ c := e.NewContext(req, rec)
+ req.Header().Set(HeaderContentType, ctype)
+ u := new(user)
+ err := c.Bind(u)
+
+ switch {
+ case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML),
+ strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm):
+ if assert.IsType(t, new(HTTPError), err) {
+ assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
+ }
+ default:
+ if assert.IsType(t, new(HTTPError), err) {
+ assert.Equal(t, ErrUnsupportedMediaType, err)
+ }
+ }
+}
diff --git a/context_test.go b/context_test.go
index 221c32c7..8059ddec 100644
--- a/context_test.go
+++ b/context_test.go
@@ -30,10 +30,6 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context)
}
func TestContext(t *testing.T) {
- userJSON := `{"id":"1","name":"Joe"}`
- userXML := `1Joe`
- invalidContent := "invalid content"
-
e := New()
req := test.NewRequest(POST, "/", strings.NewReader(userJSON))
rec := test.NewResponseRecorder()
@@ -58,26 +54,8 @@ func TestContext(t *testing.T) {
assert.Equal(t, "1", c.Param("id"))
// Store
- c.Set("user", "Joe")
- assert.Equal(t, "Joe", c.Get("user"))
-
- //------
- // Bind
- //------
-
- // JSON
- testBindOk(t, c, MIMEApplicationJSON)
- c.request = test.NewRequest(POST, "/", strings.NewReader(invalidContent))
- testBindError(t, c, MIMEApplicationJSON)
-
- // XML
- c.request = test.NewRequest(POST, "/", strings.NewReader(userXML))
- testBindOk(t, c, MIMEApplicationXML)
- c.request = test.NewRequest(POST, "/", strings.NewReader(invalidContent))
- testBindError(t, c, MIMEApplicationXML)
-
- // Unsupported
- testBindError(t, c, "")
+ c.Set("user", "Jon Snow")
+ assert.Equal(t, "Jon Snow", c.Get("user"))
//--------
// Render
@@ -87,20 +65,20 @@ func TestContext(t *testing.T) {
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
c.echo.SetRenderer(tpl)
- err := c.Render(http.StatusOK, "hello", "Joe")
+ err := c.Render(http.StatusOK, "hello", "Jon Snow")
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status())
- assert.Equal(t, "Hello, Joe!", rec.Body.String())
+ assert.Equal(t, "Hello, Jon Snow!", rec.Body.String())
}
c.echo.renderer = nil
- err = c.Render(http.StatusOK, "hello", "Joe")
+ err = c.Render(http.StatusOK, "hello", "Jon Snow")
assert.Error(t, err)
// JSON
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec).(*context)
- err = c.JSON(http.StatusOK, user{"1", "Joe"})
+ err = c.JSON(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -117,7 +95,7 @@ func TestContext(t *testing.T) {
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec).(*context)
callback := "callback"
- err = c.JSONP(http.StatusOK, callback, user{"1", "Joe"})
+ err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -127,7 +105,7 @@ func TestContext(t *testing.T) {
// XML
rec = test.NewResponseRecorder()
c = e.NewContext(req, rec).(*context)
- err = c.XML(http.StatusOK, user{"1", "Joe"})
+ err = c.XML(http.StatusOK, user{1, "Jon Snow"})
if assert.NoError(t, err) {
assert.Equal(t, http.StatusOK, rec.Status())
assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType))
@@ -283,31 +261,3 @@ func TestContextHandler(t *testing.T) {
c.Handler()(c)
assert.Equal(t, "handler", b.String())
}
-
-func testBindOk(t *testing.T, c Context, ct string) {
- c.Request().Header().Set(HeaderContentType, ct)
- u := new(user)
- err := c.Bind(u)
- if assert.NoError(t, err) {
- assert.Equal(t, "1", u.ID)
- assert.Equal(t, "Joe", u.Name)
- }
-}
-
-func testBindError(t *testing.T, c Context, ct string) {
- c.Request().Header().Set(HeaderContentType, ct)
- u := new(user)
- err := c.Bind(u)
-
- switch ct {
- case MIMEApplicationJSON, MIMEApplicationXML:
- if assert.IsType(t, new(HTTPError), err) {
- assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code)
- }
- default:
- if assert.IsType(t, new(HTTPError), err) {
- assert.Equal(t, ErrUnsupportedMediaType, err)
- }
-
- }
-}
diff --git a/echo.go b/echo.go
index f229a7ef..e150b429 100644
--- a/echo.go
+++ b/echo.go
@@ -39,8 +39,6 @@ package echo
import (
"bytes"
- "encoding/json"
- "encoding/xml"
"errors"
"fmt"
"io"
@@ -48,7 +46,6 @@ import (
"path"
"reflect"
"runtime"
- "strings"
"sync"
"github.com/labstack/echo/engine"
@@ -93,14 +90,6 @@ type (
// HTTPErrorHandler is a centralized HTTP error handler.
HTTPErrorHandler func(error, Context)
- // Binder is the interface that wraps the Bind function.
- Binder interface {
- Bind(interface{}, Context) error
- }
-
- binder struct {
- }
-
// Validator is the interface that wraps the Validate function.
Validator interface {
Validate() error
@@ -581,22 +570,6 @@ func (e *HTTPError) Error() string {
return e.Message
}
-func (b *binder) Bind(i interface{}, c Context) (err error) {
- req := c.Request()
- ct := req.Header().Get(HeaderContentType)
- err = ErrUnsupportedMediaType
- if strings.HasPrefix(ct, MIMEApplicationJSON) {
- if err = json.NewDecoder(req.Body()).Decode(i); err != nil {
- err = NewHTTPError(http.StatusBadRequest, err.Error())
- }
- } else if strings.HasPrefix(ct, MIMEApplicationXML) {
- if err = xml.NewDecoder(req.Body()).Decode(i); err != nil {
- err = NewHTTPError(http.StatusBadRequest, err.Error())
- }
- }
- return
-}
-
// WrapMiddleware wrap `echo.HandlerFunc` into `echo.MiddlewareFunc`.
func WrapMiddleware(h HandlerFunc) MiddlewareFunc {
return func(next HandlerFunc) HandlerFunc {
diff --git a/echo_test.go b/echo_test.go
index b174fdcd..c245679f 100644
--- a/echo_test.go
+++ b/echo_test.go
@@ -17,11 +17,18 @@ import (
type (
user struct {
- ID string `json:"id" xml:"id"`
- Name string `json:"name" xml:"name"`
+ ID int `json:"id" xml:"id" form:"id"`
+ Name string `json:"name" xml:"name" form:"name"`
}
)
+const (
+ userJSON = `{"id":1,"name":"Jon Snow"}`
+ userXML = `1Jon Snow`
+ userForm = `id=1&name=Jon Snow`
+ invalidContent = "invalid content"
+)
+
func TestEcho(t *testing.T) {
e := New()
req := test.NewRequest(GET, "/", nil)
diff --git a/engine/standard/request.go b/engine/standard/request.go
index 9d311a84..a955e873 100644
--- a/engine/standard/request.go
+++ b/engine/standard/request.go
@@ -5,7 +5,9 @@ import (
"io/ioutil"
"mime/multipart"
"net/http"
+ "strings"
+ "github.com/labstack/echo"
"github.com/labstack/echo/engine"
"github.com/labstack/gommon/log"
)
@@ -20,6 +22,10 @@ type (
}
)
+const (
+ defaultMemory = 32 << 20 // 32 MB
+)
+
// NewRequest returns `Request` instance.
func NewRequest(r *http.Request, l *log.Logger) *Request {
return &Request{
@@ -122,10 +128,16 @@ func (r *Request) FormValue(name string) string {
// FormParams implements `engine.Request#FormParams` function.
func (r *Request) FormParams() map[string][]string {
- if err := r.ParseForm(); err != nil {
- r.logger.Error(err)
+ if strings.HasPrefix(r.header.Get(echo.HeaderContentType), echo.MIMEMultipartForm) {
+ if err := r.ParseMultipartForm(defaultMemory); err != nil {
+ r.logger.Error(err)
+ }
+ } else {
+ if err := r.ParseForm(); err != nil {
+ r.logger.Error(err)
+ }
}
- return map[string][]string(r.Request.PostForm)
+ return map[string][]string(r.Request.Form)
}
// FormFile implements `engine.Request#FormFile` function.
@@ -136,7 +148,7 @@ func (r *Request) FormFile(name string) (*multipart.FileHeader, error) {
// MultipartForm implements `engine.Request#MultipartForm` function.
func (r *Request) MultipartForm() (*multipart.Form, error) {
- err := r.ParseMultipartForm(32 << 20) // 32 MB
+ err := r.ParseMultipartForm(defaultMemory)
return r.Request.MultipartForm, err
}
diff --git a/middleware/static_test.go b/middleware/static_test.go
new file mode 100644
index 00000000..c870d7c1
--- /dev/null
+++ b/middleware/static_test.go
@@ -0,0 +1 @@
+package middleware
diff --git a/test/request.go b/test/request.go
index 8240509c..93118089 100644
--- a/test/request.go
+++ b/test/request.go
@@ -5,6 +5,7 @@ import (
"io/ioutil"
"mime/multipart"
"net/http"
+ "strings"
"github.com/labstack/echo/engine"
)
@@ -17,6 +18,10 @@ type (
}
)
+const (
+ defaultMemory = 32 << 20 // 32 MB
+)
+
func NewRequest(method, url string, body io.Reader) engine.Request {
r, _ := http.NewRequest(method, url, body)
r.RequestURI = url
@@ -103,8 +108,16 @@ func (r *Request) FormValue(name string) string {
}
func (r *Request) FormParams() map[string][]string {
- r.request.ParseForm()
- return map[string][]string(r.request.PostForm)
+ if strings.HasPrefix(r.header.Get("Content-Type"), "multipart/form-data") {
+ if err := r.request.ParseMultipartForm(defaultMemory); err != nil {
+ panic(err)
+ }
+ } else {
+ if err := r.request.ParseForm(); err != nil {
+ panic(err)
+ }
+ }
+ return map[string][]string(r.request.Form)
}
func (r *Request) FormFile(name string) (*multipart.FileHeader, error) {
@@ -113,7 +126,7 @@ func (r *Request) FormFile(name string) (*multipart.FileHeader, error) {
}
func (r *Request) MultipartForm() (*multipart.Form, error) {
- err := r.request.ParseMultipartForm(32 << 20) // 32 MB
+ err := r.request.ParseMultipartForm(defaultMemory)
return r.request.MultipartForm, err
}