mirror of
				https://github.com/labstack/echo.git
				synced 2025-10-30 23:57:38 +02:00 
			
		
		
		
	Support BindUnmarshaler for basic types (#786)
* Add a failing test for #784 * Change ordering of unmarshaler, to handle BindUnmarshalers first * Add test for arrays of BindUnmarshalers
This commit is contained in:
		
				
					committed by
					
						 Vishal Rana
						Vishal Rana
					
				
			
			
				
	
			
			
			
						parent
						
							869cdcd19a
						
					
				
				
					commit
					9cdc439f34
				
			
							
								
								
									
										36
									
								
								bind.go
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								bind.go
									
									
									
									
									
								
							| @@ -108,6 +108,14 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// Call this first, in case we're dealing with an alias to an array type | ||||
| 		if ok, err := unmarshalField(typeField.Type.Kind(), inputValue[0], structField); ok { | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		numElems := len(inputValue) | ||||
| 		if structFieldKind == reflect.Slice && numElems > 0 { | ||||
| 			sliceOf := structField.Type().Elem().Kind() | ||||
| @@ -128,6 +136,11 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag | ||||
| } | ||||
|  | ||||
| func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error { | ||||
| 	// But also call it here, in case we're dealing with an array of BindUnmarshalers | ||||
| 	if ok, err := unmarshalField(valueKind, val, structField); ok { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	switch valueKind { | ||||
| 	case reflect.Int: | ||||
| 		return setIntField(val, 0, structField) | ||||
| @@ -157,33 +170,40 @@ func setWithProperType(valueKind reflect.Kind, val string, structField reflect.V | ||||
| 		return setFloatField(val, 64, structField) | ||||
| 	case reflect.String: | ||||
| 		structField.SetString(val) | ||||
| 	case reflect.Ptr: | ||||
| 		return unmarshalFieldPtr(val, structField) | ||||
| 	default: | ||||
| 		return unmarshalField(val, structField) | ||||
| 		return errors.New("unknown type") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func unmarshalField(value string, field reflect.Value) error { | ||||
| func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bool, error) { | ||||
| 	switch valueKind { | ||||
| 	case reflect.Ptr: | ||||
| 		return unmarshalFieldPtr(val, field) | ||||
| 	default: | ||||
| 		return unmarshalFieldNonPtr(val, field) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) { | ||||
| 	ptr := reflect.New(field.Type()) | ||||
| 	if ptr.CanInterface() { | ||||
| 		iface := ptr.Interface() | ||||
| 		if unmarshaler, ok := iface.(BindUnmarshaler); ok { | ||||
| 			err := unmarshaler.UnmarshalParam(value) | ||||
| 			field.Set(ptr.Elem()) | ||||
| 			return err | ||||
| 			return true, err | ||||
| 		} | ||||
| 	} | ||||
| 	return errors.New("unknown type") | ||||
| 	return false, nil | ||||
| } | ||||
|  | ||||
| func unmarshalFieldPtr(value string, field reflect.Value) error { | ||||
| func unmarshalFieldPtr(value string, field reflect.Value) (bool, error) { | ||||
| 	if field.IsNil() { | ||||
| 		// Initialize the pointer to a nil value | ||||
| 		field.Set(reflect.New(field.Type().Elem())) | ||||
| 	} | ||||
| 	return unmarshalField(value, field.Elem()) | ||||
| 	return unmarshalFieldNonPtr(value, field.Elem()) | ||||
| } | ||||
|  | ||||
| func setIntField(value string, bitSize int, field reflect.Value) error { | ||||
|   | ||||
							
								
								
									
										25
									
								
								bind_test.go
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								bind_test.go
									
									
									
									
									
								
							| @@ -34,9 +34,12 @@ type ( | ||||
| 		DoesntExist string | ||||
| 		T           Timestamp | ||||
| 		Tptr        *Timestamp | ||||
| 		SA          StringArray | ||||
| 	} | ||||
|  | ||||
| 	Timestamp time.Time | ||||
| 	Timestamp   time.Time | ||||
| 	TA          []Timestamp | ||||
| 	StringArray []string | ||||
| ) | ||||
|  | ||||
| func (t *Timestamp) UnmarshalParam(src string) error { | ||||
| @@ -45,6 +48,11 @@ func (t *Timestamp) UnmarshalParam(src string) error { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (a *StringArray) UnmarshalParam(src string) error { | ||||
| 	*a = StringArray(strings.Split(src, ",")) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (t bindTestStruct) GetCantSet() string { | ||||
| 	return t.cantSet | ||||
| } | ||||
| @@ -107,16 +115,21 @@ func TestBindQueryParams(t *testing.T) { | ||||
|  | ||||
| func TestBindUnmarshalParam(t *testing.T) { | ||||
| 	e := New() | ||||
| 	req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z", nil) | ||||
| 	req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z", nil) | ||||
| 	rec := httptest.NewRecorder() | ||||
| 	c := e.NewContext(req, rec) | ||||
| 	result := struct { | ||||
| 		T Timestamp `query:"ts"` | ||||
| 		T  Timestamp   `query:"ts"` | ||||
| 		TA []Timestamp `query:"ta"` | ||||
| 		SA StringArray `query:"sa"` | ||||
| 	}{} | ||||
| 	err := c.Bind(&result) | ||||
| 	ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) | ||||
| 	if assert.NoError(t, err) { | ||||
| 		//		assert.Equal(t, Timestamp(reflect.TypeOf(&Timestamp{}), time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) | ||||
| 		assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), result.T) | ||||
| 		assert.Equal(t, ts, result.T) | ||||
| 		assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA) | ||||
| 		assert.Equal(t, []Timestamp{ts, ts}, result.TA) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -217,7 +230,9 @@ func TestBindSetFields(t *testing.T) { | ||||
| 		assert.Equal(t, false, ts.B) | ||||
| 	} | ||||
|  | ||||
| 	if assert.NoError(t, unmarshalField("2016-12-06T19:09:05Z", val.FieldByName("T"))) { | ||||
| 	ok, err := unmarshalFieldNonPtr("2016-12-06T19:09:05Z", val.FieldByName("T")) | ||||
| 	if assert.NoError(t, err) { | ||||
| 		assert.Equal(t, ok, true) | ||||
| 		assert.Equal(t, Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)), ts.T) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user