diff --git a/binder.go b/binder.go new file mode 100644 index 00000000..e91d1dea --- /dev/null +++ b/binder.go @@ -0,0 +1,207 @@ +package echo + +import ( + "encoding/json" + "encoding/xml" + "errors" + "net/http" + "reflect" + "strconv" + "strings" +) + +type ( + // Binder is the interface that wraps the Bind method. + Binder interface { + Bind(*http.Request, interface{}) error + } + + binder struct { + maxMemory int64 + } +) + +const ( + defaultMaxMemory = 32 << 20 // 32 MB +) + +// SetMaxBodySize sets multipart forms max body size +func (b *binder) SetMaxMemory(size int64) { + b.maxMemory = size +} + +// MaxBodySize return multipart forms max body size +func (b *binder) MaxMemory() int64 { + return b.maxMemory +} + +func (b *binder) Bind(r *http.Request, i interface{}) (err error) { + ct := r.Header.Get(ContentType) + err = ErrUnsupportedMediaType + switch { + case strings.HasPrefix(ct, ApplicationJSON): + if err = json.NewDecoder(r.Body).Decode(i); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + case strings.HasPrefix(ct, ApplicationXML): + if err = xml.NewDecoder(r.Body).Decode(i); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + case strings.HasPrefix(ct, ApplicationForm): + if err = b.bindForm(r, i); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + case strings.HasPrefix(ct, MultipartForm): + if err = b.bindMultiPartForm(r, i); err != nil { + err = NewHTTPError(http.StatusBadRequest, err.Error()) + } + } + return +} + +func (binder) bindForm(r *http.Request, i interface{}) error { + if err := r.ParseForm(); err != nil { + return err + } + return mapForm(i, r.Form) +} + +func (b binder) bindMultiPartForm(r *http.Request, i interface{}) error { + if b.maxMemory == 0 { + b.maxMemory = defaultMaxMemory + } + if err := r.ParseMultipartForm(b.maxMemory); err != nil { + return err + } + return mapForm(i, r.Form) +} + +func mapForm(ptr interface{}, form map[string][]string) error { + typ := reflect.TypeOf(ptr).Elem() + val := reflect.ValueOf(ptr).Elem() + for i := 0; i < typ.NumField(); i++ { + typeField := typ.Field(i) + structField := val.Field(i) + if !structField.CanSet() { + continue + } + + structFieldKind := structField.Kind() + inputFieldName := typeField.Tag.Get("form") + if inputFieldName == "" { + inputFieldName = typeField.Name + + // if "form" tag is nil, we inspect if the field is a struct. + // this would not make sense for JSON parsing but it does for a form + // since data is flatten + if structFieldKind == reflect.Struct { + err := mapForm(structField.Addr().Interface(), form) + if err != nil { + return err + } + continue + } + } + inputValue, exists := form[inputFieldName] + if !exists { + continue + } + + numElems := len(inputValue) + if structFieldKind == reflect.Slice && numElems > 0 { + sliceOf := structField.Type().Elem().Kind() + slice := reflect.MakeSlice(structField.Type(), numElems, numElems) + for i := 0; i < numElems; i++ { + if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil { + return err + } + } + val.Field(i).Set(slice) + } else { + if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil { + return err + } + } + } + return nil +} + +func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { + switch valueKind { + case reflect.Int: + return setIntField(val, 0, structField) + case reflect.Int8: + return setIntField(val, 8, structField) + case reflect.Int16: + return setIntField(val, 16, structField) + case reflect.Int32: + return setIntField(val, 32, structField) + case reflect.Int64: + return setIntField(val, 64, structField) + case reflect.Uint: + return setUintField(val, 0, structField) + case reflect.Uint8: + return setUintField(val, 8, structField) + case reflect.Uint16: + return setUintField(val, 16, structField) + case reflect.Uint32: + return setUintField(val, 32, structField) + case reflect.Uint64: + return setUintField(val, 64, structField) + case reflect.Bool: + return setBoolField(val, structField) + case reflect.Float32: + return setFloatField(val, 32, structField) + case reflect.Float64: + return setFloatField(val, 64, structField) + case reflect.String: + structField.SetString(val) + default: + return errors.New("Unknown type") + } + return nil +} + +func setIntField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + intVal, err := strconv.ParseInt(val, 10, bitSize) + if err == nil { + field.SetInt(intVal) + } + return err +} + +func setUintField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0" + } + uintVal, err := strconv.ParseUint(val, 10, bitSize) + if err == nil { + field.SetUint(uintVal) + } + return err +} + +func setBoolField(val string, field reflect.Value) error { + if val == "" { + val = "false" + } + boolVal, err := strconv.ParseBool(val) + if err == nil { + field.SetBool(boolVal) + } + return err +} + +func setFloatField(val string, bitSize int, field reflect.Value) error { + if val == "" { + val = "0.0" + } + floatVal, err := strconv.ParseFloat(val, bitSize) + if err == nil { + field.SetFloat(floatVal) + } + return err +} diff --git a/binder_test.go b/binder_test.go new file mode 100644 index 00000000..7ee2ca78 --- /dev/null +++ b/binder_test.go @@ -0,0 +1,96 @@ +package echo + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "mime/multipart" + "net/http" + "strings" + "testing" +) + +type ( + customer struct { + ID string `json:"id" xml:"id" form:"id"` + Name string `json:"name" xml:"name" form:"name"` + } +) + +const ( + customerJSON = `{"id":"1","name":"Joe"}` + customerXML = `1Joe` + customerForm = `id=1&name=Joe` + incorrectContent = "this is incorrect content" +) + +func TestMaxMemory(t *testing.T) { + b := new(binder) + b.SetMaxMemory(20) + assert.Equal(t, int64(20), b.MaxMemory()) +} + +func TestJSONBinding(t *testing.T) { + r, _ := http.NewRequest(POST, "/", strings.NewReader(customerJSON)) + testBindOk(t, r, ApplicationJSON) + r, _ = http.NewRequest(POST, "/", strings.NewReader(incorrectContent)) + testBindError(t, r, ApplicationJSON) +} + +func TestXMLBinding(t *testing.T) { + r, _ := http.NewRequest(POST, "/", strings.NewReader(customerXML)) + testBindOk(t, r, ApplicationXML) + r, _ = http.NewRequest(POST, "/", strings.NewReader(incorrectContent)) + testBindError(t, r, ApplicationXML) +} + +func TestFormBinding(t *testing.T) { + r, _ := http.NewRequest(POST, "/", strings.NewReader(customerForm)) + testBindOk(t, r, ApplicationForm) +} + +func TestMultipartFormBinding(t *testing.T) { + body := new(bytes.Buffer) + mw := multipart.NewWriter(body) + mw.WriteField("id", "1") + mw.WriteField("name", "Joe") + mw.Close() + r, _ := http.NewRequest(POST, "/", body) + testBindOk(t, r, mw.FormDataContentType()) + r, _ = http.NewRequest(POST, "/", strings.NewReader(incorrectContent)) + testBindError(t, r, mw.FormDataContentType()) +} + +func TestUnsupportedMediaTypeBinding(t *testing.T) { + r, _ := http.NewRequest(POST, "/", strings.NewReader(customerJSON)) + + // Unsupported + testBindError(t, r, "") +} + +func testBindOk(t *testing.T, r *http.Request, ct string) { + r.Header.Set(ContentType, ct) + u := new(customer) + err := new(binder).Bind(r, u) + if assert.NoError(t, err) { + assert.Equal(t, "1", u.ID) + assert.Equal(t, "Joe", u.Name) + } +} + +func testBindError(t *testing.T, r *http.Request, ct string) { + r.Header.Set(ContentType, ct) + u := new(customer) + err := new(binder).Bind(r, u) + + switch { + case strings.HasPrefix(ct, ApplicationJSON), strings.HasPrefix(ct, ApplicationXML), strings.HasPrefix(ct, ApplicationForm), strings.HasPrefix(ct, MultipartForm): + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).code) + } + default: + if assert.IsType(t, new(HTTPError), err) { + assert.Equal(t, ErrUnsupportedMediaType, err) + } + + } +} diff --git a/context_test.go b/context_test.go index c65738c2..fb57cc62 100644 --- a/context_test.go +++ b/context_test.go @@ -34,7 +34,6 @@ func TestContext(t *testing.T) { userJSONIndent := "{\n_?\"id\": \"1\",\n_?\"name\": \"Joe\"\n_}" userXML := `1Joe` userXMLIndent := "_\n_?1\n_?Joe\n_" - incorrectContent := "this is incorrect content" var nonMarshallableChannel chan bool @@ -64,24 +63,6 @@ func TestContext(t *testing.T) { c.Set("user", "Joe") assert.Equal(t, "Joe", c.Get("user")) - //------ - // Bind - //------ - - // JSON - testBindOk(t, c, ApplicationJSON) - c.request, _ = http.NewRequest(POST, "/", strings.NewReader(incorrectContent)) - testBindError(t, c, ApplicationJSON) - - // XML - c.request, _ = http.NewRequest(POST, "/", strings.NewReader(userXML)) - testBindOk(t, c, ApplicationXML) - c.request, _ = http.NewRequest(POST, "/", strings.NewReader(incorrectContent)) - testBindError(t, c, ApplicationXML) - - // Unsupported - testBindError(t, c, "") - //-------- // Render //-------- @@ -295,31 +276,3 @@ func TestContextEcho(t *testing.T) { // Should be null when initialized without one assert.Nil(t, c.Echo()) } - -func testBindOk(t *testing.T, c *Context, ct string) { - c.request.Header.Set(ContentType, ct) - u := new(user) - err := c.Bind(u) - if assert.NoError(t, err) { - assert.Equal(t, "1", u.ID) - assert.Equal(t, "Joe", u.Name) - } -} - -func testBindError(t *testing.T, c *Context, ct string) { - c.request.Header.Set(ContentType, ct) - u := new(user) - err := c.Bind(u) - - switch ct { - case ApplicationJSON, ApplicationXML: - if assert.IsType(t, new(HTTPError), err) { - assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).code) - } - default: - if assert.IsType(t, new(HTTPError), err) { - assert.Equal(t, ErrUnsupportedMediaType, err) - } - - } -} diff --git a/echo.go b/echo.go index 9654a6ff..7819c162 100644 --- a/echo.go +++ b/echo.go @@ -33,8 +33,6 @@ package echo import ( "bytes" - "encoding/json" - "encoding/xml" "errors" "fmt" "io" @@ -43,7 +41,6 @@ import ( "path/filepath" "reflect" "runtime" - "strings" "sync" "github.com/labstack/gommon/log" @@ -116,14 +113,6 @@ type ( // HTTPErrorHandler is a centralized HTTP error handler. HTTPErrorHandler func(error, *Context) - // Binder is the interface that wraps the Bind method. - Binder interface { - Bind(*http.Request, interface{}) error - } - - binder struct { - } - // Validator is the interface that wraps the Validate method. Validator interface { Validate() error @@ -736,19 +725,3 @@ func wrapHandler(h Handler) HandlerFunc { panic("unknown handler") } } - -func (binder) Bind(r *http.Request, i interface{}) (err error) { - ct := r.Header.Get(ContentType) - err = ErrUnsupportedMediaType - if strings.HasPrefix(ct, ApplicationJSON) { - if err = json.NewDecoder(r.Body).Decode(i); err != nil { - err = NewHTTPError(http.StatusBadRequest, err.Error()) - } - } else if strings.HasPrefix(ct, ApplicationXML) { - if err = xml.NewDecoder(r.Body).Decode(i); err != nil { - err = NewHTTPError(http.StatusBadRequest, err.Error()) - } - - } - return -}