1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-24 03:16:14 +02:00

Fix binding of untagged struct fields (#812)

* Add failing test

A BindUnmarshaler struct with no tag is not decoded properly.

* Fix binding of untagged structs
This commit is contained in:
Jonathan Hall 2017-01-16 08:13:46 +01:00 committed by Vishal Rana
parent 80d5c96212
commit ed7353cf60
2 changed files with 27 additions and 6 deletions

18
bind.go
View File

@ -95,7 +95,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
if inputFieldName == "" {
inputFieldName = typeField.Name
// If tag is nil, we inspect if the field is a struct.
if structFieldKind == reflect.Struct {
if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct {
err := b.bindData(structField.Addr().Interface(), data, tag)
if err != nil {
return err
@ -185,16 +185,24 @@ func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bo
}
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
// bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler
func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
ptr := reflect.New(field.Type())
if ptr.CanInterface() {
iface := ptr.Interface()
if unmarshaler, ok := iface.(BindUnmarshaler); ok {
err := unmarshaler.UnmarshalParam(value)
field.Set(ptr.Elem())
return true, err
return unmarshaler, ok
}
}
return nil, false
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
if unmarshaler, ok := bindUnmarshaler(field); ok {
err := unmarshaler.UnmarshalParam(value)
field.Set(reflect.ValueOf(unmarshaler).Elem())
return true, err
}
return false, nil
}

View File

@ -40,6 +40,9 @@ type (
Timestamp time.Time
TA []Timestamp
StringArray []string
Struct struct {
Foo string
}
)
func (t *Timestamp) UnmarshalParam(src string) error {
@ -53,6 +56,13 @@ func (a *StringArray) UnmarshalParam(src string) error {
return nil
}
func (s *Struct) UnmarshalParam(src string) error {
*s = Struct{
Foo: src,
}
return nil
}
func (t bindTestStruct) GetCantSet() string {
return t.cantSet
}
@ -75,6 +85,7 @@ var values = map[string][]string{
"cantSet": {"test"},
"T": {"2016-12-06T19:09:05+01:00"},
"Tptr": {"2016-12-06T19:09:05+01:00"},
"ST": {"bar"},
}
func TestBindJSON(t *testing.T) {
@ -115,13 +126,14 @@ func TestBindQueryParams(t *testing.T) {
func TestBindUnmarshalParam(t *testing.T) {
e := New()
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z", nil)
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
T Timestamp `query:"ts"`
TA []Timestamp `query:"ta"`
SA StringArray `query:"sa"`
ST Struct
}{}
err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
@ -130,6 +142,7 @@ func TestBindUnmarshalParam(t *testing.T) {
assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []Timestamp{ts, ts}, result.TA)
assert.Equal(t, Struct{"baz"}, result.ST)
}
}