diff --git a/sdk/trace/tracetest/recorder.go b/sdk/trace/tracetest/recorder.go index b0f647493..dcf32c148 100644 --- a/sdk/trace/tracetest/recorder.go +++ b/sdk/trace/tracetest/recorder.go @@ -16,13 +16,17 @@ package tracetest // import "go.opentelemetry.io/otel/sdk/trace/tracetest" import ( "context" + "sync" sdktrace "go.opentelemetry.io/otel/sdk/trace" ) // SpanRecorder records started and ended spans. type SpanRecorder struct { - started []sdktrace.ReadWriteSpan + startedMu sync.RWMutex + started []sdktrace.ReadWriteSpan + + endedMu sync.RWMutex ended []sdktrace.ReadOnlySpan } @@ -33,34 +37,54 @@ func NewSpanRecorder() *SpanRecorder { } // OnStart records started spans. +// +// This method is safe to be called concurrently. func (sr *SpanRecorder) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) { + sr.startedMu.Lock() + defer sr.startedMu.Unlock() sr.started = append(sr.started, s) } // OnEnd records completed spans. +// +// This method is safe to be called concurrently. func (sr *SpanRecorder) OnEnd(s sdktrace.ReadOnlySpan) { + sr.endedMu.Lock() + defer sr.endedMu.Unlock() sr.ended = append(sr.ended, s) } // Shutdown does nothing. +// +// This method is safe to be called concurrently. func (sr *SpanRecorder) Shutdown(context.Context) error { return nil } // ForceFlush does nothing. +// +// This method is safe to be called concurrently. func (sr *SpanRecorder) ForceFlush(context.Context) error { return nil } // Started returns a copy of all started spans that have been recorded. +// +// This method is safe to be called concurrently. func (sr *SpanRecorder) Started() []sdktrace.ReadWriteSpan { + sr.startedMu.RLock() + defer sr.startedMu.RUnlock() dst := make([]sdktrace.ReadWriteSpan, len(sr.started)) copy(dst, sr.started) return dst } // Ended returns a copy of all ended spans that have been recorded. +// +// This method is safe to be called concurrently. func (sr *SpanRecorder) Ended() []sdktrace.ReadOnlySpan { + sr.endedMu.RLock() + defer sr.endedMu.RUnlock() dst := make([]sdktrace.ReadOnlySpan, len(sr.ended)) copy(dst, sr.ended) return dst diff --git a/sdk/trace/tracetest/recorder_test.go b/sdk/trace/tracetest/recorder_test.go index 46f9a57c9..ef292a981 100644 --- a/sdk/trace/tracetest/recorder_test.go +++ b/sdk/trace/tracetest/recorder_test.go @@ -16,6 +16,7 @@ package tracetest import ( "context" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -83,3 +84,42 @@ func TestSpanRecorderForceFlushNoError(t *testing.T) { c() assert.NoError(t, new(SpanRecorder).ForceFlush(ctx)) } + +func runConcurrently(funcs ...func()) { + var wg sync.WaitGroup + + for _, f := range funcs { + wg.Add(1) + go func(f func()) { + f() + wg.Done() + }(f) + } + + wg.Wait() +} + +func TestEndingConcurrency(t *testing.T) { + sr := NewSpanRecorder() + + runConcurrently( + func() { sr.OnEnd(new(roSpan)) }, + func() { sr.OnEnd(new(roSpan)) }, + func() { sr.Ended() }, + ) + + assert.Len(t, sr.Ended(), 2) +} + +func TestStartingConcurrency(t *testing.T) { + sr := NewSpanRecorder() + + ctx := context.Background() + runConcurrently( + func() { sr.OnStart(ctx, new(rwSpan)) }, + func() { sr.OnStart(ctx, new(rwSpan)) }, + func() { sr.Started() }, + ) + + assert.Len(t, sr.Started(), 2) +}