1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-24 03:46:37 +02:00
kratos/pkg/conf/dsn/query.go
2019-02-01 15:57:43 +08:00

423 lines
9.1 KiB
Go

package dsn
import (
"encoding"
"net/url"
"reflect"
"runtime"
"strconv"
"strings"
)
const (
_tagID = "dsn"
_queryPrefix = "query."
)
// InvalidBindError describes an invalid argument passed to DecodeQuery.
// (The argument to DecodeQuery must be a non-nil pointer.)
type InvalidBindError struct {
Type reflect.Type
}
func (e *InvalidBindError) Error() string {
if e.Type == nil {
return "Bind(nil)"
}
if e.Type.Kind() != reflect.Ptr {
return "Bind(non-pointer " + e.Type.String() + ")"
}
return "Bind(nil " + e.Type.String() + ")"
}
// BindTypeError describes a query value that was
// not appropriate for a value of a specific Go type.
type BindTypeError struct {
Value string
Type reflect.Type
}
func (e *BindTypeError) Error() string {
return "cannot decode " + e.Value + " into Go value of type " + e.Type.String()
}
type assignFunc func(v reflect.Value, to tagOpt) error
func stringsAssignFunc(val string) assignFunc {
return func(v reflect.Value, to tagOpt) error {
if v.Kind() != reflect.String || !v.CanSet() {
return &BindTypeError{Value: "string", Type: v.Type()}
}
if val == "" {
v.SetString(to.Default)
} else {
v.SetString(val)
}
return nil
}
}
// bindQuery parses url.Values and stores the result in the value pointed to by v.
// if v is nil or not a pointer, bindQuery returns an InvalidDecodeError
func bindQuery(query url.Values, v interface{}, assignFuncs map[string]assignFunc) (url.Values, error) {
if assignFuncs == nil {
assignFuncs = make(map[string]assignFunc)
}
d := decodeState{
data: query,
used: make(map[string]bool),
assignFuncs: assignFuncs,
}
err := d.decode(v)
ret := d.unused()
return ret, err
}
type tagOpt struct {
Name string
Default string
}
func parseTag(tag string) tagOpt {
vs := strings.SplitN(tag, ",", 2)
if len(vs) == 2 {
return tagOpt{Name: vs[0], Default: vs[1]}
}
return tagOpt{Name: vs[0]}
}
type decodeState struct {
data url.Values
used map[string]bool
assignFuncs map[string]assignFunc
}
func (d *decodeState) unused() url.Values {
ret := make(url.Values)
for k, v := range d.data {
if !d.used[k] {
ret[k] = v
}
}
return ret
}
func (d *decodeState) decode(v interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return &InvalidBindError{reflect.TypeOf(v)}
}
return d.root(rv)
}
func (d *decodeState) root(v reflect.Value) error {
var tu encoding.TextUnmarshaler
tu, v = d.indirect(v)
if tu != nil {
return tu.UnmarshalText([]byte(d.data.Encode()))
}
// TODO support map, slice as root
if v.Kind() != reflect.Struct {
return &BindTypeError{Value: d.data.Encode(), Type: v.Type()}
}
tv := v.Type()
for i := 0; i < tv.NumField(); i++ {
fv := v.Field(i)
field := tv.Field(i)
to := parseTag(field.Tag.Get(_tagID))
if to.Name == "-" {
continue
}
if af, ok := d.assignFuncs[to.Name]; ok {
if err := af(fv, tagOpt{}); err != nil {
return err
}
continue
}
if !strings.HasPrefix(to.Name, _queryPrefix) {
continue
}
to.Name = to.Name[len(_queryPrefix):]
if err := d.value(fv, "", to); err != nil {
return err
}
}
return nil
}
func combinekey(prefix string, to tagOpt) string {
key := to.Name
if prefix != "" {
key = prefix + "." + key
}
return key
}
func (d *decodeState) value(v reflect.Value, prefix string, to tagOpt) (err error) {
key := combinekey(prefix, to)
d.used[key] = true
var tu encoding.TextUnmarshaler
tu, v = d.indirect(v)
if tu != nil {
if val, ok := d.data[key]; ok {
return tu.UnmarshalText([]byte(val[0]))
}
if to.Default != "" {
return tu.UnmarshalText([]byte(to.Default))
}
return
}
switch v.Kind() {
case reflect.Bool:
err = d.valueBool(v, prefix, to)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
err = d.valueInt64(v, prefix, to)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
err = d.valueUint64(v, prefix, to)
case reflect.Float32, reflect.Float64:
err = d.valueFloat64(v, prefix, to)
case reflect.String:
err = d.valueString(v, prefix, to)
case reflect.Slice:
err = d.valueSlice(v, prefix, to)
case reflect.Struct:
err = d.valueStruct(v, prefix, to)
case reflect.Ptr:
if !d.hasKey(combinekey(prefix, to)) {
break
}
if !v.CanSet() {
break
}
nv := reflect.New(v.Type().Elem())
v.Set(nv)
err = d.value(nv, prefix, to)
}
return
}
func (d *decodeState) hasKey(key string) bool {
for k := range d.data {
if strings.HasPrefix(k, key+".") || k == key {
return true
}
}
return false
}
func (d *decodeState) valueBool(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setBool(v, val)
}
func (d *decodeState) setBool(v reflect.Value, val string) error {
bval, err := strconv.ParseBool(val)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetBool(bval)
return nil
}
func (d *decodeState) valueInt64(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setInt64(v, val)
}
func (d *decodeState) setInt64(v reflect.Value, val string) error {
ival, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetInt(ival)
return nil
}
func (d *decodeState) valueUint64(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setUint64(v, val)
}
func (d *decodeState) setUint64(v reflect.Value, val string) error {
uival, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetUint(uival)
return nil
}
func (d *decodeState) valueFloat64(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setFloat64(v, val)
}
func (d *decodeState) setFloat64(v reflect.Value, val string) error {
fval, err := strconv.ParseFloat(val, 64)
if err != nil {
return &BindTypeError{Value: val, Type: v.Type()}
}
v.SetFloat(fval)
return nil
}
func (d *decodeState) valueString(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
val := d.data.Get(key)
if val == "" {
if to.Default == "" {
return nil
}
val = to.Default
}
return d.setString(v, val)
}
func (d *decodeState) setString(v reflect.Value, val string) error {
v.SetString(val)
return nil
}
func (d *decodeState) valueSlice(v reflect.Value, prefix string, to tagOpt) error {
key := combinekey(prefix, to)
strs, ok := d.data[key]
if !ok {
strs = strings.Split(to.Default, ",")
}
if len(strs) == 0 {
return nil
}
et := v.Type().Elem()
var setFunc func(reflect.Value, string) error
switch et.Kind() {
case reflect.Bool:
setFunc = d.setBool
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
setFunc = d.setInt64
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
setFunc = d.setUint64
case reflect.Float32, reflect.Float64:
setFunc = d.setFloat64
case reflect.String:
setFunc = d.setString
default:
return &BindTypeError{Type: et, Value: strs[0]}
}
vals := reflect.MakeSlice(v.Type(), len(strs), len(strs))
for i, str := range strs {
if err := setFunc(vals.Index(i), str); err != nil {
return err
}
}
if v.CanSet() {
v.Set(vals)
}
return nil
}
func (d *decodeState) valueStruct(v reflect.Value, prefix string, to tagOpt) error {
tv := v.Type()
for i := 0; i < tv.NumField(); i++ {
fv := v.Field(i)
field := tv.Field(i)
fto := parseTag(field.Tag.Get(_tagID))
if fto.Name == "-" {
continue
}
if af, ok := d.assignFuncs[fto.Name]; ok {
if err := af(fv, tagOpt{}); err != nil {
return err
}
continue
}
if !strings.HasPrefix(fto.Name, _queryPrefix) {
continue
}
fto.Name = fto.Name[len(_queryPrefix):]
if err := d.value(fv, to.Name, fto); err != nil {
return err
}
}
return nil
}
func (d *decodeState) indirect(v reflect.Value) (encoding.TextUnmarshaler, reflect.Value) {
v0 := v
haveAddr := false
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}
for {
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr {
haveAddr = false
v = e
continue
}
}
if v.Kind() != reflect.Ptr {
break
}
if v.Elem().Kind() != reflect.Ptr && v.CanSet() {
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if v.Type().NumMethod() > 0 {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return u, reflect.Value{}
}
}
if haveAddr {
v = v0
haveAddr = false
} else {
v = v.Elem()
}
}
return nil, v
}