1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Add support for encoding.TextUnmarshaler in bind. (#1314)

This commit is contained in:
Garrett D'Amore 2019-06-09 09:39:54 -07:00 committed by Vishal Rana
parent 842fc8772f
commit c824b8ddc3
2 changed files with 61 additions and 0 deletions

21
bind.go
View File

@ -1,6 +1,7 @@
package echo package echo
import ( import (
"encoding"
"encoding/json" "encoding/json"
"encoding/xml" "encoding/xml"
"errors" "errors"
@ -21,6 +22,8 @@ type (
DefaultBinder struct{} DefaultBinder struct{}
// BindUnmarshaler is the interface used to wrap the UnmarshalParam method. // BindUnmarshaler is the interface used to wrap the UnmarshalParam method.
// Types that don't implement this, but do implement encoding.TextUnmarshaler
// will use that interface instead.
BindUnmarshaler interface { BindUnmarshaler interface {
// UnmarshalParam decodes and assigns a value from an form or query param. // UnmarshalParam decodes and assigns a value from an form or query param.
UnmarshalParam(param string) error UnmarshalParam(param string) error
@ -211,12 +214,30 @@ func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
return nil, false return nil, false
} }
// textUnmarshaler attempts to unmarshal a reflect.Value into a TextUnmarshaler
func textUnmarshaler(field reflect.Value) (encoding.TextUnmarshaler, bool) {
ptr := reflect.New(field.Type())
if ptr.CanInterface() {
iface := ptr.Interface()
if unmarshaler, ok := iface.(encoding.TextUnmarshaler); ok {
return unmarshaler, ok
}
}
return nil, false
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
if unmarshaler, ok := bindUnmarshaler(field); ok { if unmarshaler, ok := bindUnmarshaler(field); ok {
err := unmarshaler.UnmarshalParam(value) err := unmarshaler.UnmarshalParam(value)
field.Set(reflect.ValueOf(unmarshaler).Elem()) field.Set(reflect.ValueOf(unmarshaler).Elem())
return true, err return true, err
} }
if unmarshaler, ok := textUnmarshaler(field); ok {
err := unmarshaler.UnmarshalText([]byte(value))
field.Set(reflect.ValueOf(unmarshaler).Elem())
return true, err
}
return false, nil return false, nil
} }

View File

@ -50,6 +50,8 @@ type (
PtrS *string PtrS *string
cantSet string cantSet string
DoesntExist string DoesntExist string
GoT time.Time
GoTptr *time.Time
T Timestamp T Timestamp
Tptr *Timestamp Tptr *Timestamp
SA StringArray SA StringArray
@ -116,6 +118,8 @@ var values = map[string][]string{
"cantSet": {"test"}, "cantSet": {"test"},
"T": {"2016-12-06T19:09:05+01:00"}, "T": {"2016-12-06T19:09:05+01:00"},
"Tptr": {"2016-12-06T19:09:05+01:00"}, "Tptr": {"2016-12-06T19:09:05+01:00"},
"GoT": {"2016-12-06T19:09:05+01:00"},
"GoTptr": {"2016-12-06T19:09:05+01:00"},
"ST": {"bar"}, "ST": {"bar"},
} }
@ -216,6 +220,28 @@ func TestBindUnmarshalParam(t *testing.T) {
} }
} }
func TestBindUnmarshalText(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
T time.Time `query:"ts"`
TA []time.Time `query:"ta"`
SA StringArray `query:"sa"`
ST Struct
}{}
err := c.Bind(&result)
ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)
if assert.NoError(t, err) {
// assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T)
assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []time.Time{ts, ts}, result.TA)
assert.Equal(t, Struct{"baz"}, result.ST)
}
}
func TestBindUnmarshalParamPtr(t *testing.T) { func TestBindUnmarshalParamPtr(t *testing.T) {
e := New() e := New()
req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil) req := httptest.NewRequest(http.MethodGet, "/?ts=2016-12-06T19:09:05Z", nil)
@ -230,6 +256,20 @@ func TestBindUnmarshalParamPtr(t *testing.T) {
} }
} }
func TestBindUnmarshalTextPtr(t *testing.T) {
e := New()
req := httptest.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
Tptr *time.Time `query:"ts"`
}{}
err := c.Bind(&result)
if assert.NoError(t, err) {
assert.Equal(t, time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC), *result.Tptr)
}
}
func TestBindMultipartForm(t *testing.T) { func TestBindMultipartForm(t *testing.T) {
body := new(bytes.Buffer) body := new(bytes.Buffer)
mw := multipart.NewWriter(body) mw := multipart.NewWriter(body)