1
0
mirror of https://github.com/labstack/echo.git synced 2025-04-21 12:17:04 +02:00

Fix binding of untagged struct fields (#812)

* Add failing test

A BindUnmarshaler struct with no tag is not decoded properly.

* Fix binding of untagged structs
This commit is contained in:
Jonathan Hall 2017-01-16 08:13:46 +01:00 committed by Vishal Rana
parent 80d5c96212
commit ed7353cf60
2 changed files with 27 additions and 6 deletions

18
bind.go
View File

@ -95,7 +95,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
if inputFieldName == "" { if inputFieldName == "" {
inputFieldName = typeField.Name inputFieldName = typeField.Name
// If tag is nil, we inspect if the field is a struct. // If tag is nil, we inspect if the field is a struct.
if structFieldKind == reflect.Struct { if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct {
err := b.bindData(structField.Addr().Interface(), data, tag) err := b.bindData(structField.Addr().Interface(), data, tag)
if err != nil { if err != nil {
return err return err
@ -185,16 +185,24 @@ func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bo
} }
} }
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { // bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler
func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
ptr := reflect.New(field.Type()) ptr := reflect.New(field.Type())
if ptr.CanInterface() { if ptr.CanInterface() {
iface := ptr.Interface() iface := ptr.Interface()
if unmarshaler, ok := iface.(BindUnmarshaler); ok { if unmarshaler, ok := iface.(BindUnmarshaler); ok {
err := unmarshaler.UnmarshalParam(value) return unmarshaler, ok
field.Set(ptr.Elem())
return true, err
} }
} }
return nil, false
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
if unmarshaler, ok := bindUnmarshaler(field); ok {
err := unmarshaler.UnmarshalParam(value)
field.Set(reflect.ValueOf(unmarshaler).Elem())
return true, err
}
return false, nil return false, nil
} }

View File

@ -40,6 +40,9 @@ type (
Timestamp time.Time Timestamp time.Time
TA []Timestamp TA []Timestamp
StringArray []string StringArray []string
Struct struct {
Foo string
}
) )
func (t *Timestamp) UnmarshalParam(src string) error { func (t *Timestamp) UnmarshalParam(src string) error {
@ -53,6 +56,13 @@ func (a *StringArray) UnmarshalParam(src string) error {
return nil return nil
} }
func (s *Struct) UnmarshalParam(src string) error {
*s = Struct{
Foo: src,
}
return nil
}
func (t bindTestStruct) GetCantSet() string { func (t bindTestStruct) GetCantSet() string {
return t.cantSet return t.cantSet
} }
@ -75,6 +85,7 @@ var values = map[string][]string{
"cantSet": {"test"}, "cantSet": {"test"},
"T": {"2016-12-06T19:09:05+01:00"}, "T": {"2016-12-06T19:09:05+01:00"},
"Tptr": {"2016-12-06T19:09:05+01:00"}, "Tptr": {"2016-12-06T19:09:05+01:00"},
"ST": {"bar"},
} }
func TestBindJSON(t *testing.T) { func TestBindJSON(t *testing.T) {
@ -115,13 +126,14 @@ func TestBindQueryParams(t *testing.T) {
func TestBindUnmarshalParam(t *testing.T) { func TestBindUnmarshalParam(t *testing.T) {
e := New() e := New()
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z", nil) req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
c := e.NewContext(req, rec) c := e.NewContext(req, rec)
result := struct { result := struct {
T Timestamp `query:"ts"` T Timestamp `query:"ts"`
TA []Timestamp `query:"ta"` TA []Timestamp `query:"ta"`
SA StringArray `query:"sa"` SA StringArray `query:"sa"`
ST Struct
}{} }{}
err := c.Bind(&result) err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
@ -130,6 +142,7 @@ func TestBindUnmarshalParam(t *testing.T) {
assert.Equal(t, ts, result.T) assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []Timestamp{ts, ts}, result.TA) assert.Equal(t, []Timestamp{ts, ts}, result.TA)
assert.Equal(t, Struct{"baz"}, result.ST)
} }
} }