mirror of
https://github.com/go-kratos/kratos.git
synced 2025-03-21 21:27:16 +02:00
修复container/group在相同的key并发get时,可能会初始化多次的bug (#606)
* 修复container/group在相同的key并发get时,可能会初始化多次的bug * update unit test case Co-authored-by: demons <lu.xu@zenjoy.net>
This commit is contained in:
parent
4160e34fb6
commit
e502a9f491
@ -8,7 +8,7 @@ import "sync"
|
|||||||
// Group is a lazy load container.
|
// Group is a lazy load container.
|
||||||
type Group struct {
|
type Group struct {
|
||||||
new func() interface{}
|
new func() interface{}
|
||||||
objs sync.Map
|
objs map[string]interface{}
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19,19 +19,29 @@ func NewGroup(new func() interface{}) *Group {
|
|||||||
}
|
}
|
||||||
return &Group{
|
return &Group{
|
||||||
new: new,
|
new: new,
|
||||||
|
objs: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets the object by the given key.
|
// Get gets the object by the given key.
|
||||||
func (g *Group) Get(key string) interface{} {
|
func (g *Group) Get(key string) interface{} {
|
||||||
g.RLock()
|
g.RLock()
|
||||||
new := g.new
|
obj, ok := g.objs[key]
|
||||||
|
if ok {
|
||||||
g.RUnlock()
|
g.RUnlock()
|
||||||
obj, ok := g.objs.Load(key)
|
return obj
|
||||||
if !ok {
|
|
||||||
obj = new()
|
|
||||||
g.objs.Store(key, obj)
|
|
||||||
}
|
}
|
||||||
|
g.RUnlock()
|
||||||
|
|
||||||
|
// double check
|
||||||
|
g.Lock()
|
||||||
|
defer g.Unlock()
|
||||||
|
obj, ok = g.objs[key]
|
||||||
|
if ok {
|
||||||
|
return obj
|
||||||
|
}
|
||||||
|
obj = g.new()
|
||||||
|
g.objs[key] = obj
|
||||||
return obj
|
return obj
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,8 +58,7 @@ func (g *Group) Reset(new func() interface{}) {
|
|||||||
|
|
||||||
// Clear deletes all objects.
|
// Clear deletes all objects.
|
||||||
func (g *Group) Clear() {
|
func (g *Group) Clear() {
|
||||||
g.objs.Range(func(key, value interface{}) bool {
|
g.Lock()
|
||||||
g.objs.Delete(key)
|
g.objs = make(map[string]interface{})
|
||||||
return true
|
g.Unlock()
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
@ -36,10 +36,10 @@ func TestGroupReset(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
length := 0
|
length := 0
|
||||||
g.objs.Range(func(_, _ interface{}) bool {
|
for _,_ = range g.objs {
|
||||||
length++
|
length++
|
||||||
return true
|
}
|
||||||
})
|
|
||||||
assert.Equal(t, 0, length)
|
assert.Equal(t, 0, length)
|
||||||
|
|
||||||
g.Get("/x/internal/dummy/user")
|
g.Get("/x/internal/dummy/user")
|
||||||
@ -52,18 +52,16 @@ func TestGroupClear(t *testing.T) {
|
|||||||
})
|
})
|
||||||
g.Get("/x/internal/dummy/user")
|
g.Get("/x/internal/dummy/user")
|
||||||
length := 0
|
length := 0
|
||||||
g.objs.Range(func(_, _ interface{}) bool {
|
for _,_ = range g.objs {
|
||||||
length++
|
length++
|
||||||
return true
|
}
|
||||||
})
|
|
||||||
assert.Equal(t, 1, length)
|
assert.Equal(t, 1, length)
|
||||||
|
|
||||||
g.Clear()
|
g.Clear()
|
||||||
length = 0
|
length = 0
|
||||||
g.objs.Range(func(_, _ interface{}) bool {
|
for _,_ = range g.objs {
|
||||||
length++
|
length++
|
||||||
return true
|
}
|
||||||
})
|
|
||||||
assert.Equal(t, 0, length)
|
assert.Equal(t, 0, length)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user