mirror of
https://github.com/go-kratos/kratos.git
synced 2026-05-22 10:15:24 +02:00
fix: fix encode proto well known types in form and url query (#1559)
* fix encode proto well known types
This commit is contained in:
@@ -7,14 +7,10 @@ import (
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/genproto/protobuf/field_mask"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
// EncodeMap encode proto message to url query.
|
||||
@@ -84,7 +80,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
value, err := encodeField(fd, v.Get(fd))
|
||||
value, err := EncodeField(fd, v.Get(fd))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -98,7 +94,7 @@ func encodeByField(u url.Values, path string, v protoreflect.Message) error {
|
||||
func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List) ([]string, error) {
|
||||
var values []string
|
||||
for i := 0; i < list.Len(); i++ {
|
||||
value, err := encodeField(fieldDescriptor, list.Get(i))
|
||||
value, err := EncodeField(fieldDescriptor, list.Get(i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -111,11 +107,11 @@ func encodeRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list prot
|
||||
func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map) (map[string]string, error) {
|
||||
m := make(map[string]string)
|
||||
mp.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
||||
key, err := encodeField(fieldDescriptor.MapValue(), k.Value())
|
||||
key, err := EncodeField(fieldDescriptor.MapValue(), k.Value())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
value, err := encodeField(fieldDescriptor.MapValue(), v)
|
||||
value, err := EncodeField(fieldDescriptor.MapValue(), v)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -126,7 +122,8 @@ func encodeMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflec
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
|
||||
// EncodeField encode proto message filed
|
||||
func EncodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) (string, error) {
|
||||
switch fieldDescriptor.Kind() {
|
||||
case protoreflect.BoolKind:
|
||||
return strconv.FormatBool(value.Bool()), nil
|
||||
@@ -147,29 +144,17 @@ func encodeField(fieldDescriptor protoreflect.FieldDescriptor, value protoreflec
|
||||
}
|
||||
}
|
||||
|
||||
// marshalMessage marshals the fields in the given protoreflect.Message.
|
||||
// encodeMessage marshals the fields in the given protoreflect.Message.
|
||||
// If the typeURL is non-empty, then a synthetic "@type" field is injected
|
||||
// containing the URL as the value.
|
||||
func encodeMessage(msgDescriptor protoreflect.MessageDescriptor, value protoreflect.Value) (string, error) {
|
||||
switch msgDescriptor.FullName() {
|
||||
case "google.protobuf.Timestamp":
|
||||
t, ok := value.Interface().(*timestamppb.Timestamp)
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return t.AsTime().Format(time.RFC3339Nano), nil
|
||||
case "google.protobuf.Duration":
|
||||
d, ok := value.Interface().(*durationpb.Duration)
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return d.AsDuration().String(), nil
|
||||
case "google.protobuf.BytesValue":
|
||||
b, ok := value.Interface().(*wrapperspb.BytesValue)
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b.Value), nil
|
||||
case timestampMessageFullname:
|
||||
return marshalTimestamp(value.Message())
|
||||
case durationMessageFullname:
|
||||
return marshalDuration(value.Message())
|
||||
case bytesMessageFullname:
|
||||
return marshalBytes(value.Message())
|
||||
case "google.protobuf.DoubleValue", "google.protobuf.FloatValue", "google.protobuf.Int64Value", "google.protobuf.Int32Value",
|
||||
"google.protobuf.UInt64Value", "google.protobuf.UInt32Value", "google.protobuf.BoolValue", "google.protobuf.StringValue":
|
||||
fd := msgDescriptor.Fields()
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package form
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
)
|
||||
|
||||
const (
|
||||
// timestamp
|
||||
timestampMessageFullname protoreflect.FullName = "google.protobuf.Timestamp"
|
||||
maxTimestampSeconds = 253402300799
|
||||
minTimestampSeconds = -6213559680013
|
||||
timestampSecondsFieldNumber protoreflect.FieldNumber = 1
|
||||
timestampNanosFieldNumber protoreflect.FieldNumber = 2
|
||||
|
||||
// duration
|
||||
durationMessageFullname protoreflect.FullName = "google.protobuf.Duration"
|
||||
secondsInNanos = 999999999
|
||||
durationSecondsFieldNumber protoreflect.FieldNumber = 1
|
||||
durationNanosFieldNumber protoreflect.FieldNumber = 2
|
||||
|
||||
// bytes
|
||||
bytesMessageFullname protoreflect.FullName = "google.protobuf.BytesValue"
|
||||
bytesValueFieldNumber protoreflect.FieldNumber = 1
|
||||
)
|
||||
|
||||
func marshalTimestamp(m protoreflect.Message) (string, error) {
|
||||
fds := m.Descriptor().Fields()
|
||||
fdSeconds := fds.ByNumber(timestampSecondsFieldNumber)
|
||||
fdNanos := fds.ByNumber(timestampNanosFieldNumber)
|
||||
|
||||
secsVal := m.Get(fdSeconds)
|
||||
nanosVal := m.Get(fdNanos)
|
||||
secs := secsVal.Int()
|
||||
nanos := nanosVal.Int()
|
||||
if secs < minTimestampSeconds || secs > maxTimestampSeconds {
|
||||
return "", fmt.Errorf("%s: seconds out of range %v", timestampMessageFullname, secs)
|
||||
}
|
||||
if nanos < 0 || nanos > secondsInNanos {
|
||||
return "", fmt.Errorf("%s: nanos out of range %v", timestampMessageFullname, nanos)
|
||||
}
|
||||
// Uses RFC 3339, where generated output will be Z-normalized and uses 0, 3,
|
||||
// 6 or 9 fractional digits.
|
||||
t := time.Unix(secs, nanos).UTC()
|
||||
x := t.Format("2006-01-02T15:04:05.000000000")
|
||||
x = strings.TrimSuffix(x, "000")
|
||||
x = strings.TrimSuffix(x, "000")
|
||||
x = strings.TrimSuffix(x, ".000")
|
||||
return x + "Z", nil
|
||||
}
|
||||
|
||||
func marshalDuration(m protoreflect.Message) (string, error) {
|
||||
fds := m.Descriptor().Fields()
|
||||
fdSeconds := fds.ByNumber(durationSecondsFieldNumber)
|
||||
fdNanos := fds.ByNumber(durationNanosFieldNumber)
|
||||
|
||||
secsVal := m.Get(fdSeconds)
|
||||
nanosVal := m.Get(fdNanos)
|
||||
secs := secsVal.Int()
|
||||
nanos := nanosVal.Int()
|
||||
d := time.Duration(secs) * time.Second
|
||||
overflow := d/time.Second != time.Duration(secs)
|
||||
d += time.Duration(nanos) * time.Nanosecond
|
||||
overflow = overflow || (secs < 0 && nanos < 0 && d > 0)
|
||||
overflow = overflow || (secs > 0 && nanos > 0 && d < 0)
|
||||
if overflow {
|
||||
switch {
|
||||
case secs < 0:
|
||||
return time.Duration(math.MinInt64).String(), nil
|
||||
case secs > 0:
|
||||
return time.Duration(math.MaxInt64).String(), nil
|
||||
}
|
||||
}
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
func marshalBytes(m protoreflect.Message) (string, error) {
|
||||
fds := m.Descriptor().Fields()
|
||||
fdBytes := fds.ByNumber(bytesValueFieldNumber)
|
||||
bytesVal := m.Get(fdBytes)
|
||||
val := bytesVal.Bytes()
|
||||
return base64.StdEncoding.EncodeToString(val), nil
|
||||
}
|
||||
Reference in New Issue
Block a user