1
0
mirror of https://github.com/labstack/echo.git synced 2025-07-05 00:58:47 +02:00

Support BindUnmarshaler for basic types (#786)

* Add a failing test for #784

* Change ordering of unmarshaler, to handle BindUnmarshalers first

* Add test for arrays of BindUnmarshalers
This commit is contained in:
Jonathan Hall
2016-12-23 19:01:42 +01:00
committed by Vishal Rana
parent 869cdcd19a
commit 9cdc439f34
2 changed files with 48 additions and 13 deletions

36
bind.go
View File

@ -108,6 +108,14 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
continue continue
} }
// Call this first, in case we're dealing with an alias to an array type
if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
if err != nil {
return err
}
continue
}
numElems := len(inputValue) numElems := len(inputValue)
if structFieldKind == reflect.Slice && numElems > 0 { if structFieldKind == reflect.Slice && numElems > 0 {
sliceOf := structField.Type().Elem().Kind() sliceOf := structField.Type().Elem().Kind()
@ -128,6 +136,11 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
} }
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
// But also call it here, in case we're dealing with an array of BindUnmarshalers
if ok, err := unmarshalField(valueKind, val, structField); ok {
return err
}
switch valueKind { switch valueKind {
case reflect.Int: case reflect.Int:
return setIntField(val, 0, structField) return setIntField(val, 0, structField)
@ -157,33 +170,40 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
return setFloatField(val, 64, structField) return setFloatField(val, 64, structField)
case reflect.String: case reflect.String:
structField.SetString(val) structField.SetString(val)
case reflect.Ptr:
return unmarshalFieldPtr(val, structField)
default: default:
return unmarshalField(val, structField) return errors.New("unknown type")
} }
return nil return nil
} }
func unmarshalField(value string, field reflect.Value) error { func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
switch valueKind {
case reflect.Ptr:
return unmarshalFieldPtr(val, field)
default:
return unmarshalFieldNonPtr(val, field)
}
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
ptr := reflect.New(field.Type()) ptr := reflect.New(field.Type())
if ptr.CanInterface() { if ptr.CanInterface() {
iface := ptr.Interface() iface := ptr.Interface()
if unmarshaler, ok := iface.(BindUnmarshaler); ok { if unmarshaler, ok := iface.(BindUnmarshaler); ok {
err := unmarshaler.UnmarshalParam(value) err := unmarshaler.UnmarshalParam(value)
field.Set(ptr.Elem()) field.Set(ptr.Elem())
return err return true, err
} }
} }
return errors.New("unknown type") return false, nil
} }
func unmarshalFieldPtr(value string, field reflect.Value) error { func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
if field.IsNil() { if field.IsNil() {
// Initialize the pointer to a nil value // Initialize the pointer to a nil value
field.Set(reflect.New(field.Type().Elem())) field.Set(reflect.New(field.Type().Elem()))
} }
return unmarshalField(value, field.Elem()) return unmarshalFieldNonPtr(value, field.Elem())
} }
func setIntField(value string, bitSize int, field reflect.Value) error { func setIntField(value string, bitSize int, field reflect.Value) error {

View File

@ -34,9 +34,12 @@ type (
DoesntExist string DoesntExist string
T Timestamp T Timestamp
Tptr *Timestamp Tptr *Timestamp
SA StringArray
} }
Timestamp time.Time Timestamp time.Time
TA []Timestamp
StringArray []string
) )
func (t *Timestamp) UnmarshalParam(src string) error { func (t *Timestamp) UnmarshalParam(src string) error {
@ -45,6 +48,11 @@ func (t *Timestamp) UnmarshalParam(src string) error {
return err return err
} }
func (a *StringArray) UnmarshalParam(src string) error {
*a = StringArray(strings.Split(src, ","))
return nil
}
func (t bindTestStruct) GetCantSet() string { func (t bindTestStruct) GetCantSet() string {
return t.cantSet return t.cantSet
} }
@ -107,16 +115,21 @@ func TestBindQueryParams(t *testing.T) {
func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalParam(t *testing.T) {
e := New() e := New()
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
result := struct { result := struct {
T Timestamp `query:"ts"` T Timestamp `query:"ts"`
TA []Timestamp `query:"ta"`
SA StringArray `query:"sa"`
}{} }{}
err := c.Bind(&result) err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
if assert.NoError(t, err) { if assert.NoError(t, err) {
// assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) // assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []Timestamp{ts, ts}, result.TA)
} }
} }
@ -217,7 +230,9 @@ func TestBindSetFields(t *testing.T) {
assert.Equal(t, false, ts.B) assert.Equal(t, false, ts.B)
} }
if assert.NoError(t, unmarshalField("2016-12-06T19:09:05Z", val.FieldByName("T"))) { ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
if assert.NoError(t, err) {
assert.Equal(t, ok, true)
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
} }
} }