mirror of
https://github.com/labstack/echo.git
synced 2025-07-05 00:58:47 +02:00
Support BindUnmarshaler for basic types (#786)
* Add a failing test for #784 * Change ordering of unmarshaler, to handle BindUnmarshalers first * Add test for arrays of BindUnmarshalers
This commit is contained in:
committed by
Vishal Rana
parent
869cdcd19a
commit
9cdc439f34
36
bind.go
36
bind.go
@ -108,6 +108,14 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Call this first, in case we're dealing with an alias to an array type
|
||||||
|
if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
numElems := len(inputValue)
|
numElems := len(inputValue)
|
||||||
if structFieldKind == reflect.Slice && numElems > 0 {
|
if structFieldKind == reflect.Slice && numElems > 0 {
|
||||||
sliceOf := structField.Type().Elem().Kind()
|
sliceOf := structField.Type().Elem().Kind()
|
||||||
@ -128,6 +136,11 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
|
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
|
||||||
|
// But also call it here, in case we're dealing with an array of BindUnmarshalers
|
||||||
|
if ok, err := unmarshalField(valueKind, val, structField); ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
switch valueKind {
|
switch valueKind {
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
return setIntField(val, 0, structField)
|
return setIntField(val, 0, structField)
|
||||||
@ -157,33 +170,40 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
|
|||||||
return setFloatField(val, 64, structField)
|
return setFloatField(val, 64, structField)
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
structField.SetString(val)
|
structField.SetString(val)
|
||||||
case reflect.Ptr:
|
|
||||||
return unmarshalFieldPtr(val, structField)
|
|
||||||
default:
|
default:
|
||||||
return unmarshalField(val, structField)
|
return errors.New("unknown type")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func unmarshalField(value string, field reflect.Value) error {
|
func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
|
||||||
|
switch valueKind {
|
||||||
|
case reflect.Ptr:
|
||||||
|
return unmarshalFieldPtr(val, field)
|
||||||
|
default:
|
||||||
|
return unmarshalFieldNonPtr(val, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
|
||||||
ptr := reflect.New(field.Type())
|
ptr := reflect.New(field.Type())
|
||||||
if ptr.CanInterface() {
|
if ptr.CanInterface() {
|
||||||
iface := ptr.Interface()
|
iface := ptr.Interface()
|
||||||
if unmarshaler, ok := iface.(BindUnmarshaler); ok {
|
if unmarshaler, ok := iface.(BindUnmarshaler); ok {
|
||||||
err := unmarshaler.UnmarshalParam(value)
|
err := unmarshaler.UnmarshalParam(value)
|
||||||
field.Set(ptr.Elem())
|
field.Set(ptr.Elem())
|
||||||
return err
|
return true, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return errors.New("unknown type")
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func unmarshalFieldPtr(value string, field reflect.Value) error {
|
func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
|
||||||
if field.IsNil() {
|
if field.IsNil() {
|
||||||
// Initialize the pointer to a nil value
|
// Initialize the pointer to a nil value
|
||||||
field.Set(reflect.New(field.Type().Elem()))
|
field.Set(reflect.New(field.Type().Elem()))
|
||||||
}
|
}
|
||||||
return unmarshalField(value, field.Elem())
|
return unmarshalFieldNonPtr(value, field.Elem())
|
||||||
}
|
}
|
||||||
|
|
||||||
func setIntField(value string, bitSize int, field reflect.Value) error {
|
func setIntField(value string, bitSize int, field reflect.Value) error {
|
||||||
|
21
bind_test.go
21
bind_test.go
@ -34,9 +34,12 @@ type (
|
|||||||
DoesntExist string
|
DoesntExist string
|
||||||
T Timestamp
|
T Timestamp
|
||||||
Tptr *Timestamp
|
Tptr *Timestamp
|
||||||
|
SA StringArray
|
||||||
}
|
}
|
||||||
|
|
||||||
Timestamp time.Time
|
Timestamp time.Time
|
||||||
|
TA []Timestamp
|
||||||
|
StringArray []string
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *Timestamp) UnmarshalParam(src string) error {
|
func (t *Timestamp) UnmarshalParam(src string) error {
|
||||||
@ -45,6 +48,11 @@ func (t *Timestamp) UnmarshalParam(src string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *StringArray) UnmarshalParam(src string) error {
|
||||||
|
*a = StringArray(strings.Split(src, ","))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t bindTestStruct) GetCantSet() string {
|
func (t bindTestStruct) GetCantSet() string {
|
||||||
return t.cantSet
|
return t.cantSet
|
||||||
}
|
}
|
||||||
@ -107,16 +115,21 @@ func TestBindQueryParams(t *testing.T) {
|
|||||||
|
|
||||||
func TestBindUnmarshalParam(t *testing.T) {
|
func TestBindUnmarshalParam(t *testing.T) {
|
||||||
e := New()
|
e := New()
|
||||||
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
|
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c := e.NewContext(req, rec)
|
c := e.NewContext(req, rec)
|
||||||
result := struct {
|
result := struct {
|
||||||
T Timestamp `query:"ts"`
|
T Timestamp `query:"ts"`
|
||||||
|
TA []Timestamp `query:"ta"`
|
||||||
|
SA StringArray `query:"sa"`
|
||||||
}{}
|
}{}
|
||||||
err := c.Bind(&result)
|
err := c.Bind(&result)
|
||||||
|
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
|
||||||
if assert.NoError(t, err) {
|
if assert.NoError(t, err) {
|
||||||
// assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
|
// assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
|
||||||
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
|
assert.Equal(t, ts, result.T)
|
||||||
|
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
|
||||||
|
assert.Equal(t, []Timestamp{ts, ts}, result.TA)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,7 +230,9 @@ func TestBindSetFields(t *testing.T) {
|
|||||||
assert.Equal(t, false, ts.B)
|
assert.Equal(t, false, ts.B)
|
||||||
}
|
}
|
||||||
|
|
||||||
if assert.NoError(t, unmarshalField("2016-12-06T19:09:05Z", val.FieldByName("T"))) {
|
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.Equal(t, ok, true)
|
||||||
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
|
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user