diff --git a/pkg/net/metadata/metadata.go b/pkg/net/metadata/metadata.go index fb46f3d57..83eb3657c 100644 --- a/pkg/net/metadata/metadata.go +++ b/pkg/net/metadata/metadata.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "strconv" + + "github.com/pkg/errors" ) // MD is a mapping from metadata keys to values. @@ -132,3 +134,23 @@ func Bool(ctx context.Context, key string) bool { return false } } + +// Range range value from metadata in context filtered by filterFunc. +func Range(ctx context.Context, rangeFunc func(key string, value interface{}), filterFunc ...func(key string) bool) { + var filter func(key string) bool + filterLen := len(filterFunc) + if filterLen > 1 { + panic(errors.New("metadata: Range got the lenth of filterFunc must less than 2")) + } else if filterLen == 1 { + filter = filterFunc[0] + } + md, ok := ctx.Value(mdKey{}).(MD) + if !ok { + return + } + for key, value := range md { + if filter == nil || filter(key) { + rangeFunc(key, value) + } + } +} diff --git a/pkg/net/metadata/metadata_test.go b/pkg/net/metadata/metadata_test.go index db1e73647..2672da466 100644 --- a/pkg/net/metadata/metadata_test.go +++ b/pkg/net/metadata/metadata_test.go @@ -94,3 +94,56 @@ func TestInt64(t *testing.T) { mdcontext = NewContext(context.Background(), MD{Mid: 10}) assert.NotEqual(t, int64(10), Int64(mdcontext, Mid)) } + +func TestRange(t *testing.T) { + for _, test := range []struct { + filterFunc func(key string) bool + md MD + want MD + }{ + { + nil, + Pairs("foo", "bar"), + Pairs("foo", "bar"), + }, + { + IsOutgoingKey, + Pairs("foo", "bar", RemoteIP, "127.0.0.1", Color, "red", Mirror, "false"), + Pairs(RemoteIP, "127.0.0.1", Color, "red", Mirror, "false"), + }, + { + IsOutgoingKey, + Pairs("foo", "bar", Caller, "app-feed", RemoteIP, "127.0.0.1", Color, "red", Mirror, "true"), + Pairs(RemoteIP, "127.0.0.1", Color, "red", Mirror, "true"), + }, + { + IsIncomingKey, + Pairs("foo", "bar", Caller, "app-feed", RemoteIP, "127.0.0.1", Color, "red", Mirror, "true"), + Pairs(Caller, "app-feed", RemoteIP, "127.0.0.1", Color, "red", Mirror, "true"), + }, + } { + var mds []MD + c := NewContext(context.Background(), test.md) + ctx := WithContext(c) + Range(ctx, + func(key string, value interface{}) { + mds = append(mds, Pairs(key, value)) + }, + test.filterFunc) + rmd := Join(mds...) + if !reflect.DeepEqual(rmd, test.want) { + t.Fatalf("Range(%v) = %v, want %v", test.md, rmd, test.want) + } + if test.filterFunc == nil { + var mds []MD + Range(ctx, + func(key string, value interface{}) { + mds = append(mds, Pairs(key, value)) + }) + rmd := Join(mds...) + if !reflect.DeepEqual(rmd, test.want) { + t.Fatalf("Range(%v) = %v, want %v", test.md, rmd, test.want) + } + } + } +} diff --git a/pkg/net/rpc/warden/client.go b/pkg/net/rpc/warden/client.go index 1a0ff8752..8f54f02e4 100644 --- a/pkg/net/rpc/warden/client.go +++ b/pkg/net/rpc/warden/client.go @@ -83,7 +83,6 @@ func (c *Client) handle() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { var ( ok bool - cmd nmd.MD t trace.Trace gmd metadata.MD conf *ClientConfig @@ -114,17 +113,13 @@ func (c *Client) handle() grpc.UnaryClientInterceptor { defer onBreaker(brk, &err) _, ctx, cancel = conf.Timeout.Shrink(ctx) defer cancel() - if cmd, ok = nmd.FromContext(ctx); ok { - for netKey, val := range cmd { - if !nmd.IsOutgoingKey(netKey) { - continue + nmd.Range(ctx, + func(key string, value interface{}) { + if valstr, ok := value.(string); ok { + gmd[key] = []string{valstr} } - valstr, ok := val.(string) - if ok { - gmd[netKey] = []string{valstr} - } - } - } + }, + nmd.IsOutgoingKey) // merge with old matadata if exists if oldmd, ok := metadata.FromOutgoingContext(ctx); ok { gmd = metadata.Join(gmd, oldmd)