diff --git a/CHANGELOG.md b/CHANGELOG.md index d284ce62b..d2cd98304 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - Metric aggregator Count() and histogram Bucket.Counts are consistently `uint64`. (1430) - `SamplingResult` now passed a `Tracestate` from the parent `SpanContext` (#1432) - Moved gRPC driver for OTLP exporter to `exporters/otlp/otlpgrpc`. (#1420) +- The `TraceContext` propagator now correctly propagates `TraceState` through the `SpanContext`. (#1447) ### Removed diff --git a/propagation/trace_context.go b/propagation/trace_context.go index ec99e0965..fa18b8540 100644 --- a/propagation/trace_context.go +++ b/propagation/trace_context.go @@ -19,7 +19,9 @@ import ( "encoding/hex" "fmt" "regexp" + "strings" + "go.opentelemetry.io/otel/label" "go.opentelemetry.io/otel/trace" ) @@ -30,12 +32,6 @@ const ( tracestateHeader = "tracestate" ) -type traceContextPropagatorKeyType uint - -const ( - tracestateKey traceContextPropagatorKeyType = 0 -) - // TraceContext is a propagator that supports the W3C Trace Context format // (https://www.w3.org/TR/trace-context/) // @@ -51,15 +47,13 @@ var traceCtxRegExp = regexp.MustCompile("^(?P[0-9a-f]{2})-(?P[ // Inject set tracecontext from the Context into the carrier. func (tc TraceContext) Inject(ctx context.Context, carrier TextMapCarrier) { - tracestate := ctx.Value(tracestateKey) - if state, ok := tracestate.(string); tracestate != nil && ok { - carrier.Set(tracestateHeader, state) - } - sc := trace.SpanContextFromContext(ctx) if !sc.IsValid() { return } + + carrier.Set(tracestateHeader, sc.TraceState.String()) + h := fmt.Sprintf("%.2x-%s-%s-%.2x", supportedVersion, sc.TraceID, @@ -70,11 +64,6 @@ func (tc TraceContext) Inject(ctx context.Context, carrier TextMapCarrier) { // Extract reads tracecontext from the carrier into a returned Context. func (tc TraceContext) Extract(ctx context.Context, carrier TextMapCarrier) context.Context { - state := carrier.Get(tracestateHeader) - if state != "" { - ctx = context.WithValue(ctx, tracestateKey, state) - } - sc := tc.extract(carrier) if !sc.IsValid() { return ctx @@ -143,6 +132,8 @@ func (tc TraceContext) extract(carrier TextMapCarrier) trace.SpanContext { // Clear all flags other than the trace-context supported sampling bit. sc.TraceFlags = opts[0] & trace.FlagsSampled + sc.TraceState = parseTraceState(carrier.Get(tracestateHeader)) + if !sc.IsValid() { return trace.SpanContext{} } @@ -154,3 +145,25 @@ func (tc TraceContext) extract(carrier TextMapCarrier) trace.SpanContext { func (tc TraceContext) Fields() []string { return []string{traceparentHeader, tracestateHeader} } + +func parseTraceState(in string) trace.TraceState { + if in == "" { + return trace.TraceState{} + } + + kvs := []label.KeyValue{} + for _, entry := range strings.Split(in, ",") { + parts := strings.SplitN(entry, "=", 2) + if len(parts) != 2 { + // Parse failure, abort! + return trace.TraceState{} + } + kvs = append(kvs, label.String(parts[0], parts[1])) + } + + // Ignoring error here as "failure to parse tracestate MUST NOT + // affect the parsing of traceparent." + // https://www.w3.org/TR/trace-context/#tracestate-header + ts, _ := trace.TraceStateFromKeyValues(kvs...) + return ts +} diff --git a/propagation/trace_context_test.go b/propagation/trace_context_test.go index aab3db339..54260f0fd 100644 --- a/propagation/trace_context_test.go +++ b/propagation/trace_context_test.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-cmp/cmp" + "go.opentelemetry.io/otel/label" "go.opentelemetry.io/otel/oteltest" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" @@ -277,17 +278,85 @@ func TestTraceContextPropagator_GetAllKeys(t *testing.T) { func TestTraceStatePropagation(t *testing.T) { prop := propagation.TraceContext{} - want := "opaquevalue" - headerName := "tracestate" + stateHeader := "tracestate" + parentHeader := "traceparent" + state, err := trace.TraceStateFromKeyValues(label.String("key1", "value1"), label.String("key2", "value2")) + if err != nil { + t.Fatalf("Unable to construct expected TraceState: %s", err.Error()) + } - inReq, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) - inReq.Header.Add(headerName, want) - ctx := prop.Extract(context.Background(), inReq.Header) + tests := []struct { + name string + headers map[string]string + valid bool + wantSc trace.SpanContext + }{ + { + name: "valid parent and state", + headers: map[string]string{ + parentHeader: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", + stateHeader: "key1=value1,key2=value2", + }, + valid: true, + wantSc: trace.SpanContext{ + TraceID: traceID, + SpanID: spanID, + TraceState: state, + }, + }, + { + name: "valid parent, invalid state", + headers: map[string]string{ + parentHeader: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", + stateHeader: "key1=value1,invalid$@#=invalid", + }, + valid: false, + wantSc: trace.SpanContext{ + TraceID: traceID, + SpanID: spanID, + }, + }, + { + name: "valid parent, malformed state", + headers: map[string]string{ + parentHeader: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", + stateHeader: "key1=value1,invalid", + }, + valid: false, + wantSc: trace.SpanContext{ + TraceID: traceID, + SpanID: spanID, + }, + }, + } - outReq, _ := http.NewRequest(http.MethodGet, "http://www.example.com", nil) - prop.Inject(ctx, outReq.Header) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + inReq, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + for hk, hv := range tt.headers { + inReq.Header.Add(hk, hv) + } - if diff := cmp.Diff(outReq.Header.Get(headerName), want); diff != "" { - t.Errorf("Propagate tracestate: -got +want %s", diff) + ctx := prop.Extract(context.Background(), inReq.Header) + if diff := cmp.Diff( + trace.RemoteSpanContextFromContext(ctx), + tt.wantSc, + cmp.AllowUnexported(label.Value{}), + cmp.AllowUnexported(trace.TraceState{}), + ); diff != "" { + t.Errorf("Extracted tracestate: -got +want %s", diff) + } + + if tt.valid { + mockTracer := oteltest.DefaultTracer() + ctx, _ = mockTracer.Start(ctx, "inject") + outReq, _ := http.NewRequest(http.MethodGet, "http://www.example.com", nil) + prop.Inject(ctx, outReq.Header) + + if diff := cmp.Diff(outReq.Header.Get(stateHeader), tt.headers[stateHeader]); diff != "" { + t.Errorf("Propagated tracestate: -got +want %s", diff) + } + } + }) } }