1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-18 02:58:38 +02:00

Default binder can use UnmarshalParams(params []string) error interface to bind multiple input values at one go. (#2607)

This commit is contained in:
Martti T 2024-03-11 22:49:58 +02:00 committed by GitHub
parent a3b0ba24d3
commit c57fcb3746
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 301 additions and 65 deletions

64
bind.go
View File

@ -30,6 +30,13 @@ type BindUnmarshaler interface {
UnmarshalParam(param string) error
}
// bindMultipleUnmarshaler is used by binder to unmarshal multiple values from request at once to
// type implementing this interface. For example request could have multiple query fields `?a=1&a=2&b=test` in that case
// for `a` following slice `["1", "2"] will be passed to unmarshaller.
type bindMultipleUnmarshaler interface {
UnmarshalParams(params []string) error
}
// BindPathParams binds path params to bindable object
func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error {
names := c.ParamNames()
@ -217,8 +224,15 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
continue
}
if ok, err := unmarshalInputsToField(typeField.Type.Kind(), inputValue, structField); ok {
if err != nil {
return err
}
continue
}
// Call this first, in case we're dealing with an alias to an array type
if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok {
if ok, err := unmarshalInputToField(typeField.Type.Kind(), inputValue[0], structField); ok {
if err != nil {
return err
}
@ -245,7 +259,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri
func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
// But also call it here, in case we're dealing with an array of BindUnmarshalers
if ok, err := unmarshalField(valueKind, val, structField); ok {
if ok, err := unmarshalInputToField(valueKind, val, structField); ok {
return err
}
@ -286,35 +300,41 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V
return nil
}
func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
switch valueKind {
case reflect.Ptr:
return unmarshalFieldPtr(val, field)
default:
return unmarshalFieldNonPtr(val, field)
func unmarshalInputsToField(valueKind reflect.Kind, values []string, field reflect.Value) (bool, error) {
if valueKind == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
}
fieldIValue := field.Addr().Interface()
unmarshaler, ok := fieldIValue.(bindMultipleUnmarshaler)
if !ok {
return false, nil
}
return true, unmarshaler.UnmarshalParams(values)
}
func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
fieldIValue := field.Addr().Interface()
if unmarshaler, ok := fieldIValue.(BindUnmarshaler); ok {
return true, unmarshaler.UnmarshalParam(value)
func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) {
if valueKind == reflect.Ptr {
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
}
if unmarshaler, ok := fieldIValue.(encoding.TextUnmarshaler); ok {
return true, unmarshaler.UnmarshalText([]byte(value))
fieldIValue := field.Addr().Interface()
switch unmarshaler := fieldIValue.(type) {
case BindUnmarshaler:
return true, unmarshaler.UnmarshalParam(val)
case encoding.TextUnmarshaler:
return true, unmarshaler.UnmarshalText([]byte(val))
}
return false, nil
}
func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) {
if field.IsNil() {
// Initialize the pointer to a nil value
field.Set(reflect.New(field.Type().Elem()))
}
return unmarshalFieldNonPtr(value, field.Elem())
}
func setIntField(value string, bitSize int, field reflect.Value) error {
if value == "" {
value = "0"

View File

@ -8,6 +8,7 @@ import (
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
@ -653,49 +654,6 @@ func TestBindSetWithProperType(t *testing.T) {
assert.Error(t, setWithProperType(typ.Field(0).Type.Kind(), "5", val.Field(0)))
}
func TestBindSetFields(t *testing.T) {
ts := new(bindTestStruct)
val := reflect.ValueOf(ts).Elem()
// Int
if assert.NoError(t, setIntField("5", 0, val.FieldByName("I"))) {
assert.Equal(t, 5, ts.I)
}
if assert.NoError(t, setIntField("", 0, val.FieldByName("I"))) {
assert.Equal(t, 0, ts.I)
}
// Uint
if assert.NoError(t, setUintField("10", 0, val.FieldByName("UI"))) {
assert.Equal(t, uint(10), ts.UI)
}
if assert.NoError(t, setUintField("", 0, val.FieldByName("UI"))) {
assert.Equal(t, uint(0), ts.UI)
}
// Float
if assert.NoError(t, setFloatField("15.5", 0, val.FieldByName("F32"))) {
assert.Equal(t, float32(15.5), ts.F32)
}
if assert.NoError(t, setFloatField("", 0, val.FieldByName("F32"))) {
assert.Equal(t, float32(0.0), ts.F32)
}
// Bool
if assert.NoError(t, setBoolField("true", val.FieldByName("B"))) {
assert.Equal(t, true, ts.B)
}
if assert.NoError(t, setBoolField("", val.FieldByName("B"))) {
assert.Equal(t, false, ts.B)
}
ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T"))
if assert.NoError(t, err) {
assert.Equal(t, ok, true)
assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T)
}
}
func BenchmarkBindbindDataWithTags(b *testing.B) {
b.ReportAllocs()
ts := new(bindTestStructWithTags)
@ -1138,3 +1096,261 @@ func TestDefaultBinder_BindBody(t *testing.T) {
})
}
}
type unixTimestamp struct {
Time time.Time
}
func (t *unixTimestamp) UnmarshalParam(param string) error {
n, err := strconv.ParseInt(param, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", param)
}
*t = unixTimestamp{Time: time.Unix(n, 0)}
return err
}
type IntArrayA []int
// UnmarshalParam converts value to *Int64Slice. This allows the API to accept
// a comma-separated list of integers as a query parameter.
func (i *IntArrayA) UnmarshalParam(value string) error {
var values = strings.Split(value, ",")
var numbers = make([]int, 0, len(values))
for _, v := range values {
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", v)
}
numbers = append(numbers, int(n))
}
*i = append(*i, numbers...)
return nil
}
func TestBindUnmarshalParamExtras(t *testing.T) {
// this test documents how bind handles `BindUnmarshaler` interface:
// NOTE: BindUnmarshaler chooses first input value to be bound.
t.Run("nok, unmarshalling fails", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?t=xxxx", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V unixTimestamp `query:"t"`
}{}
err := c.Bind(&result)
assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer")
})
t.Run("ok, target is struct", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?t=1710095540&t=1710095541", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V unixTimestamp `query:"t"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
expect := unixTimestamp{
Time: time.Unix(1710095540, 0),
}
assert.Equal(t, expect, result.V)
})
t.Run("ok, target is an alias to slice and is nil, append only values from first", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1,2,3&a=4,5,6", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V IntArrayA `query:"a"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
assert.Equal(t, IntArrayA([]int{1, 2, 3}), result.V)
})
t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1,2", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V IntArrayA `query:"a"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
assert.Equal(t, IntArrayA([]int{1, 2}), result.V)
})
t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V *IntArrayA `query:"a"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
var expected = IntArrayA([]int{1})
assert.Equal(t, &expected, result.V)
})
t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V *IntArrayA `query:"a"`
}{}
result.V = new(IntArrayA) // NOT nil
err := c.Bind(&result)
assert.NoError(t, err)
var expected = IntArrayA([]int{1})
assert.Equal(t, &expected, result.V)
})
}
type unixTimestampLast struct {
Time time.Time
}
// this is silly example for `bindMultipleUnmarshaler` for type that uses last input value for unmarshalling
func (t *unixTimestampLast) UnmarshalParams(params []string) error {
lastInput := params[len(params)-1]
n, err := strconv.ParseInt(lastInput, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", lastInput)
}
*t = unixTimestampLast{Time: time.Unix(n, 0)}
return err
}
type IntArrayB []int
func (i *IntArrayB) UnmarshalParams(params []string) error {
var numbers = make([]int, 0, len(params))
for _, param := range params {
var values = strings.Split(param, ",")
for _, v := range values {
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return fmt.Errorf("'%s' is not an integer", v)
}
numbers = append(numbers, int(n))
}
}
*i = append(*i, numbers...)
return nil
}
func TestBindUnmarshalParams(t *testing.T) {
// this test documents how bind handles `bindMultipleUnmarshaler` interface:
t.Run("nok, unmarshalling fails", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?t=xxxx", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V unixTimestampLast `query:"t"`
}{}
err := c.Bind(&result)
assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer")
})
t.Run("ok, target is struct", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?t=1710095540&t=1710095541", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V unixTimestampLast `query:"t"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
expect := unixTimestampLast{
Time: time.Unix(1710095541, 0),
}
assert.Equal(t, expect, result.V)
})
t.Run("ok, target is an alias to slice and is nil, append multiple inputs", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1,2,3&a=4,5,6", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V IntArrayB `query:"a"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
assert.Equal(t, IntArrayB([]int{1, 2, 3, 4, 5, 6}), result.V)
})
t.Run("ok, target is an alias to slice and is nil, single input", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1,2", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V IntArrayB `query:"a"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
assert.Equal(t, IntArrayB([]int{1, 2}), result.V)
})
t.Run("ok, target is pointer an alias to slice and is nil", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V *IntArrayB `query:"a"`
}{}
err := c.Bind(&result)
assert.NoError(t, err)
var expected = IntArrayB([]int{1, 4, 5, 6})
assert.Equal(t, &expected, result.V)
})
t.Run("ok, target is pointer an alias to slice and is NOT nil", func(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodGet, "/?a=1&a=4,5,6", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
V *IntArrayB `query:"a"`
}{}
result.V = new(IntArrayB) // NOT nil
err := c.Bind(&result)
assert.NoError(t, err)
var expected = IntArrayB([]int{1, 4, 5, 6})
assert.Equal(t, &expected, result.V)
})
}