From 439cd31389fec380a160471feb55143df7b6c6df Mon Sep 17 00:00:00 2001 From: Matej Gera <38492574+matej-g@users.noreply.github.com> Date: Mon, 21 Dec 2020 22:11:48 +0100 Subject: [PATCH] Add TraceState to SpanContext in API (#1340) * Add TraceState to API * Add tests for TraceState * Update related tests - stdout exporter test - SDK test * Update OTLP span transform * Update CHANGELOG * Change TraceState to struct instead of pointer - Adjust tests for trace API - Adjust adjacent parts of codebase (test utils, SDK etc.) * Add methods to assert equality - for type SpanContext, if SpanID, TraceID, TraceFlag and TraceState are equal - for type TraceState, if entries of both respective trace states are equal Signed-off-by: Matej Gera * Copy values for new TraceState, adjust tests * Use IsEqualWith in remaining tests instead of assertion func * Further feedback, minor improvements - Move IsEqualWith method to be only in test package - Minor improvements, typos etc. Co-authored-by: Tyler Yahn --- CHANGELOG.md | 1 + bridge/opencensus/utils/utils_test.go | 4 +- exporters/otlp/internal/transform/span.go | 22 +- .../otlp/internal/transform/span_test.go | 7 +- exporters/stdout/trace_test.go | 13 +- oteltest/span.go | 12 +- oteltest/tracer.go | 23 +- oteltest/tracer_test.go | 13 +- propagation/trace_context_test.go | 4 +- sdk/trace/span.go | 11 +- sdk/trace/trace_test.go | 27 +- trace/trace.go | 160 +++++++++ trace/trace_noop_test.go | 2 +- trace/trace_test.go | 338 +++++++++++++++++- 14 files changed, 573 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba4d40dc4..a7062c744 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - An `EventOption` and the related `NewEventConfig` function are added to the `go.opentelemetry.io/otel` package to configure Span events. (#1254) - A `TextMapPropagator` and associated `TextMapCarrier` are added to the `go.opentelemetry.io/otel/oteltest` package to test `TextMap` type propagators and their use. (#1259) - `SpanContextFromContext` returns `SpanContext` from context. (#1255) +- `TraceState` has been added to `SpanContext`. (#1340) - `DeploymentEnvironmentKey` added to `go.opentelemetry.io/otel/semconv` package. (#1323) - Add an OpenCensus to OpenTelemetry tracing bridge. (#1305) - Add a parent context argument to `SpanProcessor.OnStart` to follow the specification. (#1333) diff --git a/bridge/opencensus/utils/utils_test.go b/bridge/opencensus/utils/utils_test.go index b21744503..c324b3420 100644 --- a/bridge/opencensus/utils/utils_test.go +++ b/bridge/opencensus/utils/utils_test.go @@ -130,7 +130,9 @@ func TestOCSpanContextToOTel(t *testing.T) { } { t.Run(tc.description, func(t *testing.T) { output := OCSpanContextToOTel(tc.input) - if output != tc.expected { + if output.SpanID != tc.expected.SpanID || + output.TraceID != tc.expected.TraceID || + output.TraceFlags != tc.expected.TraceFlags { t.Fatalf("Got %+v spancontext, exepected %+v.", output, tc.expected) } }) diff --git a/exporters/otlp/internal/transform/span.go b/exporters/otlp/internal/transform/span.go index 7e575a6d1..a71894c25 100644 --- a/exporters/otlp/internal/transform/span.go +++ b/exporters/otlp/internal/transform/span.go @@ -102,17 +102,17 @@ func span(sd *export.SpanSnapshot) *tracepb.Span { } s := &tracepb.Span{ - TraceId: sd.SpanContext.TraceID[:], - SpanId: sd.SpanContext.SpanID[:], - Status: status(sd.StatusCode, sd.StatusMessage), - StartTimeUnixNano: uint64(sd.StartTime.UnixNano()), - EndTimeUnixNano: uint64(sd.EndTime.UnixNano()), - Links: links(sd.Links), - Kind: spanKind(sd.SpanKind), - Name: sd.Name, - Attributes: Attributes(sd.Attributes), - Events: spanEvents(sd.MessageEvents), - // TODO (rghetia): Add Tracestate: when supported. + TraceId: sd.SpanContext.TraceID[:], + SpanId: sd.SpanContext.SpanID[:], + TraceState: sd.SpanContext.TraceState.String(), + Status: status(sd.StatusCode, sd.StatusMessage), + StartTimeUnixNano: uint64(sd.StartTime.UnixNano()), + EndTimeUnixNano: uint64(sd.EndTime.UnixNano()), + Links: links(sd.Links), + Kind: spanKind(sd.SpanKind), + Name: sd.Name, + Attributes: Attributes(sd.Attributes), + Events: spanEvents(sd.MessageEvents), DroppedAttributesCount: uint32(sd.DroppedAttributeCount), DroppedEventsCount: uint32(sd.DroppedMessageEventCount), DroppedLinksCount: uint32(sd.DroppedLinkCount), diff --git a/exporters/otlp/internal/transform/span_test.go b/exporters/otlp/internal/transform/span_test.go index 1a80a1894..e8d863fbf 100644 --- a/exporters/otlp/internal/transform/span_test.go +++ b/exporters/otlp/internal/transform/span_test.go @@ -199,10 +199,12 @@ func TestSpanData(t *testing.T) { // March 31, 2020 5:01:26 1234nanos (UTC) startTime := time.Unix(1585674086, 1234) endTime := startTime.Add(10 * time.Second) + traceState, _ := trace.TraceStateFromKeyValues(label.String("key1", "val1"), label.String("key2", "val2")) spanData := &export.SpanSnapshot{ SpanContext: trace.SpanContext{ - TraceID: trace.TraceID{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}, - SpanID: trace.SpanID{0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8}, + TraceID: trace.TraceID{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}, + SpanID: trace.SpanID{0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8}, + TraceState: traceState, }, SpanKind: trace.SpanKindServer, ParentSpanID: trace.SpanID{0xEF, 0xEE, 0xED, 0xEC, 0xEB, 0xEA, 0xE9, 0xE8}, @@ -266,6 +268,7 @@ func TestSpanData(t *testing.T) { TraceId: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F}, SpanId: []byte{0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8}, ParentSpanId: []byte{0xEF, 0xEE, 0xED, 0xEC, 0xEB, 0xEA, 0xE9, 0xE8}, + TraceState: "key1=val1,key2=val2", Name: spanData.Name, Kind: tracepb.Span_SPAN_KIND_SERVER, StartTimeUnixNano: uint64(startTime.UnixNano()), diff --git a/exporters/stdout/trace_test.go b/exporters/stdout/trace_test.go index ed2f4f310..4a3fabed9 100644 --- a/exporters/stdout/trace_test.go +++ b/exporters/stdout/trace_test.go @@ -42,14 +42,16 @@ func TestExporter_ExportSpan(t *testing.T) { now := time.Now() traceID, _ := trace.TraceIDFromHex("0102030405060708090a0b0c0d0e0f10") spanID, _ := trace.SpanIDFromHex("0102030405060708") + traceState, _ := trace.TraceStateFromKeyValues(label.String("key", "val")) keyValue := "value" doubleValue := 123.456 resource := resource.NewWithAttributes(label.String("rk1", "rv11")) testSpan := &export.SpanSnapshot{ SpanContext: trace.SpanContext{ - TraceID: traceID, - SpanID: spanID, + TraceID: traceID, + SpanID: spanID, + TraceState: traceState, }, Name: "/foo", StartTime: now, @@ -76,7 +78,12 @@ func TestExporter_ExportSpan(t *testing.T) { got := b.String() expectedOutput := `[{"SpanContext":{` + `"TraceID":"0102030405060708090a0b0c0d0e0f10",` + - `"SpanID":"0102030405060708","TraceFlags":0},` + + `"SpanID":"0102030405060708","TraceFlags":0,` + + `"TraceState":[` + + `{` + + `"Key":"key",` + + `"Value":{"Type":"STRING","Value":"val"}` + + `}]},` + `"ParentSpanID":"0000000000000000",` + `"SpanKind":1,` + `"Name":"/foo",` + diff --git a/oteltest/span.go b/oteltest/span.go index feb892fdd..9504a5af4 100644 --- a/oteltest/span.go +++ b/oteltest/span.go @@ -47,7 +47,7 @@ type Span struct { statusMessage string attributes map[label.Key]label.Value events []Event - links map[trace.SpanContext][]label.KeyValue + links []trace.Link spanKind trace.SpanKind } @@ -206,15 +206,7 @@ func (s *Span) Events() []Event { return s.events } // Links returns the links set on s at creation time. If multiple links for // the same SpanContext were set, the last link will be used. -func (s *Span) Links() map[trace.SpanContext][]label.KeyValue { - links := make(map[trace.SpanContext][]label.KeyValue) - - for sc, attributes := range s.links { - links[sc] = append([]label.KeyValue{}, attributes...) - } - - return links -} +func (s *Span) Links() []trace.Link { return s.links } // StartTime returns the time at which s was started. This will be the // wall-clock time unless a specific start time was provided. diff --git a/oteltest/tracer.go b/oteltest/tracer.go index 45446e70a..cf0cc2a66 100644 --- a/oteltest/tracer.go +++ b/oteltest/tracer.go @@ -47,7 +47,7 @@ func (t *Tracer) Start(ctx context.Context, name string, opts ...trace.SpanOptio tracer: t, startTime: startTime, attributes: make(map[label.Key]label.Value), - links: make(map[trace.SpanContext][]label.KeyValue), + links: []trace.Link{}, spanKind: c.SpanKind, } @@ -56,10 +56,16 @@ func (t *Tracer) Start(ctx context.Context, name string, opts ...trace.SpanOptio iodKey := label.Key("ignored-on-demand") if lsc := trace.SpanContextFromContext(ctx); lsc.IsValid() { - span.links[lsc] = []label.KeyValue{iodKey.String("current")} + span.links = append(span.links, trace.Link{ + SpanContext: lsc, + Attributes: []label.KeyValue{iodKey.String("current")}, + }) } if rsc := trace.RemoteSpanContextFromContext(ctx); rsc.IsValid() { - span.links[rsc] = []label.KeyValue{iodKey.String("remote")} + span.links = append(span.links, trace.Link{ + SpanContext: rsc, + Attributes: []label.KeyValue{iodKey.String("remote")}, + }) } } else { span.spanContext = t.config.SpanContextFunc(ctx) @@ -73,7 +79,16 @@ func (t *Tracer) Start(ctx context.Context, name string, opts ...trace.SpanOptio } for _, link := range c.Links { - span.links[link.SpanContext] = link.Attributes + for i, sl := range span.links { + if sl.SpanContext.SpanID == link.SpanContext.SpanID && + sl.SpanContext.TraceID == link.SpanContext.TraceID && + sl.SpanContext.TraceFlags == link.SpanContext.TraceFlags && + sl.SpanContext.TraceState.String() == link.SpanContext.TraceState.String() { + span.links[i].Attributes = link.Attributes + break + } + } + span.links = append(span.links, link) } span.SetName(name) diff --git a/oteltest/tracer_test.go b/oteltest/tracer_test.go index 1de96913a..f0e9ce699 100644 --- a/oteltest/tracer_test.go +++ b/oteltest/tracer_test.go @@ -211,14 +211,7 @@ func TestTracer(t *testing.T) { }, }, } - tsLinks := testSpan.Links() - gotLinks := make([]trace.Link, 0, len(tsLinks)) - for sc, attributes := range tsLinks { - gotLinks = append(gotLinks, trace.Link{ - SpanContext: sc, - Attributes: attributes, - }) - } + gotLinks := testSpan.Links() e.Expect(gotLinks).ToMatchInAnyOrder(expectedLinks) }) @@ -251,8 +244,8 @@ func TestTracer(t *testing.T) { e.Expect(ok).ToBeTrue() links := testSpan.Links() - e.Expect(links[link1.SpanContext]).ToEqual(link1.Attributes) - e.Expect(links[link2.SpanContext]).ToEqual(link2.Attributes) + e.Expect(links[0].Attributes).ToEqual(link1.Attributes) + e.Expect(links[1].Attributes).ToEqual(link2.Attributes) }) }) } diff --git a/propagation/trace_context_test.go b/propagation/trace_context_test.go index 1fc3ba154..aab3db339 100644 --- a/propagation/trace_context_test.go +++ b/propagation/trace_context_test.go @@ -113,7 +113,7 @@ func TestExtractValidTraceContextFromHTTPReq(t *testing.T) { ctx := context.Background() ctx = prop.Extract(ctx, req.Header) gotSc := trace.RemoteSpanContextFromContext(ctx) - if diff := cmp.Diff(gotSc, tt.wantSc); diff != "" { + if diff := cmp.Diff(gotSc, tt.wantSc, cmp.AllowUnexported(trace.TraceState{})); diff != "" { t.Errorf("Extract Tracecontext: %s: -got +want %s", tt.name, diff) } }) @@ -201,7 +201,7 @@ func TestExtractInvalidTraceContextFromHTTPReq(t *testing.T) { ctx := context.Background() ctx = prop.Extract(ctx, req.Header) gotSc := trace.RemoteSpanContextFromContext(ctx) - if diff := cmp.Diff(gotSc, wantSc); diff != "" { + if diff := cmp.Diff(gotSc, wantSc, cmp.AllowUnexported(trace.TraceState{})); diff != "" { t.Errorf("Extract Tracecontext: %s: -got +want %s", tt.name, diff) } }) diff --git a/sdk/trace/span.go b/sdk/trace/span.go index bf1f77b23..a053f8205 100644 --- a/sdk/trace/span.go +++ b/sdk/trace/span.go @@ -473,7 +473,7 @@ func startSpanInternal(ctx context.Context, tr *tracer, name string, parent trac cfg := tr.provider.config.Load().(*Config) - if parent == emptySpanContext { + if hasEmptySpanContext(parent) { // Generate both TraceID and SpanID span.spanContext.TraceID, span.spanContext.SpanID = cfg.IDGenerator.NewIDs(ctx) } else { @@ -486,7 +486,7 @@ func startSpanInternal(ctx context.Context, tr *tracer, name string, parent trac span.links = newEvictedQueue(cfg.MaxLinksPerSpan) data := samplingData{ - noParent: parent == emptySpanContext, + noParent: hasEmptySpanContext(parent), remoteParent: remoteParent, parent: parent, name: name, @@ -521,6 +521,13 @@ func startSpanInternal(ctx context.Context, tr *tracer, name string, parent trac return span } +func hasEmptySpanContext(parent trace.SpanContext) bool { + return parent.SpanID == emptySpanContext.SpanID && + parent.TraceID == emptySpanContext.TraceID && + parent.TraceFlags == emptySpanContext.TraceFlags && + parent.TraceState.IsEmpty() +} + type samplingData struct { noParent bool remoteParent bool diff --git a/sdk/trace/trace_test.go b/sdk/trace/trace_test.go index 22ff6b20f..f2f5e15cf 100644 --- a/sdk/trace/trace_test.go +++ b/sdk/trace/trace_test.go @@ -316,34 +316,38 @@ func TestStartSpanWithParent(t *testing.T) { TraceFlags: 0x1, } _, s1 := tr.Start(trace.ContextWithRemoteSpanContext(ctx, sc1), "span1-unsampled-parent1") - if err := checkChild(sc1, s1); err != nil { + if err := checkChild(t, sc1, s1); err != nil { t.Error(err) } _, s2 := tr.Start(trace.ContextWithRemoteSpanContext(ctx, sc1), "span2-unsampled-parent1") - if err := checkChild(sc1, s2); err != nil { + if err := checkChild(t, sc1, s2); err != nil { t.Error(err) } + ts, err := trace.TraceStateFromKeyValues(label.String("k", "v")) + if err != nil { + t.Error(err) + } sc2 := trace.SpanContext{ TraceID: tid, SpanID: sid, TraceFlags: 0x1, - //Tracestate: testTracestate, + TraceState: ts, } _, s3 := tr.Start(trace.ContextWithRemoteSpanContext(ctx, sc2), "span3-sampled-parent2") - if err := checkChild(sc2, s3); err != nil { + if err := checkChild(t, sc2, s3); err != nil { t.Error(err) } ctx2, s4 := tr.Start(trace.ContextWithRemoteSpanContext(ctx, sc2), "span4-sampled-parent2") - if err := checkChild(sc2, s4); err != nil { + if err := checkChild(t, sc2, s4); err != nil { t.Error(err) } s4Sc := s4.SpanContext() _, s5 := tr.Start(ctx2, "span5-implicit-childof-span4") - if err := checkChild(s4Sc, s5); err != nil { + if err := checkChild(t, s4Sc, s5); err != nil { t.Error(err) } } @@ -751,7 +755,8 @@ func TestSetSpanStatus(t *testing.T) { func cmpDiff(x, y interface{}) string { return cmp.Diff(x, y, cmp.AllowUnexported(label.Value{}), - cmp.AllowUnexported(export.Event{})) + cmp.AllowUnexported(export.Event{}), + cmp.AllowUnexported(trace.TraceState{})) } func remoteSpanContext() trace.SpanContext { @@ -764,7 +769,7 @@ func remoteSpanContext() trace.SpanContext { // checkChild is test utility function that tests that c has fields set appropriately, // given that it is a child span of p. -func checkChild(p trace.SpanContext, apiSpan trace.Span) error { +func checkChild(t *testing.T, p trace.SpanContext, apiSpan trace.Span) error { s := apiSpan.(*span) if s == nil { return fmt.Errorf("got nil child span, want non-nil") @@ -778,10 +783,8 @@ func checkChild(p trace.SpanContext, apiSpan trace.Span) error { if got, want := s.spanContext.TraceFlags, p.TraceFlags; got != want { return fmt.Errorf("got child trace options %d, want %d", got, want) } - // TODO [rgheita] : Fix tracestate test - //if got, want := c.spanContext.Tracestate, p.Tracestate; got != want { - // return fmt.Errorf("got child tracestate %v, want %v", got, want) - //} + got, want := s.spanContext.TraceState, p.TraceState + assert.Equal(t, want, got) return nil } diff --git a/trace/trace.go b/trace/trace.go index 76016e5bb..ced7f10ce 100644 --- a/trace/trace.go +++ b/trace/trace.go @@ -19,6 +19,8 @@ import ( "context" "encoding/hex" "encoding/json" + "regexp" + "strings" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/label" @@ -42,6 +44,18 @@ const ( errInvalidSpanIDLength errorConst = "hex encoded span-id must have length equals to 16" errNilSpanID errorConst = "span-id can't be all zero" + + // based on the W3C Trace Context specification, see https://www.w3.org/TR/trace-context-1/#tracestate-header + traceStateKeyFormat = `[a-z][_0-9a-z\-\*\/]{0,255}` + traceStateKeyFormatWithMultiTenantVendor = `[a-z][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}` + traceStateValueFormat = `[\x20-\x2b\x2d-\x3c\x3e-\x7e]{0,255}[\x21-\x2b\x2d-\x3c\x3e-\x7e]` + + traceStateMaxListMembers = 32 + + errInvalidTraceStateKeyValue errorConst = "provided key or value is not valid according to the" + + " W3C Trace Context specification" + errInvalidTraceStateMembersNumber errorConst = "trace state would exceed the maximum limit of members (32)" + errInvalidTraceStateDuplicate errorConst = "trace state key/value pairs with duplicate keys provided" ) type errorConst string @@ -157,11 +171,157 @@ func decodeHex(h string, b []byte) error { return nil } +// TraceState provides additional vendor-specific trace identification information +// across different distributed tracing systems. It represents an immutable list consisting +// of key/value pairs. There can be a maximum of 32 entries in the list. +// +// Key and value of each list member must be valid according to the W3C Trace Context specification +// (see https://www.w3.org/TR/trace-context-1/#key and https://www.w3.org/TR/trace-context-1/#value +// respectively). +// +// Trace state must be valid according to the W3C Trace Context specification at all times. All +// mutating operations validate their input and, in case of valid parameters, return a new TraceState. +type TraceState struct { //nolint:golint + // TODO @matej-g: Consider implementing this as label.Set, see + // comment https://github.com/open-telemetry/opentelemetry-go/pull/1340#discussion_r540599226 + kvs []label.KeyValue +} + +var _ json.Marshaler = TraceState{} +var keyFormatRegExp = regexp.MustCompile( + `^((` + traceStateKeyFormat + `)|(` + traceStateKeyFormatWithMultiTenantVendor + `))$`, +) +var valueFormatRegExp = regexp.MustCompile(`^(` + traceStateValueFormat + `)$`) + +// MarshalJSON implements a custom marshal function to encode trace state. +func (ts TraceState) MarshalJSON() ([]byte, error) { + return json.Marshal(ts.kvs) +} + +// String returns trace state as a string valid according to the +// W3C Trace Context specification. +func (ts TraceState) String() string { + var sb strings.Builder + + for i, kv := range ts.kvs { + sb.WriteString((string)(kv.Key)) + sb.WriteByte('=') + sb.WriteString(kv.Value.Emit()) + + if i != len(ts.kvs)-1 { + sb.WriteByte(',') + } + } + + return sb.String() +} + +// Get returns a value for given key from the trace state. +// If no key is found or provided key is invalid, returns an empty value. +func (ts TraceState) Get(key label.Key) label.Value { + if !isTraceStateKeyValid(key) { + return label.Value{} + } + + for _, kv := range ts.kvs { + if kv.Key == key { + return kv.Value + } + } + + return label.Value{} +} + +// Insert adds a new key/value, if one doesn't exists; otherwise updates the existing entry. +// The new or updated entry is always inserted at the beginning of the TraceState, i.e. +// on the left side, as per the W3C Trace Context specification requirement. +func (ts TraceState) Insert(entry label.KeyValue) (TraceState, error) { + if !isTraceStateKeyValueValid(entry) { + return ts, errInvalidTraceStateKeyValue + } + + ckvs := ts.copyKVsAndDeleteEntry(entry.Key) + if len(ckvs)+1 > traceStateMaxListMembers { + return ts, errInvalidTraceStateMembersNumber + } + + ckvs = append(ckvs, label.KeyValue{}) + copy(ckvs[1:], ckvs) + ckvs[0] = entry + + return TraceState{ckvs}, nil +} + +// Delete removes specified entry from the trace state. +func (ts TraceState) Delete(key label.Key) (TraceState, error) { + if !isTraceStateKeyValid(key) { + return ts, errInvalidTraceStateKeyValue + } + + return TraceState{ts.copyKVsAndDeleteEntry(key)}, nil +} + +// IsEmpty returns true if the TraceState does not contain any entries +func (ts TraceState) IsEmpty() bool { + return len(ts.kvs) == 0 +} + +func (ts TraceState) copyKVsAndDeleteEntry(key label.Key) []label.KeyValue { + ckvs := make([]label.KeyValue, len(ts.kvs)) + copy(ckvs, ts.kvs) + for i, kv := range ts.kvs { + if kv.Key == key { + ckvs = append(ckvs[:i], ckvs[i+1:]...) + break + } + } + + return ckvs +} + +// TraceStateFromKeyValues is a convenience method to create a new TraceState from +// provided key/value pairs. +func TraceStateFromKeyValues(kvs ...label.KeyValue) (TraceState, error) { //nolint:golint + if len(kvs) == 0 { + return TraceState{}, nil + } + + if len(kvs) > traceStateMaxListMembers { + return TraceState{}, errInvalidTraceStateMembersNumber + } + + km := make(map[label.Key]bool) + for _, kv := range kvs { + if !isTraceStateKeyValueValid(kv) { + return TraceState{}, errInvalidTraceStateKeyValue + } + _, ok := km[kv.Key] + if ok { + return TraceState{}, errInvalidTraceStateDuplicate + } + km[kv.Key] = true + } + + ckvs := make([]label.KeyValue, len(kvs)) + copy(ckvs, kvs) + return TraceState{ckvs}, nil +} + +func isTraceStateKeyValid(key label.Key) bool { + return keyFormatRegExp.MatchString(string(key)) +} + +func isTraceStateKeyValueValid(kv label.KeyValue) bool { + return isTraceStateKeyValid(kv.Key) && + valueFormatRegExp.MatchString(kv.Value.Emit()) +} + // SpanContext contains identifying trace information about a Span. type SpanContext struct { TraceID TraceID SpanID SpanID TraceFlags byte + TraceState TraceState } // IsValid returns if the SpanContext is valid. A valid span context has a diff --git a/trace/trace_noop_test.go b/trace/trace_noop_test.go index f65d7d474..64a4d95ae 100644 --- a/trace/trace_noop_test.go +++ b/trace/trace_noop_test.go @@ -62,7 +62,7 @@ func TestNoopSpan(t *testing.T) { _, s := tracer.Start(context.Background(), "test span") span := s.(noopSpan) - if got, want := span.SpanContext(), (SpanContext{}); got != want { + if got, want := span.SpanContext(), (SpanContext{}); !assertSpanContextEqual(got, want) { t.Errorf("span.SpanContext() returned %#v, want %#v", got, want) } diff --git a/trace/trace_test.go b/trace/trace_test.go index 58ed137eb..2b6ecd64c 100644 --- a/trace/trace_test.go +++ b/trace/trace_test.go @@ -16,9 +16,13 @@ package trace import ( "context" + "fmt" "testing" + "go.opentelemetry.io/otel/label" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type testSpan struct { @@ -69,7 +73,7 @@ func TestContextSpan(t *testing.T) { func TestContextRemoteSpanContext(t *testing.T) { ctx := context.Background() got, empty := RemoteSpanContextFromContext(ctx), SpanContext{} - if got != empty { + if !assertSpanContextEqual(got, empty) { t.Errorf("RemoteSpanContextFromContext returned %v from an empty context, want %v", got, empty) } @@ -77,11 +81,11 @@ func TestContextRemoteSpanContext(t *testing.T) { ctx = ContextWithRemoteSpanContext(ctx, want) if got, ok := ctx.Value(remoteContextKey).(SpanContext); !ok { t.Errorf("failed to set SpanContext with %#v", want) - } else if got != want { + } else if !assertSpanContextEqual(got, want) { t.Errorf("got %#v from context with remote set, want %#v", got, want) } - if got := RemoteSpanContextFromContext(ctx); got != want { + if got := RemoteSpanContextFromContext(ctx); !assertSpanContextEqual(got, want) { t.Errorf("RemoteSpanContextFromContext returned %v from a set context, want %v", got, want) } @@ -89,11 +93,12 @@ func TestContextRemoteSpanContext(t *testing.T) { ctx = ContextWithRemoteSpanContext(ctx, want) if got, ok := ctx.Value(remoteContextKey).(SpanContext); !ok { t.Errorf("failed to set SpanContext with %#v", want) - } else if got != want { - t.Errorf("got %#v from context with remote overridden, want %#v", got, want) + } else if !assertSpanContextEqual(got, want) { + t.Errorf("got %#v from context with remote set, want %#v", got, want) } - if got := RemoteSpanContextFromContext(ctx); got != want { + got = RemoteSpanContextFromContext(ctx) + if !assertSpanContextEqual(got, want) { t.Errorf("RemoteSpanContextFromContext returned %v from a set context, want %v", got, want) } } @@ -437,3 +442,324 @@ func TestSpanContextFromContext(t *testing.T) { }) } } + +func TestTraceStateString(t *testing.T) { + testCases := []struct { + name string + traceState TraceState + expectedStr string + }{ + { + name: "Non-empty trace state", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3@vendor", "val3"), + }, + }, + expectedStr: "key1=val1,key2=val2,key3@vendor=val3", + }, + { + name: "Empty trace state", + traceState: TraceState{}, + expectedStr: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedStr, tc.traceState.String()) + }) + } +} + +func TestTraceStateGet(t *testing.T) { + testCases := []struct { + name string + traceState TraceState + key label.Key + expectedValue string + }{ + { + name: "OK case", + traceState: TraceState{kvsWithMaxMembers}, + key: "key16", + expectedValue: "value16", + }, + { + name: "Not found", + traceState: TraceState{kvsWithMaxMembers}, + key: "keyxx", + expectedValue: "", + }, + { + name: "Invalid key", + traceState: TraceState{kvsWithMaxMembers}, + key: "key!", + expectedValue: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + kv := tc.traceState.Get(tc.key) + assert.Equal(t, tc.expectedValue, kv.AsString()) + }) + } +} + +func TestTraceStateDelete(t *testing.T) { + testCases := []struct { + name string + traceState TraceState + key label.Key + expectedTraceState TraceState + expectedErr error + }{ + { + name: "OK case", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + key: "key2", + expectedTraceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key3", "val3"), + }, + }, + }, + { + name: "Non-existing key", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + key: "keyx", + expectedTraceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + }, + { + name: "Invalid key", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + key: "in va lid", + expectedTraceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + expectedErr: errInvalidTraceStateKeyValue, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := tc.traceState.Delete(tc.key) + if tc.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tc.expectedErr, err) + assert.Equal(t, tc.traceState, result) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedTraceState, result) + } + }) + } +} + +func TestTraceStateInsert(t *testing.T) { + testCases := []struct { + name string + traceState TraceState + keyValue label.KeyValue + expectedTraceState TraceState + expectedErr error + }{ + { + name: "OK case - add new", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + keyValue: label.String("key4@vendor", "val4"), + expectedTraceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key4@vendor", "val4"), + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + }, + { + name: "OK case - replace", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key2", "val2"), + label.String("key3", "val3"), + }, + }, + keyValue: label.String("key2", "valX"), + expectedTraceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key2", "valX"), + label.String("key1", "val1"), + label.String("key3", "val3"), + }, + }, + }, + { + name: "Invalid key/value", + traceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + }, + }, + keyValue: label.String("key!", "val!"), + expectedTraceState: TraceState{ + kvs: []label.KeyValue{ + label.String("key1", "val1"), + }, + }, + expectedErr: errInvalidTraceStateKeyValue, + }, + { + name: "Too many entries", + traceState: TraceState{kvsWithMaxMembers}, + keyValue: label.String("keyx", "valx"), + expectedTraceState: TraceState{kvsWithMaxMembers}, + expectedErr: errInvalidTraceStateMembersNumber, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := tc.traceState.Insert(tc.keyValue) + if tc.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tc.expectedErr, err) + assert.Equal(t, tc.traceState, result) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedTraceState, result) + } + }) + } +} + +func TestTraceStateFromKeyValues(t *testing.T) { + testCases := []struct { + name string + kvs []label.KeyValue + expectedTraceState TraceState + expectedErr error + }{ + { + name: "OK case", + kvs: kvsWithMaxMembers, + expectedTraceState: TraceState{kvsWithMaxMembers}, + }, + { + name: "OK case (empty)", + expectedTraceState: TraceState{}, + }, + { + name: "Too many entries", + kvs: func() []label.KeyValue { + kvs := kvsWithMaxMembers + kvs = append(kvs, label.String("keyx", "valX")) + return kvs + }(), + expectedTraceState: TraceState{}, + expectedErr: errInvalidTraceStateMembersNumber, + }, + { + name: "Duplicate", + kvs: []label.KeyValue{ + label.String("key1", "val1"), + label.String("key1", "val2"), + }, + expectedTraceState: TraceState{}, + expectedErr: errInvalidTraceStateDuplicate, + }, + { + name: "Invalid key/value", + kvs: []label.KeyValue{ + label.String("key!", "val!"), + }, + expectedTraceState: TraceState{}, + expectedErr: errInvalidTraceStateKeyValue, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := TraceStateFromKeyValues(tc.kvs...) + if tc.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, TraceState{}, result) + assert.Equal(t, tc.expectedErr, err) + } else { + require.NoError(t, err) + assert.NotNil(t, tc.expectedTraceState) + assert.Equal(t, tc.expectedTraceState, result) + } + }) + } + +} + +func assertSpanContextEqual(got SpanContext, want SpanContext) bool { + return got.SpanID == want.SpanID && + got.TraceID == want.TraceID && + got.TraceFlags == want.TraceFlags && + assertTraceStateEqual(got.TraceState, want.TraceState) +} + +func assertTraceStateEqual(got TraceState, want TraceState) bool { + if len(got.kvs) != len(want.kvs) { + return false + } + + for i, kv := range got.kvs { + if kv != want.kvs[i] { + return false + } + } + + return true +} + +var kvsWithMaxMembers = func() []label.KeyValue { + kvs := make([]label.KeyValue, traceStateMaxListMembers) + for i := 0; i < traceStateMaxListMembers; i++ { + kvs[i] = label.String(fmt.Sprintf("key%d", i+1), + fmt.Sprintf("value%d", i+1)) + } + return kvs +}()