diff --git a/binder.go b/binder.go new file mode 100644 index 00000000..c17a6507 --- /dev/null +++ b/binder.go @@ -0,0 +1,173 @@ +package echo + +import ( + "encoding/json" + "encoding/xml" + "errors" + "net/http" + "reflect" + "strconv" + "strings" +) + +type ( + // Binder is the interface that wraps the Bind method. + Binder interface { + Bind(interface{}, Context) error + } + + binder struct{} +) + +func (b *binder) Bind(i interface{}, c Context) (err error) { + req := c.Request() + ctype := req.Header().Get(HeaderContentType) + if req.Body() == nil { + err = NewHTTPError(http.StatusBadRequest, "request body can't be empty") + return + } + err = ErrUnsupportedMediaType + switch { + case strings.HasPrefix(ctype, MIMEApplicationJSON): + if err = json.NewDecoder(req.Body()).Decode(i); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + case strings.HasPrefix(ctype, MIMEApplicationXML): + if err = xml.NewDecoder(req.Body()).Decode(i); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + case strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): + if err = b.bindForm(i, req.FormParams()); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + } + return +} + +func (b *binder) bindForm(ptr interface{}, form map[string][]string) error { + typ := reflect.TypeOf(ptr).Elem() + val := reflect.ValueOf(ptr).Elem() + + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + if !structField.CanSet() { + continue + } + structFieldKind := structField.Kind() + inputFieldName := typeField.Tag.Get("form") + + if inputFieldName == "" { + inputFieldName = typeField.Name + // If "form" tag is nil, we inspect if the field is a struct. + if structFieldKind == reflect.Struct { + err := b.bindForm(structField.Addr().Interface(), form) + if err != nil { + return err + } + continue + } + } + inputValue, exists := form[inputFieldName] + if !exists { + continue + } + + numElems := len(inputValue) + if structFieldKind == reflect.Slice && numElems > 0 { + sliceOf := structField.Type().Elem().Kind() + slice := reflect.MakeSlice(structField.Type(), numElems, numElems) + for i := 0; i < numElems; i++ { + if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { + return err + } + } + val.Field(i).Set(slice) + } else { + if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { + return err + } + } + } + return nil +} + +func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { + switch valueKind { + case reflect.Int: + return setIntField(val, 0, structField) + case reflect.Int8: + return setIntField(val, 8, structField) + case reflect.Int16: + return setIntField(val, 16, structField) + case reflect.Int32: + return setIntField(val, 32, structField) + case reflect.Int64: + return setIntField(val, 64, structField) + case reflect.Uint: + return setUintField(val, 0, structField) + case reflect.Uint8: + return setUintField(val, 8, structField) + case reflect.Uint16: + return setUintField(val, 16, structField) + case reflect.Uint32: + return setUintField(val, 32, structField) + case reflect.Uint64: + return setUintField(val, 64, structField) + case reflect.Bool: + return setBoolField(val, structField) + case reflect.Float32: + return setFloatField(val, 32, structField) + case reflect.Float64: + return setFloatField(val, 64, structField) + case reflect.String: + structField.SetString(val) + default: + return errors.New("unknown type") + } + return nil +} + +func setIntField(value string, bitSize int, field reflect.Value) error { + if value == "" { + value = "0" + } + intVal, err := strconv.ParseInt(value, 10, bitSize) + if err == nil { + field.SetInt(intVal) + } + return err +} + +func setUintField(value string, bitSize int, field reflect.Value) error { + if value == "" { + value = "0" + } + uintVal, err := strconv.ParseUint(value, 10, bitSize) + if err == nil { + field.SetUint(uintVal) + } + return err +} + +func setBoolField(value string, field reflect.Value) error { + if value == "" { + value = "false" + } + boolVal, err := strconv.ParseBool(value) + if err == nil { + field.SetBool(boolVal) + } + return err +} + +func setFloatField(value string, bitSize int, field reflect.Value) error { + if value == "" { + value = "0.0" + } + floatVal, err := strconv.ParseFloat(value, bitSize) + if err == nil { + field.SetFloat(floatVal) + } + return err +} diff --git a/binder_test.go b/binder_test.go new file mode 100644 index 00000000..a1508b80 --- /dev/null +++ b/binder_test.go @@ -0,0 +1,214 @@ +package echo + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "reflect" + "strings" + "testing" + + "github.com/labstack/echo/test" + "github.com/stretchr/testify/assert" +) + +type ( + binderTestStruct struct { + I int + I8 int8 + I16 int16 + I32 int32 + I64 int64 + UI uint + UI8 uint8 + UI16 uint16 + UI32 uint32 + UI64 uint64 + B bool + F32 float32 + F64 float64 + S string + cantSet string + DoesntExist string + } +) + +func (t binderTestStruct) GetCantSet() string { + return t.cantSet +} + +var values = map[string][]string{ + "I": {"0"}, + "I8": {"8"}, + "I16": {"16"}, + "I32": {"32"}, + "I64": {"64"}, + "UI": {"0"}, + "UI8": {"8"}, + "UI16": {"16"}, + "UI32": {"32"}, + "UI64": {"64"}, + "B": {"true"}, + "F32": {"32.5"}, + "F64": {"64.5"}, + "S": {"test"}, + "cantSet": {"test"}, +} + +func TestBinderJSON(t *testing.T) { + testBinderOkay(t, strings.NewReader(userJSON), MIMEApplicationJSON) + testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON) +} + +func TestBinderXML(t *testing.T) { + testBinderOkay(t, strings.NewReader(userXML), MIMEApplicationXML) + testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationXML) +} + +func TestBinderForm(t *testing.T) { + testBinderOkay(t, strings.NewReader(userForm), MIMEApplicationForm) +} + +func TestBinderMultipartForm(t *testing.T) { + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + mw.WriteField("id", "1") + mw.WriteField("name", "Jon Snow") + mw.Close() + testBinderOkay(t, body, mw.FormDataContentType()) +} + +func TestBinderUnsupportedMediaType(t *testing.T) { + testBinderError(t, strings.NewReader(invalidContent), MIMEApplicationJSON) +} + +// func assertCustomer(t *testing.T, c *user) { +// assert.Equal(t, 1, c.ID) +// assert.Equal(t, "Joe", c.Name) +// } + +func TestBinderbindForm(t *testing.T) { + ts := new(binderTestStruct) + b := new(binder) + b.bindForm(ts, values) + assertBinderTestStruct(t, ts) +} + +func TestBinderSetWithProperType(t *testing.T) { + ts := new(binderTestStruct) + typ := reflect.TypeOf(ts).Elem() + val := reflect.ValueOf(ts).Elem() + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + if !structField.CanSet() { + continue + } + if len(values[typeField.Name]) == 0 { + continue + } + val := values[typeField.Name][0] + err := setWithProperType(typeField.Type.Kind(), val, structField) + assert.NoError(t, err) + } + assertBinderTestStruct(t, ts) + + type foo struct { + Bar bytes.Buffer + } + v := &foo{} + typ = reflect.TypeOf(v).Elem() + val = reflect.ValueOf(v).Elem() + assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0))) +} + +func TestBinderSetFields(t *testing.T) { + ts := new(binderTestStruct) + val := reflect.ValueOf(ts).Elem() + // Int + if assert.NoError(t, setIntField("5", 0, val.FieldByName("I"))) { + assert.Equal(t, 5, ts.I) + } + if assert.NoError(t, setIntField("", 0, val.FieldByName("I"))) { + assert.Equal(t, 0, ts.I) + } + + // Uint + if assert.NoError(t, setUintField("10", 0, val.FieldByName("UI"))) { + assert.Equal(t, uint(10), ts.UI) + } + if assert.NoError(t, setUintField("", 0, val.FieldByName("UI"))) { + assert.Equal(t, uint(0), ts.UI) + } + + // Float + if assert.NoError(t, setFloatField("15.5", 0, val.FieldByName("F32"))) { + assert.Equal(t, float32(15.5), ts.F32) + } + if assert.NoError(t, setFloatField("", 0, val.FieldByName("F32"))) { + assert.Equal(t, float32(0.0), ts.F32) + } + + // Bool + if assert.NoError(t, setBoolField("true", val.FieldByName("B"))) { + assert.Equal(t, true, ts.B) + } + if assert.NoError(t, setBoolField("", val.FieldByName("B"))) { + assert.Equal(t, false, ts.B) + } +} + +func assertBinderTestStruct(t *testing.T, ts *binderTestStruct) { + assert.Equal(t, 0, ts.I) + assert.Equal(t, int8(8), ts.I8) + assert.Equal(t, int16(16), ts.I16) + assert.Equal(t, int32(32), ts.I32) + assert.Equal(t, int64(64), ts.I64) + assert.Equal(t, uint(0), ts.UI) + assert.Equal(t, uint8(8), ts.UI8) + assert.Equal(t, uint16(16), ts.UI16) + assert.Equal(t, uint32(32), ts.UI32) + assert.Equal(t, uint64(64), ts.UI64) + assert.Equal(t, true, ts.B) + assert.Equal(t, float32(32.5), ts.F32) + assert.Equal(t, float64(64.5), ts.F64) + assert.Equal(t, "test", ts.S) + assert.Equal(t, "", ts.GetCantSet()) +} + +func testBinderOkay(t *testing.T, r io.Reader, ctype string) { + e := New() + req := test.NewRequest(POST, "/", r) + rec := test.NewResponseRecorder() + c := e.NewContext(req, rec) + req.Header().Set(HeaderContentType, ctype) + u := new(user) + err := c.Bind(u) + if assert.NoError(t, err) { + assert.Equal(t, 1, u.ID) + assert.Equal(t, "Jon Snow", u.Name) + } +} + +func testBinderError(t *testing.T, r io.Reader, ctype string) { + e := New() + req := test.NewRequest(POST, "/", r) + rec := test.NewResponseRecorder() + c := e.NewContext(req, rec) + req.Header().Set(HeaderContentType, ctype) + u := new(user) + err := c.Bind(u) + + switch { + case strings.HasPrefix(ctype, MIMEApplicationJSON), strings.HasPrefix(ctype, MIMEApplicationXML), + strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) + } + default: + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, ErrUnsupportedMediaType, err) + } + } +} diff --git a/context_test.go b/context_test.go index 221c32c7..8059ddec 100644 --- a/context_test.go +++ b/context_test.go @@ -30,10 +30,6 @@ func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) } func TestContext(t *testing.T) { - userJSON := `{"id":"1","name":"Joe"}` - userXML := `1Joe` - invalidContent := "invalid content" - e := New() req := test.NewRequest(POST, "/", strings.NewReader(userJSON)) rec := test.NewResponseRecorder() @@ -58,26 +54,8 @@ func TestContext(t *testing.T) { assert.Equal(t, "1", c.Param("id")) // Store - c.Set("user", "Joe") - assert.Equal(t, "Joe", c.Get("user")) - - //------ - // Bind - //------ - - // JSON - testBindOk(t, c, MIMEApplicationJSON) - c.request = test.NewRequest(POST, "/", strings.NewReader(invalidContent)) - testBindError(t, c, MIMEApplicationJSON) - - // XML - c.request = test.NewRequest(POST, "/", strings.NewReader(userXML)) - testBindOk(t, c, MIMEApplicationXML) - c.request = test.NewRequest(POST, "/", strings.NewReader(invalidContent)) - testBindError(t, c, MIMEApplicationXML) - - // Unsupported - testBindError(t, c, "") + c.Set("user", "Jon Snow") + assert.Equal(t, "Jon Snow", c.Get("user")) //-------- // Render @@ -87,20 +65,20 @@ func TestContext(t *testing.T) { templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } c.echo.SetRenderer(tpl) - err := c.Render(http.StatusOK, "hello", "Joe") + err := c.Render(http.StatusOK, "hello", "Jon Snow") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Status()) - assert.Equal(t, "Hello, Joe!", rec.Body.String()) + assert.Equal(t, "Hello, Jon Snow!", rec.Body.String()) } c.echo.renderer = nil - err = c.Render(http.StatusOK, "hello", "Joe") + err = c.Render(http.StatusOK, "hello", "Jon Snow") assert.Error(t, err) // JSON rec = test.NewResponseRecorder() c = e.NewContext(req, rec).(*context) - err = c.JSON(http.StatusOK, user{"1", "Joe"}) + err = c.JSON(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, MIMEApplicationJSONCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -117,7 +95,7 @@ func TestContext(t *testing.T) { rec = test.NewResponseRecorder() c = e.NewContext(req, rec).(*context) callback := "callback" - err = c.JSONP(http.StatusOK, callback, user{"1", "Joe"}) + err = c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -127,7 +105,7 @@ func TestContext(t *testing.T) { // XML rec = test.NewResponseRecorder() c = e.NewContext(req, rec).(*context) - err = c.XML(http.StatusOK, user{"1", "Joe"}) + err = c.XML(http.StatusOK, user{1, "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Status()) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -283,31 +261,3 @@ func TestContextHandler(t *testing.T) { c.Handler()(c) assert.Equal(t, "handler", b.String()) } - -func testBindOk(t *testing.T, c Context, ct string) { - c.Request().Header().Set(HeaderContentType, ct) - u := new(user) - err := c.Bind(u) - if assert.NoError(t, err) { - assert.Equal(t, "1", u.ID) - assert.Equal(t, "Joe", u.Name) - } -} - -func testBindError(t *testing.T, c Context, ct string) { - c.Request().Header().Set(HeaderContentType, ct) - u := new(user) - err := c.Bind(u) - - switch ct { - case MIMEApplicationJSON, MIMEApplicationXML: - if assert.IsType(t, new(HTTPError), err) { - assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) - } - default: - if assert.IsType(t, new(HTTPError), err) { - assert.Equal(t, ErrUnsupportedMediaType, err) - } - - } -} diff --git a/echo.go b/echo.go index f229a7ef..e150b429 100644 --- a/echo.go +++ b/echo.go @@ -39,8 +39,6 @@ package echo import ( "bytes" - "encoding/json" - "encoding/xml" "errors" "fmt" "io" @@ -48,7 +46,6 @@ import ( "path" "reflect" "runtime" - "strings" "sync" "github.com/labstack/echo/engine" @@ -93,14 +90,6 @@ type ( // HTTPErrorHandler is a centralized HTTP error handler. HTTPErrorHandler func(error, Context) - // Binder is the interface that wraps the Bind function. - Binder interface { - Bind(interface{}, Context) error - } - - binder struct { - } - // Validator is the interface that wraps the Validate function. Validator interface { Validate() error @@ -581,22 +570,6 @@ func (e *HTTPError) Error() string { return e.Message } -func (b *binder) Bind(i interface{}, c Context) (err error) { - req := c.Request() - ct := req.Header().Get(HeaderContentType) - err = ErrUnsupportedMediaType - if strings.HasPrefix(ct, MIMEApplicationJSON) { - if err = json.NewDecoder(req.Body()).Decode(i); err != nil { - err = NewHTTPError(http.StatusBadRequest, err.Error()) - } - } else if strings.HasPrefix(ct, MIMEApplicationXML) { - if err = xml.NewDecoder(req.Body()).Decode(i); err != nil { - err = NewHTTPError(http.StatusBadRequest, err.Error()) - } - } - return -} - // WrapMiddleware wrap `echo.HandlerFunc` into `echo.MiddlewareFunc`. func WrapMiddleware(h HandlerFunc) MiddlewareFunc { return func(next HandlerFunc) HandlerFunc { diff --git a/echo_test.go b/echo_test.go index b174fdcd..c245679f 100644 --- a/echo_test.go +++ b/echo_test.go @@ -17,11 +17,18 @@ import ( type ( user struct { - ID string `json:"id" xml:"id"` - Name string `json:"name" xml:"name"` + ID int `json:"id" xml:"id" form:"id"` + Name string `json:"name" xml:"name" form:"name"` } ) +const ( + userJSON = `{"id":1,"name":"Jon Snow"}` + userXML = `1Jon Snow` + userForm = `id=1&name=Jon Snow` + invalidContent = "invalid content" +) + func TestEcho(t *testing.T) { e := New() req := test.NewRequest(GET, "/", nil) diff --git a/engine/standard/request.go b/engine/standard/request.go index 9d311a84..a955e873 100644 --- a/engine/standard/request.go +++ b/engine/standard/request.go @@ -5,7 +5,9 @@ import ( "io/ioutil" "mime/multipart" "net/http" + "strings" + "github.com/labstack/echo" "github.com/labstack/echo/engine" "github.com/labstack/gommon/log" ) @@ -20,6 +22,10 @@ type ( } ) +const ( + defaultMemory = 32 << 20 // 32 MB +) + // NewRequest returns `Request` instance. func NewRequest(r *http.Request, l *log.Logger) *Request { return &Request{ @@ -122,10 +128,16 @@ func (r *Request) FormValue(name string) string { // FormParams implements `engine.Request#FormParams` function. func (r *Request) FormParams() map[string][]string { - if err := r.ParseForm(); err != nil { - r.logger.Error(err) + if strings.HasPrefix(r.header.Get(echo.HeaderContentType), echo.MIMEMultipartForm) { + if err := r.ParseMultipartForm(defaultMemory); err != nil { + r.logger.Error(err) + } + } else { + if err := r.ParseForm(); err != nil { + r.logger.Error(err) + } } - return map[string][]string(r.Request.PostForm) + return map[string][]string(r.Request.Form) } // FormFile implements `engine.Request#FormFile` function. @@ -136,7 +148,7 @@ func (r *Request) FormFile(name string) (*multipart.FileHeader, error) { // MultipartForm implements `engine.Request#MultipartForm` function. func (r *Request) MultipartForm() (*multipart.Form, error) { - err := r.ParseMultipartForm(32 << 20) // 32 MB + err := r.ParseMultipartForm(defaultMemory) return r.Request.MultipartForm, err } diff --git a/middleware/static_test.go b/middleware/static_test.go new file mode 100644 index 00000000..c870d7c1 --- /dev/null +++ b/middleware/static_test.go @@ -0,0 +1 @@ +package middleware diff --git a/test/request.go b/test/request.go index 8240509c..93118089 100644 --- a/test/request.go +++ b/test/request.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "mime/multipart" "net/http" + "strings" "github.com/labstack/echo/engine" ) @@ -17,6 +18,10 @@ type ( } ) +const ( + defaultMemory = 32 << 20 // 32 MB +) + func NewRequest(method, url string, body io.Reader) engine.Request { r, _ := http.NewRequest(method, url, body) r.RequestURI = url @@ -103,8 +108,16 @@ func (r *Request) FormValue(name string) string { } func (r *Request) FormParams() map[string][]string { - r.request.ParseForm() - return map[string][]string(r.request.PostForm) + if strings.HasPrefix(r.header.Get("Content-Type"), "multipart/form-data") { + if err := r.request.ParseMultipartForm(defaultMemory); err != nil { + panic(err) + } + } else { + if err := r.request.ParseForm(); err != nil { + panic(err) + } + } + return map[string][]string(r.request.Form) } func (r *Request) FormFile(name string) (*multipart.FileHeader, error) { @@ -113,7 +126,7 @@ func (r *Request) FormFile(name string) (*multipart.FileHeader, error) { } func (r *Request) MultipartForm() (*multipart.Form, error) { - err := r.request.ParseMultipartForm(32 << 20) // 32 MB + err := r.request.ParseMultipartForm(defaultMemory) return r.request.MultipartForm, err }