mirror of
https://github.com/go-kratos/kratos.git
synced 2025-01-07 23:02:12 +02:00
parent
60b1e593f1
commit
51a3a32502
@ -17,6 +17,16 @@ type options struct {
|
||||
md metadata.Metadata
|
||||
}
|
||||
|
||||
func (o *options) hasPrefix(key string) bool {
|
||||
k := strings.ToLower(key)
|
||||
for _, prefix := range o.prefix {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WithConstants with constant metadata key value.
|
||||
func WithConstants(md metadata.Metadata) Option {
|
||||
return func(o *options) {
|
||||
@ -33,22 +43,19 @@ func WithPropagatedPrefix(prefix ...string) Option {
|
||||
|
||||
// Server is middleware server-side metadata.
|
||||
func Server(opts ...Option) middleware.Middleware {
|
||||
options := options{
|
||||
options := &options{
|
||||
prefix: []string{"x-md-"}, // x-md-global-, x-md-local
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
o(options)
|
||||
}
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
md := metadata.New()
|
||||
md := options.md.Clone()
|
||||
for _, k := range tr.Header().Keys() {
|
||||
for _, prefix := range options.prefix {
|
||||
if strings.HasPrefix(strings.ToLower(k), prefix) {
|
||||
md.Set(k, tr.Header().Get(k))
|
||||
break
|
||||
}
|
||||
if options.hasPrefix(k) {
|
||||
md.Set(k, tr.Header().Get(k))
|
||||
}
|
||||
}
|
||||
ctx = metadata.NewServerContext(ctx, md)
|
||||
@ -60,11 +67,11 @@ func Server(opts ...Option) middleware.Middleware {
|
||||
|
||||
// Client is middleware client-side metadata.
|
||||
func Client(opts ...Option) middleware.Middleware {
|
||||
options := options{
|
||||
options := &options{
|
||||
prefix: []string{"x-md-global-"},
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
o(options)
|
||||
}
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
||||
@ -81,11 +88,8 @@ func Client(opts ...Option) middleware.Middleware {
|
||||
// x-md-global-
|
||||
if md, ok := metadata.FromServerContext(ctx); ok {
|
||||
for k, v := range md {
|
||||
for _, prefix := range options.prefix {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
tr.Header().Set(k, v)
|
||||
break
|
||||
}
|
||||
if options.hasPrefix(k) {
|
||||
tr.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
127
middleware/metadata/metadata_test.go
Normal file
127
middleware/metadata/metadata_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/metadata"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
)
|
||||
|
||||
type headerCarrier http.Header
|
||||
|
||||
func (hc headerCarrier) Get(key string) string { return http.Header(hc).Get(key) }
|
||||
|
||||
func (hc headerCarrier) Set(key string, value string) { http.Header(hc).Set(key, value) }
|
||||
|
||||
// Keys lists the keys stored in this carrier.
|
||||
func (hc headerCarrier) Keys() []string {
|
||||
keys := make([]string, 0, len(hc))
|
||||
for k := range http.Header(hc) {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
type testTransport struct{ header headerCarrier }
|
||||
|
||||
func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP }
|
||||
func (tr *testTransport) Endpoint() string { return "" }
|
||||
func (tr *testTransport) Operation() string { return "" }
|
||||
func (tr *testTransport) Header() transport.Header { return tr.header }
|
||||
|
||||
func TestSever(t *testing.T) {
|
||||
var (
|
||||
globalKey = "x-md-global-key"
|
||||
globalValue = "global-value"
|
||||
localKey = "x-md-local-key"
|
||||
localValue = "local-value"
|
||||
constKey = "x-md-local-const"
|
||||
constValue = "x-md-local-const"
|
||||
)
|
||||
hs := func(ctx context.Context, in interface{}) (interface{}, error) {
|
||||
md, ok := metadata.FromServerContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("no md")
|
||||
}
|
||||
if md.Get(constKey) != constValue {
|
||||
return nil, errors.New("const not equal")
|
||||
}
|
||||
if md.Get(globalKey) != globalValue {
|
||||
return nil, errors.New("global not equal")
|
||||
}
|
||||
if md.Get(localKey) != localValue {
|
||||
return nil, errors.New("local not equal")
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
hc := headerCarrier{}
|
||||
hc.Set(globalKey, globalValue)
|
||||
hc.Set(localKey, localValue)
|
||||
ctx := transport.NewServerContext(context.Background(), &testTransport{hc})
|
||||
// const md
|
||||
constMD := metadata.New()
|
||||
constMD.Set(constKey, constValue)
|
||||
reply, err := Server(WithConstants(constMD))(hs)(ctx, "foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reply.(string) != "foo" {
|
||||
t.Fatalf("want foo got %v", reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
var (
|
||||
globalKey = "x-md-global-key"
|
||||
globalValue = "global-value"
|
||||
localKey = "x-md-local-key"
|
||||
localValue = "local-value"
|
||||
customKey = "x-md-local-custom"
|
||||
customValue = "custom-value"
|
||||
constKey = "x-md-local-const"
|
||||
constValue = "x-md-local-const"
|
||||
)
|
||||
hs := func(ctx context.Context, in interface{}) (interface{}, error) {
|
||||
tr, ok := transport.FromClientContext(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("no md")
|
||||
}
|
||||
if tr.Header().Get(constKey) != constValue {
|
||||
return nil, errors.New("const not equal")
|
||||
}
|
||||
if tr.Header().Get(customKey) != customValue {
|
||||
return nil, errors.New("custom not equal")
|
||||
}
|
||||
if tr.Header().Get(globalKey) != globalValue {
|
||||
return nil, errors.New("global not equal")
|
||||
}
|
||||
if tr.Header().Get(localKey) != "" {
|
||||
return nil, errors.New("local must empty")
|
||||
}
|
||||
return in, nil
|
||||
}
|
||||
// server md
|
||||
serverMD := metadata.New()
|
||||
serverMD.Set(globalKey, globalValue)
|
||||
serverMD.Set(localKey, localValue)
|
||||
ctx := metadata.NewServerContext(context.Background(), serverMD)
|
||||
// client md
|
||||
clientMD := metadata.New()
|
||||
clientMD.Set(customKey, customValue)
|
||||
ctx = metadata.NewClientContext(ctx, clientMD)
|
||||
// transport carrier
|
||||
ctx = transport.NewClientContext(ctx, &testTransport{headerCarrier{}})
|
||||
// const md
|
||||
constMD := metadata.New()
|
||||
constMD.Set(constKey, constValue)
|
||||
reply, err := Client(WithConstants(constMD))(hs)(ctx, "bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if reply.(string) != "bar" {
|
||||
t.Fatalf("want foo got %v", reply)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user