1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-16 02:47:03 +02:00

fix(log): WithContext() changed the ctx field of the parent log.Filter (#3069)

* fix(log): `WithContext()` changed the ctx field of the parent log.Filter

* test(log): add concurrence test for `WithContext()`

* fix(log): concurrence problem of `Filter`

---------

Co-authored-by: 包子 <baozhecheng@foxmail.com>
This commit is contained in:
ionling 2023-11-22 11:35:58 +08:00 committed by GitHub
parent 9adece088b
commit ff105d5bca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 62 additions and 34 deletions

View File

@ -1,7 +1,5 @@
package log package log
import "context"
// FilterOption is filter option. // FilterOption is filter option.
type FilterOption func(*Filter) type FilterOption func(*Filter)
@ -41,7 +39,6 @@ func FilterFunc(f func(level Level, keyvals ...interface{}) bool) FilterOption {
// Filter is a logger filter. // Filter is a logger filter.
type Filter struct { type Filter struct {
ctx context.Context
logger Logger logger Logger
level Level level Level
key map[interface{}]struct{} key map[interface{}]struct{}
@ -70,9 +67,6 @@ func (f *Filter) Log(level Level, keyvals ...interface{}) error {
// prefixkv contains the slice of arguments defined as prefixes during the log initialization // prefixkv contains the slice of arguments defined as prefixes during the log initialization
var prefixkv []interface{} var prefixkv []interface{}
l, ok := f.logger.(*logger) l, ok := f.logger.(*logger)
if ok {
l.ctx = f.ctx
}
if ok && len(l.prefix) > 0 { if ok && len(l.prefix) > 0 {
prefixkv = make([]interface{}, 0, len(l.prefix)) prefixkv = make([]interface{}, 0, len(l.prefix))
prefixkv = append(prefixkv, l.prefix...) prefixkv = append(prefixkv, l.prefix...)

View File

@ -5,7 +5,9 @@ import (
"context" "context"
"io" "io"
"strings" "strings"
"sync"
"testing" "testing"
"time"
) )
func TestFilterAll(_ *testing.T) { func TestFilterAll(_ *testing.T) {
@ -172,3 +174,52 @@ func TestFilterWithContext(t *testing.T) {
t.Error("don't read ctx value") t.Error("don't read ctx value")
} }
} }
type traceIDKey struct{}
func setTraceID(ctx context.Context, tid string) context.Context {
return context.WithValue(ctx, traceIDKey{}, tid)
}
func traceIDValuer() Valuer {
return func(ctx context.Context) any {
if ctx == nil {
return ""
}
if tid := ctx.Value(traceIDKey{}); tid != nil {
return tid
}
return ""
}
}
func TestFilterWithContextConcurrent(t *testing.T) {
var buf bytes.Buffer
pctx := context.Background()
l := NewFilter(
With(NewStdLogger(&buf), "trace-id", traceIDValuer()),
FilterLevel(LevelInfo),
)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(time.Second)
NewHelper(l).Info("done1")
}()
wg.Add(1)
go func() {
defer wg.Done()
tid := "world"
ctx := setTraceID(pctx, tid)
NewHelper((WithContext(ctx, l))).Info("done2")
}()
wg.Wait()
expected := "INFO trace-id=world msg=done2\nINFO trace-id= msg=done1\n"
if got := buf.String(); got != expected {
t.Errorf("got: %#v", got)
}
}

View File

@ -50,28 +50,16 @@ func With(l Logger, kv ...interface{}) Logger {
// WithContext returns a shallow copy of l with its context changed // WithContext returns a shallow copy of l with its context changed
// to ctx. The provided ctx must be non-nil. // to ctx. The provided ctx must be non-nil.
func WithContext(ctx context.Context, l Logger) Logger { func WithContext(ctx context.Context, l Logger) Logger {
c, ok := l.(*logger) switch v := l.(type) {
if ok { default:
return &logger{ return &logger{logger: l, ctx: ctx}
logger: c.logger, case *logger:
prefix: c.prefix, lv := *v
hasValuer: c.hasValuer, lv.ctx = ctx
ctx: ctx, return &lv
} case *Filter:
fv := *v
fv.logger = WithContext(ctx, fv.logger)
return &fv
} }
f, ok := l.(*Filter)
if ok {
f.ctx = ctx
return &Filter{
ctx: ctx,
logger: f.logger,
level: f.level,
key: f.key,
value: f.value,
filter: f.filter,
}
}
return &logger{logger: l, ctx: ctx}
} }

View File

@ -1,7 +1,6 @@
package log package log
import ( import (
"context"
"testing" "testing"
) )
@ -11,7 +10,3 @@ func TestInfo(_ *testing.T) {
logger = With(logger, "caller", DefaultCaller) logger = With(logger, "caller", DefaultCaller)
_ = logger.Log(LevelInfo, "key1", "value1") _ = logger.Log(LevelInfo, "key1", "value1")
} }
func TestWithContext(_ *testing.T) {
WithContext(context.Background(), nil)
}