1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-11-06 08:59:18 +02:00

add response header (#1119)

* add response header

Co-authored-by: chenzhihui <zhihui_chen@foxmail.com>
This commit is contained in:
longxboy
2021-06-29 15:33:18 +08:00
committed by GitHub
parent 493c11929f
commit 545ffd1084
18 changed files with 249 additions and 63 deletions

View File

@@ -24,16 +24,16 @@ type Transport struct {
func (tr *Transport) Kind() transport.Kind {
return tr.kind
}
func (tr *Transport) Endpoint() string {
return tr.endpoint
}
func (tr *Transport) Operation() string {
return tr.operation
}
func (tr *Transport) Header() transport.Header {
func (tr *Transport) RequestHeader() transport.Header {
return nil
}
func (tr *Transport) ReplyHeader() transport.Header {
return nil
}

View File

@@ -53,9 +53,10 @@ func Server(opts ...Option) middleware.Middleware {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
md := options.md.Clone()
for _, k := range tr.Header().Keys() {
header := tr.RequestHeader()
for _, k := range header.Keys() {
if options.hasPrefix(k) {
md.Set(k, tr.Header().Get(k))
md.Set(k, header.Get(k))
}
}
ctx = metadata.NewServerContext(ctx, md)
@@ -76,20 +77,21 @@ func Client(opts ...Option) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromClientContext(ctx); ok {
header := tr.RequestHeader()
// x-md-local-
for k, v := range options.md {
tr.Header().Set(k, v)
header.Set(k, v)
}
if md, ok := metadata.FromClientContext(ctx); ok {
for k, v := range md {
tr.Header().Set(k, v)
header.Set(k, v)
}
}
// x-md-global-
if md, ok := metadata.FromServerContext(ctx); ok {
for k, v := range md {
if options.hasPrefix(k) {
tr.Header().Set(k, v)
header.Set(k, v)
}
}
}

View File

@@ -27,10 +27,11 @@ func (hc headerCarrier) Keys() []string {
type testTransport struct{ header headerCarrier }
func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP }
func (tr *testTransport) Endpoint() string { return "" }
func (tr *testTransport) Operation() string { return "" }
func (tr *testTransport) Header() transport.Header { return tr.header }
func (tr *testTransport) Kind() transport.Kind { return transport.KindHTTP }
func (tr *testTransport) Endpoint() string { return "" }
func (tr *testTransport) Operation() string { return "" }
func (tr *testTransport) RequestHeader() transport.Header { return tr.header }
func (tr *testTransport) ReplyHeader() transport.Header { return tr.header }
func TestSever(t *testing.T) {
var (
@@ -89,16 +90,16 @@ func TestClient(t *testing.T) {
if !ok {
return nil, errors.New("no md")
}
if tr.Header().Get(constKey) != constValue {
if tr.RequestHeader().Get(constKey) != constValue {
return nil, errors.New("const not equal")
}
if tr.Header().Get(customKey) != customValue {
if tr.RequestHeader().Get(customKey) != customValue {
return nil, errors.New("custom not equal")
}
if tr.Header().Get(globalKey) != globalValue {
if tr.RequestHeader().Get(globalKey) != globalValue {
return nil, errors.New("global not equal")
}
if tr.Header().Get(localKey) != "" {
if tr.RequestHeader().Get(localKey) != "" {
return nil, errors.New("local must empty")
}
return in, nil

View File

@@ -38,7 +38,7 @@ func Server(opts ...Option) middleware.Middleware {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromServerContext(ctx); ok {
var span trace.Span
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.Header())
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.RequestHeader())
defer func() { tracer.End(ctx, span, err) }()
}
return handler(ctx, req)
@@ -53,7 +53,7 @@ func Client(opts ...Option) middleware.Middleware {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
if tr, ok := transport.FromClientContext(ctx); ok {
var span trace.Span
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.Header())
ctx, span = tracer.Start(ctx, tr.Kind().String(), tr.Operation(), tr.RequestHeader())
defer func() { tracer.End(ctx, span, err) }()
}
return handler(ctx, req)

View File

@@ -43,10 +43,11 @@ type Transport struct {
header headerCarrier
}
func (tr *Transport) Kind() transport.Kind { return tr.kind }
func (tr *Transport) Endpoint() string { return tr.endpoint }
func (tr *Transport) Operation() string { return tr.operation }
func (tr *Transport) Header() transport.Header { return tr.header }
func (tr *Transport) Kind() transport.Kind { return tr.kind }
func (tr *Transport) Endpoint() string { return tr.endpoint }
func (tr *Transport) Operation() string { return tr.operation }
func (tr *Transport) RequestHeader() transport.Header { return tr.header }
func (tr *Transport) ReplyHeader() transport.Header { return tr.header }
func TestTracing(t *testing.T) {
var carrier = headerCarrier{}
@@ -56,21 +57,21 @@ func TestTracing(t *testing.T) {
tracer := NewTracer(trace.SpanKindClient, WithTracerProvider(tp), WithPropagator(propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{})))
ts := &Transport{kind: transport.KindHTTP, header: carrier}
ctx, aboveSpan := tracer.Start(transport.NewClientContext(context.Background(), ts), ts.Kind().String(), ts.Operation(), ts.Header())
ctx, aboveSpan := tracer.Start(transport.NewClientContext(context.Background(), ts), ts.Kind().String(), ts.Operation(), ts.RequestHeader())
defer tracer.End(ctx, aboveSpan, nil)
// server use Extract fetch traceInfo from carrier
tracer = NewTracer(trace.SpanKindServer, WithPropagator(propagation.NewCompositeTextMapPropagator(propagation.Baggage{}, propagation.TraceContext{})))
ts = &Transport{kind: transport.KindHTTP, header: carrier}
ctx, span := tracer.Start(transport.NewServerContext(ctx, ts), ts.Kind().String(), ts.Operation(), ts.Header())
ctx, span := tracer.Start(transport.NewServerContext(ctx, ts), ts.Kind().String(), ts.Operation(), ts.RequestHeader())
defer tracer.End(ctx, span, nil)
if aboveSpan.SpanContext().TraceID() != span.SpanContext().TraceID() {
t.Fatalf("TraceID failed to deliver")
}
if v, ok := transport.FromClientContext(ctx); !ok || len(v.Header().Keys()) == 0 {
if v, ok := transport.FromClientContext(ctx); !ok || len(v.RequestHeader().Keys()) == 0 {
t.Fatalf("traceHeader failed to deliver")
}
}