1
0
mirror of https://github.com/go-kratos/kratos.git synced 2026-05-22 10:15:24 +02:00

add form for codec (#1158)

* add form for http codec
This commit is contained in:
longxboy
2021-07-08 13:05:21 +08:00
committed by GitHub
parent 9280af7165
commit 23a7f15541
25 changed files with 799 additions and 784 deletions
+44 -11
View File
@@ -1,30 +1,49 @@
package form
import (
"github.com/go-kratos/kratos/v2/encoding"
"github.com/gorilla/schema"
"net/url"
"reflect"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-playground/form/v4"
"google.golang.org/protobuf/proto"
)
// Name is form codec name
const Name = "x-www-form-urlencoded"
func init() {
decoder := schema.NewDecoder()
decoder.SetAliasTag("json")
encoder := schema.NewEncoder()
encoder.SetAliasTag("json")
decoder := form.NewDecoder()
decoder.SetTagName("json")
encoder := form.NewEncoder()
encoder.SetTagName("json")
encoding.RegisterCodec(codec{encoder: encoder, decoder: decoder})
}
type codec struct {
encoder *schema.Encoder
decoder *schema.Decoder
encoder *form.Encoder
decoder *form.Decoder
}
func (c codec) Marshal(v interface{}) ([]byte, error) {
var vs = url.Values{}
if err := c.encoder.Encode(v, vs); err != nil {
return nil, err
var vs url.Values
var err error
if m, ok := v.(proto.Message); ok {
vs, err = EncodeMap(m)
if err != nil {
return nil, err
}
} else {
vs, err = c.encoder.Encode(v)
if err != nil {
return nil, err
}
}
for k, v := range vs {
if len(v) == 0 {
delete(vs, k)
}
}
return []byte(vs.Encode()), nil
}
@@ -34,6 +53,20 @@ func (c codec) Unmarshal(data []byte, v interface{}) error {
if err != nil {
return err
}
rv := reflect.ValueOf(v)
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
rv.Set(reflect.New(rv.Type().Elem()))
}
rv = rv.Elem()
}
if m, ok := v.(proto.Message); ok {
return MapProto(m, vs)
} else if m, ok := reflect.Indirect(reflect.ValueOf(v)).Interface().(proto.Message); ok {
return MapProto(m, vs)
}
if err := c.decoder.Decode(v, vs); err != nil {
return err
}
+27 -2
View File
@@ -1,9 +1,11 @@
package form
import (
"github.com/go-kratos/kratos/v2/encoding"
"github.com/stretchr/testify/require"
"testing"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/internal/testproto/complex"
"github.com/stretchr/testify/require"
)
type LoginRequest struct {
@@ -45,3 +47,26 @@ func TestFormCodecUnmarshal(t *testing.T) {
require.Equal(t, "kratos", bindReq.Username)
require.Equal(t, "kratos_pwd", bindReq.Password)
}
func TestProtoEncodeDecode(t *testing.T) {
in := &complex.Complex{
Id: 2233,
NoOne: "2233",
Simple: &complex.Simple{Component: "5566"},
Simples: []string{"3344", "5566"},
}
content, err := encoding.GetCodec(contentType).Marshal(in)
require.NoError(t, err)
require.Equal(t, "id=2233&numberOne=2233&simples=3344&simples=5566&very_simple.component=5566", string(content))
var in2 = &complex.Complex{}
err = encoding.GetCodec(contentType).Unmarshal(content, in2)
require.NoError(t, err)
require.Equal(t, int64(2233), in2.Id)
require.Equal(t, "2233", in2.NoOne)
require.NotEmpty(t, in2.Simple)
require.Equal(t, "5566", in2.Simple.Component)
require.NotEmpty(t, in2.Simples)
require.Len(t, in2.Simples, 2)
require.Equal(t, "3344", in2.Simples[0])
require.Equal(t, "5566", in2.Simples[1])
}
+269
View File
@@ -0,0 +1,269 @@
package form
import (
"encoding/base64"
"errors"
"fmt"
"strconv"
"strings"
"time"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func MapProto(msg proto.Message, values map[string][]string) error {
for key, values := range values {
if err := populateFieldValues(msg.ProtoReflect(), strings.Split(key, "."), values); err != nil {
return err
}
}
return nil
}
func populateFieldValues(v protoreflect.Message, fieldPath []string, values []string) error {
if len(fieldPath) < 1 {
return errors.New("no field path")
}
if len(values) < 1 {
return errors.New("no value provided")
}
var fd protoreflect.FieldDescriptor
for i, fieldName := range fieldPath {
fields := v.Descriptor().Fields()
if fd = fields.ByName(protoreflect.Name(fieldName)); fd == nil {
fd = fields.ByJSONName(fieldName)
if fd == nil {
// ignore unexpected field.
return nil
}
}
if i == len(fieldPath)-1 {
break
}
if fd.Message() == nil || fd.Cardinality() == protoreflect.Repeated {
return fmt.Errorf("invalid path: %q is not a message", fieldName)
}
v = v.Mutable(fd).Message()
}
if of := fd.ContainingOneof(); of != nil {
if f := v.WhichOneof(of); f != nil {
return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
}
}
switch {
case fd.IsList():
return populateRepeatedField(fd, v.Mutable(fd).List(), values)
case fd.IsMap():
return populateMapField(fd, v.Mutable(fd).Map(), values)
}
if len(values) > 1 {
return fmt.Errorf("too many values for field %q: %s", fd.FullName().Name(), strings.Join(values, ", "))
}
return populateField(fd, v, values[0])
}
func populateField(fd protoreflect.FieldDescriptor, v protoreflect.Message, value string) error {
val, err := parseField(fd, value)
if err != nil {
return fmt.Errorf("parsing field %q: %w", fd.FullName().Name(), err)
}
v.Set(fd, val)
return nil
}
func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
for _, value := range values {
v, err := parseField(fd, value)
if err != nil {
return fmt.Errorf("parsing list %q: %w", fd.FullName().Name(), err)
}
list.Append(v)
}
return nil
}
func populateMapField(fd protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
if len(values) != 2 {
return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fd.FullName())
}
key, err := parseField(fd.MapKey(), values[0])
if err != nil {
return fmt.Errorf("parsing map key %q: %w", fd.FullName().Name(), err)
}
value, err := parseField(fd.MapValue(), values[1])
if err != nil {
return fmt.Errorf("parsing map value %q: %w", fd.FullName().Name(), err)
}
mp.Set(key.MapKey(), value)
return nil
}
func parseField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
switch fd.Kind() {
case protoreflect.BoolKind:
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBool(v), nil
case protoreflect.EnumKind:
enum, err := protoregistry.GlobalTypes.FindEnumByName(fd.Enum().FullName())
switch {
case errors.Is(err, protoregistry.NotFound):
return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fd.Enum().FullName())
case err != nil:
return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
}
v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
if v == nil {
i, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
if v == nil {
return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
}
}
return protoreflect.ValueOfEnum(v.Number()), nil
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt32(int32(v)), nil
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfInt64(v), nil
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint32(uint32(v)), nil
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfUint64(v), nil
case protoreflect.FloatKind:
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat32(float32(v)), nil
case protoreflect.DoubleKind:
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfFloat64(v), nil
case protoreflect.StringKind:
return protoreflect.ValueOfString(value), nil
case protoreflect.BytesKind:
v, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return protoreflect.Value{}, err
}
return protoreflect.ValueOfBytes(v), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return parseMessage(fd.Message(), value)
default:
panic(fmt.Sprintf("unknown field kind: %v", fd.Kind()))
}
}
func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
var msg proto.Message
switch md.FullName() {
case "google.protobuf.Timestamp":
if value == "null" {
break
}
t, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return protoreflect.Value{}, err
}
msg = timestamppb.New(t)
case "google.protobuf.Duration":
if value == "null" {
break
}
d, err := time.ParseDuration(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = durationpb.New(d)
case "google.protobuf.DoubleValue":
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Double(v)
case "google.protobuf.FloatValue":
v, err := strconv.ParseFloat(value, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Float(float32(v))
case "google.protobuf.Int64Value":
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Int64(v)
case "google.protobuf.Int32Value":
v, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Int32(int32(v))
case "google.protobuf.UInt64Value":
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.UInt64(v)
case "google.protobuf.UInt32Value":
v, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.UInt32(uint32(v))
case "google.protobuf.BoolValue":
v, err := strconv.ParseBool(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Bool(v)
case "google.protobuf.StringValue":
msg = wrapperspb.String(value)
case "google.protobuf.BytesValue":
v, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return protoreflect.Value{}, err
}
msg = wrapperspb.Bytes(v)
case "google.protobuf.FieldMask":
fm := &field_mask.FieldMask{}
fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
msg = fm
default:
return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(md.FullName()))
}
return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
}
+187
View File
@@ -0,0 +1,187 @@
package form
import (
"encoding/base64"
"fmt"
"net/url"
"reflect"
"strconv"
"strings"
"time"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// EncodeMap encode proto message to url query.
func EncodeMap(msg proto.Message) (url.Values, error) {
if msg == nil || (reflect.ValueOf(msg).Kind() == reflect.Ptr && reflect.ValueOf(msg).IsNil()) {
return url.Values{}, nil
}
u := make(url.Values)
err := encodeByField(u, "", msg.ProtoReflect())
if err != nil {
return nil, err
}
return u, nil
}
func encodeByField(u url.Values, path string, v protoreflect.Message) error {
for i := 0; i < v.Descriptor().Fields().Len(); i++ {
fd := v.Descriptor().Fields().Get(i)
var key string
var newPath string
if fd.HasJSONName() {
key = fd.JSONName()
} else {
key = fd.TextName()
}
if path == "" {
newPath = key
} else {
newPath = path + "." + key
}
if of := fd.ContainingOneof(); of != nil {
if f := v.WhichOneof(of); f != nil {
if f != fd {
continue
}
}
continue
}
switch {
case fd.IsList():
if v.Get(fd).List().Len() > 0 {
list, err := encodeRepeatedField(fd, v.Get(fd).List())
if err != nil {
return err
}
u[newPath] = list
}
case fd.IsMap():
if v.Get(fd).Map().Len() > 0 {
m, err := encodeMapField(fd, v.Get(fd).Map())
if err != nil {
return err
}
for k, value := range m {
u[fmt.Sprintf("%s[%s]", newPath, k)] = []string{value}
}
}
case (fd.Kind() == protoreflect.MessageKind) || (fd.Kind() == protoreflect.GroupKind):
value, err := encodeMessage(fd.Message(), v.Get(fd))
if err == nil {
u[newPath] = []string{value}
continue
}
err = encodeByField(u, newPath, v.Get(fd).Message())
if err != nil {
return err
}
default:
value, err := encodeField(fd, v.Get(fd))
if err != nil {
return err
}
u[newPath] = []string{value}
}
}
return nil
}
func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) {
var values []string
for i := 0; i < list.Len(); i++ {
value, err := encodeField(fieldDescriptor, list.Get(i))
if err != nil {
return nil, err
}
values = append(values, value)
}
return values, nil
}
func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) {
m := make(map[string]string)
mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
key, err := encodeField(fieldDescriptor.MapValue(), k.Value())
if err != nil {
return false
}
value, err := encodeField(fieldDescriptor.MapValue(), v)
if err != nil {
return false
}
m[key] = value
return true
})
return m, nil
}
func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
switch fieldDescriptor.Kind() {
case protoreflect.BoolKind:
return strconv.FormatBool(value.Bool()), nil
case protoreflect.EnumKind:
if fieldDescriptor.Enum().FullName() == "google.protobuf.NullValue" {
return "null", nil
}
desc := fieldDescriptor.Enum().Values().ByNumber(value.Enum())
return string(desc.Name()), nil
case protoreflect.StringKind:
return value.String(), nil
case protoreflect.BytesKind:
return base64.URLEncoding.EncodeToString(value.Bytes()), nil
case protoreflect.MessageKind, protoreflect.GroupKind:
return encodeMessage(fieldDescriptor.Message(), value)
default:
return fmt.Sprintf("%v", value.Interface()), nil
}
}
// marshalMessage marshals the fields in the given protoreflect.Message.
// If the typeURL is non-empty, then a synthetic "@type" field is injected
// containing the URL as the value.
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
switch msgDescriptor.FullName() {
case "google.protobuf.Timestamp":
t, ok := value.Interface().(*timestamppb.Timestamp)
if !ok {
return "", nil
}
return t.AsTime().Format(time.RFC3339Nano), nil
case "google.protobuf.Duration":
d, ok := value.Interface().(*durationpb.Duration)
if !ok {
return "", nil
}
return d.AsDuration().String(), nil
case "google.protobuf.BytesValue":
b, ok := value.Interface().(*wrapperspb.BytesValue)
if !ok {
return "", nil
}
return base64.StdEncoding.EncodeToString(b.Value), nil
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
fd := msgDescriptor.Fields()
v := value.Message().Get(fd.ByName(protoreflect.Name("value"))).Message()
return fmt.Sprintf("%v", v.Interface()), nil
case "google.protobuf.FieldMask":
m, ok := value.Interface().(*field_mask.FieldMask)
if !ok {
return "", nil
}
return strings.Join(m.Paths, ","), nil
default:
return "", fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
}
}