diff --git a/pkg/cache/memcache/README.md b/pkg/cache/memcache/README.md deleted file mode 100644 index 892c5c32a..000000000 --- a/pkg/cache/memcache/README.md +++ /dev/null @@ -1,25 +0,0 @@ -# cache/memcache - -##### 项目简介 -1. 提供protobuf,gob,json序列化方式,gzip的memcache接口 - -#### 使用方式 -```golang -// 初始化 注意这里只是示例 展示用法 不能每次都New 只需要初始化一次 -mc := memcache.New(&memcache.Config{}) -// 程序关闭的时候调用close方法 -defer mc.Close() -// 增加 key -err = mc.Set(c, &memcache.Item{}) -// 删除key -err := mc.Delete(c,key) -// 获得某个key的内容 -err := mc.Get(c,key).Scan(&v) -// 获取多个key的内容 -replies, err := mc.GetMulti(c, keys) -for _, key := range replies.Keys() { - if err = replies.Scan(key, &v); err != nil { - return - } -} -``` diff --git a/pkg/cache/memcache/ascii_conn.go b/pkg/cache/memcache/ascii_conn.go new file mode 100644 index 000000000..327629e80 --- /dev/null +++ b/pkg/cache/memcache/ascii_conn.go @@ -0,0 +1,261 @@ +package memcache + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" + + pkgerr "github.com/pkg/errors" +) + +var ( + crlf = []byte("\r\n") + space = []byte(" ") + replyOK = []byte("OK\r\n") + replyStored = []byte("STORED\r\n") + replyNotStored = []byte("NOT_STORED\r\n") + replyExists = []byte("EXISTS\r\n") + replyNotFound = []byte("NOT_FOUND\r\n") + replyDeleted = []byte("DELETED\r\n") + replyEnd = []byte("END\r\n") + replyTouched = []byte("TOUCHED\r\n") + replyClientErrorPrefix = []byte("CLIENT_ERROR ") + replyServerErrorPrefix = []byte("SERVER_ERROR ") +) + +var _ protocolConn = &asiiConn{} + +// asiiConn is the low-level implementation of Conn +type asiiConn struct { + err error + conn net.Conn + // Read & Write + readTimeout time.Duration + writeTimeout time.Duration + rw *bufio.ReadWriter +} + +func replyToError(line []byte) error { + switch { + case bytes.Equal(line, replyStored): + return nil + case bytes.Equal(line, replyOK): + return nil + case bytes.Equal(line, replyDeleted): + return nil + case bytes.Equal(line, replyTouched): + return nil + case bytes.Equal(line, replyNotStored): + return ErrNotStored + case bytes.Equal(line, replyExists): + return ErrCASConflict + case bytes.Equal(line, replyNotFound): + return ErrNotFound + case bytes.Equal(line, replyNotStored): + return ErrNotStored + case bytes.Equal(line, replyExists): + return ErrCASConflict + } + return pkgerr.WithStack(protocolError(string(line))) +} + +func (c *asiiConn) Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error { + c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) + // [noreply]\r\n + var err error + if cmd == "cas" { + _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", cmd, key, flags, expiration, len(data), cas) + } else { + _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", cmd, key, flags, expiration, len(data)) + } + if err != nil { + return c.fatal(err) + } + c.rw.Write(data) + c.rw.Write(crlf) + if err = c.rw.Flush(); err != nil { + return c.fatal(err) + } + c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout)) + line, err := c.rw.ReadSlice('\n') + if err != nil { + return c.fatal(err) + } + return replyToError(line) +} + +// newConn returns a new memcache connection for the given net connection. +func newASCIIConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) (protocolConn, error) { + if writeTimeout <= 0 || readTimeout <= 0 { + return nil, pkgerr.Errorf("readTimeout writeTimeout can't be zero") + } + c := &asiiConn{ + conn: netConn, + rw: bufio.NewReadWriter(bufio.NewReader(netConn), + bufio.NewWriter(netConn)), + readTimeout: readTimeout, + writeTimeout: writeTimeout, + } + return c, nil +} + +func (c *asiiConn) Close() error { + if c.err == nil { + c.err = pkgerr.New("memcache: closed") + } + return c.conn.Close() +} + +func (c *asiiConn) fatal(err error) error { + if c.err == nil { + c.err = pkgerr.WithStack(err) + // Close connection to force errors on subsequent calls and to unblock + // other reader or writer. + c.conn.Close() + } + return c.err +} + +func (c *asiiConn) Err() error { + return c.err +} + +func (c *asiiConn) Get(ctx context.Context, key string) (result *Item, err error) { + c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil { + return nil, c.fatal(err) + } + if err = c.rw.Flush(); err != nil { + return nil, c.fatal(err) + } + if err = c.parseGetReply(func(it *Item) { + result = it + }); err != nil { + return + } + if result == nil { + return nil, ErrNotFound + } + return +} + +func (c *asiiConn) GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error) { + var err error + c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout)) + if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { + return nil, c.fatal(err) + } + if err = c.rw.Flush(); err != nil { + return nil, c.fatal(err) + } + results := make(map[string]*Item, len(keys)) + if err = c.parseGetReply(func(it *Item) { + results[it.Key] = it + }); err != nil { + return nil, err + } + return results, nil +} + +func (c *asiiConn) parseGetReply(f func(*Item)) error { + c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout)) + for { + line, err := c.rw.ReadSlice('\n') + if err != nil { + return c.fatal(err) + } + if bytes.Equal(line, replyEnd) { + return nil + } + if bytes.HasPrefix(line, replyServerErrorPrefix) { + errMsg := line[len(replyServerErrorPrefix):] + return c.fatal(protocolError(errMsg)) + } + it := new(Item) + size, err := scanGetReply(line, it) + if err != nil { + return c.fatal(err) + } + it.Value = make([]byte, size+2) + if _, err = io.ReadFull(c.rw, it.Value); err != nil { + return c.fatal(err) + } + if !bytes.HasSuffix(it.Value, crlf) { + return c.fatal(protocolError("corrupt get reply, no except CRLF")) + } + it.Value = it.Value[:size] + f(it) + } +} + +func scanGetReply(line []byte, item *Item) (size int, err error) { + pattern := "VALUE %s %d %d %d\r\n" + dest := []interface{}{&item.Key, &item.Flags, &size, &item.cas} + if bytes.Count(line, space) == 3 { + pattern = "VALUE %s %d %d\r\n" + dest = dest[:3] + } + n, err := fmt.Sscanf(string(line), pattern, dest...) + if err != nil || n != len(dest) { + return -1, fmt.Errorf("memcache: unexpected line in get response: %q", line) + } + return size, nil +} + +func (c *asiiConn) Touch(ctx context.Context, key string, expire int32) error { + line, err := c.writeReadLine("touch %s %d\r\n", key, expire) + if err != nil { + return err + } + return replyToError(line) +} + +func (c *asiiConn) IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) { + line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta) + if err != nil { + return 0, err + } + switch { + case bytes.Equal(line, replyNotFound): + return 0, ErrNotFound + case bytes.HasPrefix(line, replyClientErrorPrefix): + errMsg := line[len(replyClientErrorPrefix):] + return 0, pkgerr.WithStack(protocolError(errMsg)) + } + val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64) + if err != nil { + return 0, err + } + return val, nil +} + +func (c *asiiConn) Delete(ctx context.Context, key string) error { + line, err := c.writeReadLine("delete %s\r\n", key) + if err != nil { + return err + } + return replyToError(line) +} + +func (c *asiiConn) writeReadLine(format string, args ...interface{}) ([]byte, error) { + c.conn.SetWriteDeadline(shrinkDeadline(context.TODO(), c.writeTimeout)) + _, err := fmt.Fprintf(c.rw, format, args...) + if err != nil { + return nil, c.fatal(pkgerr.WithStack(err)) + } + if err = c.rw.Flush(); err != nil { + return nil, c.fatal(pkgerr.WithStack(err)) + } + c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout)) + line, err := c.rw.ReadSlice('\n') + if err != nil { + return line, c.fatal(pkgerr.WithStack(err)) + } + return line, nil +} diff --git a/pkg/cache/memcache/ascii_conn_test.go b/pkg/cache/memcache/ascii_conn_test.go new file mode 100644 index 000000000..54b1dc596 --- /dev/null +++ b/pkg/cache/memcache/ascii_conn_test.go @@ -0,0 +1,569 @@ +package memcache + +import ( + "bytes" + "strconv" + "strings" + + "testing" +) + +func TestASCIIConnAdd(t *testing.T) { + tests := []struct { + name string + a *Item + e error + }{ + { + "Add", + &Item{ + Key: "test_add", + Value: []byte("0"), + Flags: 0, + Expiration: 60, + }, + nil, + }, + { + "Add_Large", + &Item{ + Key: "test_add_large", + Value: bytes.Repeat(space, _largeValue+1), + Flags: 0, + Expiration: 60, + }, + nil, + }, + { + "Add_Exist", + &Item{ + Key: "test_add", + Value: []byte("0"), + Flags: 0, + Expiration: 60, + }, + ErrNotStored, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := testConnASCII.Add(test.a); err != test.e { + t.Fatal(err) + } + if b, err := testConnASCII.Get(test.a.Key); err != nil { + t.Fatal(err) + } else { + compareItem(t, test.a, b) + } + }) + } +} + +func TestASCIIConnGet(t *testing.T) { + tests := []struct { + name string + a *Item + k string + e error + }{ + { + "Get", + &Item{ + Key: "test_get", + Value: []byte("0"), + Flags: 0, + Expiration: 60, + }, + "test_get", + nil, + }, + { + "Get_NotExist", + &Item{ + Key: "test_get_not_exist", + Value: []byte("0"), + Flags: 0, + Expiration: 60, + }, + "test_get_not_exist!", + ErrNotFound, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := testConnASCII.Add(test.a); err != nil { + t.Fatal(err) + } + if b, err := testConnASCII.Get(test.a.Key); err != nil { + t.Fatal(err) + } else { + compareItem(t, test.a, b) + } + }) + } +} + +//func TestGetHasErr(t *testing.T) { +// prepareEnv(t) +// +// st := &TestItem{Name: "json", Age: 10} +// itemx := &Item{Key: "test", Object: st, Flags: FlagJSON} +// c.Set(itemx) +// +// expected := errors.New("some error") +// monkey.Patch(scanGetReply, func(line []byte, item *Item) (size int, err error) { +// return 0, expected +// }) +// +// if _, err := c.Get("test"); err.Error() != expected.Error() { +// t.Errorf("conn.Get() unexpected error(%v)", err) +// } +// if err := c.(*asciiConn).err; err.Error() != expected.Error() { +// t.Errorf("unexpected error(%v)", err) +// } +//} + +func TestASCIIConnGetMulti(t *testing.T) { + tests := []struct { + name string + a []*Item + k []string + e error + }{ + {"getMulti_Add", + []*Item{ + { + Key: "get_multi_1", + Value: []byte("test"), + Flags: FlagRAW, + Expiration: 60, + cas: 0, + }, + { + Key: "get_multi_2", + Value: []byte("test2"), + Flags: FlagRAW, + Expiration: 60, + cas: 0, + }, + }, + []string{"get_multi_1", "get_multi_2"}, + nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, i := range test.a { + if err := testConnASCII.Set(i); err != nil { + t.Fatal(err) + } + } + if r, err := testConnASCII.GetMulti(test.k); err != nil { + t.Fatal(err) + } else { + reply := r["get_multi_1"] + compareItem(t, reply, test.a[0]) + reply = r["get_multi_2"] + compareItem(t, reply, test.a[1]) + } + + }) + } + +} + +func TestASCIIConnSet(t *testing.T) { + tests := []struct { + name string + a *Item + e error + }{ + { + "SetLowerBound", + &Item{ + Key: strings.Repeat("a", 1), + Value: []byte("4"), + Flags: 0, + Expiration: 60, + }, + nil, + }, + { + "SetUpperBound", + &Item{ + Key: strings.Repeat("a", 250), + Value: []byte("3"), + Flags: 0, + Expiration: 60, + }, + nil, + }, + { + "SetIllegalKeyZeroLength", + &Item{ + Key: "", + Value: []byte("2"), + Flags: 0, + Expiration: 60, + }, + ErrMalformedKey, + }, + { + "SetIllegalKeyLengthExceededLimit", + &Item{ + Key: " ", + Value: []byte("1"), + Flags: 0, + Expiration: 60, + }, + ErrMalformedKey, + }, + { + "SeJsonItem", + &Item{ + Key: "set_obj", + Object: &struct { + Name string + Age int + }{"json", 10}, + Expiration: 60, + Flags: FlagJSON, + }, + nil, + }, + { + "SeErrItemJSONGzip", + &Item{ + Key: "set_err_item", + Expiration: 60, + Flags: FlagJSON | FlagGzip, + }, + ErrItem, + }, + { + "SeErrItemBytesValueWrongFlag", + &Item{ + Key: "set_err_item", + Value: []byte("2"), + Expiration: 60, + Flags: FlagJSON, + }, + ErrItem, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := testConnASCII.Set(test.a); err != test.e { + t.Fatal(err) + } + }) + } +} + +func TestASCIIConnCompareAndSwap(t *testing.T) { + tests := []struct { + name string + a *Item + b *Item + c *Item + k string + e error + }{ + { + "CompareAndSwap", + &Item{ + Key: "test_cas", + Value: []byte("2"), + Flags: 0, + Expiration: 60, + }, + nil, + &Item{ + Key: "test_cas", + Value: []byte("3"), + Flags: 0, + Expiration: 60, + }, + "test_cas", + nil, + }, + { + "CompareAndSwapErrCASConflict", + &Item{ + Key: "test_cas_conflict", + Value: []byte("2"), + Flags: 0, + Expiration: 60, + }, + &Item{ + Key: "test_cas_conflict", + Value: []byte("1"), + Flags: 0, + Expiration: 60, + }, + &Item{ + Key: "test_cas_conflict", + Value: []byte("3"), + Flags: 0, + Expiration: 60, + }, + "test_cas_conflict", + ErrCASConflict, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := testConnASCII.Set(test.a); err != nil { + t.Fatal(err) + } + r, err := testConnASCII.Get(test.k) + if err != nil { + t.Fatal(err) + } + + if test.b != nil { + if err := testConnASCII.Set(test.b); err != nil { + t.Fatal(err) + } + } + + r.Value = test.c.Value + if err := testConnASCII.CompareAndSwap(r); err != nil { + if err != test.e { + t.Fatal(err) + } + } else { + if fr, err := testConnASCII.Get(test.k); err != nil { + t.Fatal(err) + } else { + compareItem(t, fr, test.c) + } + } + }) + } + + t.Run("TestCompareAndSwapErrNotFound", func(t *testing.T) { + ti := &Item{ + Key: "test_cas_notfound", + Value: []byte("2"), + Flags: 0, + Expiration: 60, + } + if err := testConnASCII.Set(ti); err != nil { + t.Fatal(err) + } + r, err := testConnASCII.Get(ti.Key) + if err != nil { + t.Fatal(err) + } + + r.Key = "test_cas_notfound_boom" + r.Value = []byte("3") + if err := testConnASCII.CompareAndSwap(r); err != nil { + if err != ErrNotFound { + t.Fatal(err) + } + } + }) +} + +func TestASCIIConnReplace(t *testing.T) { + tests := []struct { + name string + a *Item + b *Item + e error + }{ + { + "TestReplace", + &Item{ + Key: "test_replace", + Value: []byte("2"), + Flags: 0, + Expiration: 60, + }, + &Item{ + Key: "test_replace", + Value: []byte("3"), + Flags: 0, + Expiration: 60, + }, + nil, + }, + { + "TestReplaceErrNotStored", + &Item{ + Key: "test_replace_not_stored", + Value: []byte("2"), + Flags: 0, + Expiration: 60, + }, + &Item{ + Key: "test_replace_not_stored_boom", + Value: []byte("3"), + Flags: 0, + Expiration: 60, + }, + ErrNotStored, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err := testConnASCII.Set(test.a); err != nil { + t.Fatal(err) + } + if err := testConnASCII.Replace(test.b); err != nil { + if err == test.e { + return + } + t.Fatal(err) + } + if r, err := testConnASCII.Get(test.b.Key); err != nil { + t.Fatal(err) + } else { + compareItem(t, r, test.b) + } + }) + } +} + +func TestASCIIConnIncrDecr(t *testing.T) { + tests := []struct { + fn func(key string, delta uint64) (uint64, error) + name string + k string + v uint64 + w uint64 + }{ + { + testConnASCII.Increment, + "Incr_10", + "test_incr", + 10, + 10, + }, + { + testConnASCII.Increment, + "Incr_10(2)", + "test_incr", + 10, + 20, + }, + { + testConnASCII.Decrement, + "Decr_10", + "test_incr", + 10, + 10, + }, + } + if err := testConnASCII.Add(&Item{ + Key: "test_incr", + Value: []byte("0"), + }); err != nil { + t.Fatal(err) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if a, err := test.fn(test.k, test.v); err != nil { + t.Fatal(err) + } else { + if a != test.w { + t.Fatalf("want %d, got %d", test.w, a) + } + } + if b, err := testConnASCII.Get(test.k); err != nil { + t.Fatal(err) + } else { + if string(b.Value) != strconv.FormatUint(test.w, 10) { + t.Fatalf("want %s, got %d", b.Value, test.w) + } + } + }) + } +} + +func TestASCIIConnTouch(t *testing.T) { + tests := []struct { + name string + k string + a *Item + e error + }{ + { + "Touch", + "test_touch", + &Item{ + Key: "test_touch", + Value: []byte("0"), + Expiration: 60, + }, + nil, + }, + { + "Touch_NotExist", + "test_touch_not_exist", + nil, + ErrNotFound, + }, + } + for _, test := range tests { + if test.a != nil { + if err := testConnASCII.Add(test.a); err != nil { + t.Fatal(err) + } + if err := testConnASCII.Touch(test.k, 1); err != test.e { + t.Fatal(err) + } + } + } +} + +func TestASCIIConnDelete(t *testing.T) { + tests := []struct { + name string + k string + a *Item + e error + }{ + { + "Delete", + "test_delete", + &Item{ + Key: "test_delete", + Value: []byte("0"), + Expiration: 60, + }, + nil, + }, + { + "Delete_NotExist", + "test_delete_not_exist", + nil, + ErrNotFound, + }, + } + for _, test := range tests { + if test.a != nil { + if err := testConnASCII.Add(test.a); err != nil { + t.Fatal(err) + } + if err := testConnASCII.Delete(test.k); err != test.e { + t.Fatal(err) + } + if _, err := testConnASCII.Get(test.k); err != ErrNotFound { + t.Fatal(err) + } + } + } +} + +func compareItem(t *testing.T, a, b *Item) { + if a.Key != b.Key || !bytes.Equal(a.Value, b.Value) || a.Flags != b.Flags { + t.Fatalf("compareItem: a(%s, %d, %d) : b(%s, %d, %d)", a.Key, len(a.Value), a.Flags, b.Key, len(b.Value), b.Flags) + } +} diff --git a/pkg/cache/memcache/client.go b/pkg/cache/memcache/client.go deleted file mode 100644 index d0b830051..000000000 --- a/pkg/cache/memcache/client.go +++ /dev/null @@ -1,187 +0,0 @@ -package memcache - -import ( - "context" -) - -// Memcache memcache client -type Memcache struct { - pool *Pool -} - -// Reply is the result of Get -type Reply struct { - err error - item *Item - conn Conn - closed bool -} - -// Replies is the result of GetMulti -type Replies struct { - err error - items map[string]*Item - usedItems map[string]struct{} - conn Conn - closed bool -} - -// New get a memcache client -func New(c *Config) *Memcache { - return &Memcache{pool: NewPool(c)} -} - -// Close close connection pool -func (mc *Memcache) Close() error { - return mc.pool.Close() -} - -// Conn direct get a connection -func (mc *Memcache) Conn(c context.Context) Conn { - return mc.pool.Get(c) -} - -// Set writes the given item, unconditionally. -func (mc *Memcache) Set(c context.Context, item *Item) (err error) { - conn := mc.pool.Get(c) - err = conn.Set(item) - conn.Close() - return -} - -// Add writes the given item, if no value already exists for its key. -// ErrNotStored is returned if that condition is not met. -func (mc *Memcache) Add(c context.Context, item *Item) (err error) { - conn := mc.pool.Get(c) - err = conn.Add(item) - conn.Close() - return -} - -// Replace writes the given item, but only if the server *does* already hold data for this key. -func (mc *Memcache) Replace(c context.Context, item *Item) (err error) { - conn := mc.pool.Get(c) - err = conn.Replace(item) - conn.Close() - return -} - -// CompareAndSwap writes the given item that was previously returned by Get -func (mc *Memcache) CompareAndSwap(c context.Context, item *Item) (err error) { - conn := mc.pool.Get(c) - err = conn.CompareAndSwap(item) - conn.Close() - return -} - -// Get sends a command to the server for gets data. -func (mc *Memcache) Get(c context.Context, key string) *Reply { - conn := mc.pool.Get(c) - item, err := conn.Get(key) - if err != nil { - conn.Close() - } - return &Reply{err: err, item: item, conn: conn} -} - -// Item get raw Item -func (r *Reply) Item() *Item { - return r.item -} - -// Scan converts value, read from the memcache -func (r *Reply) Scan(v interface{}) (err error) { - if r.err != nil { - return r.err - } - err = r.conn.Scan(r.item, v) - if !r.closed { - r.conn.Close() - r.closed = true - } - return -} - -// GetMulti is a batch version of Get -func (mc *Memcache) GetMulti(c context.Context, keys []string) (*Replies, error) { - conn := mc.pool.Get(c) - items, err := conn.GetMulti(keys) - rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))} - if (err != nil) || (len(items) == 0) { - rs.Close() - } - return rs, err -} - -// Close close rows. -func (rs *Replies) Close() (err error) { - if !rs.closed { - err = rs.conn.Close() - rs.closed = true - } - return -} - -// Item get Item from rows -func (rs *Replies) Item(key string) *Item { - return rs.items[key] -} - -// Scan converts value, read from key in rows -func (rs *Replies) Scan(key string, v interface{}) (err error) { - if rs.err != nil { - return rs.err - } - item, ok := rs.items[key] - if !ok { - rs.Close() - return ErrNotFound - } - rs.usedItems[key] = struct{}{} - err = rs.conn.Scan(item, v) - if (err != nil) || (len(rs.items) == len(rs.usedItems)) { - rs.Close() - } - return -} - -// Keys keys of result -func (rs *Replies) Keys() (keys []string) { - keys = make([]string, 0, len(rs.items)) - for key := range rs.items { - keys = append(keys, key) - } - return -} - -// Touch updates the expiry for the given key. -func (mc *Memcache) Touch(c context.Context, key string, timeout int32) (err error) { - conn := mc.pool.Get(c) - err = conn.Touch(key, timeout) - conn.Close() - return -} - -// Delete deletes the item with the provided key. -func (mc *Memcache) Delete(c context.Context, key string) (err error) { - conn := mc.pool.Get(c) - err = conn.Delete(key) - conn.Close() - return -} - -// Increment atomically increments key by delta. -func (mc *Memcache) Increment(c context.Context, key string, delta uint64) (newValue uint64, err error) { - conn := mc.pool.Get(c) - newValue, err = conn.Increment(key, delta) - conn.Close() - return -} - -// Decrement atomically decrements key by delta. -func (mc *Memcache) Decrement(c context.Context, key string, delta uint64) (newValue uint64, err error) { - conn := mc.pool.Get(c) - newValue, err = conn.Decrement(key, delta) - conn.Close() - return -} diff --git a/pkg/cache/memcache/conn.go b/pkg/cache/memcache/conn.go index cbd72f35f..77c2d232d 100644 --- a/pkg/cache/memcache/conn.go +++ b/pkg/cache/memcache/conn.go @@ -1,78 +1,30 @@ package memcache import ( - "bufio" - "bytes" - "compress/gzip" "context" - "encoding/gob" - "encoding/json" "fmt" - "io" "net" "strconv" - "strings" - "sync" "time" - "github.com/gogo/protobuf/proto" pkgerr "github.com/pkg/errors" ) -var ( - crlf = []byte("\r\n") - spaceStr = string(" ") - replyOK = []byte("OK\r\n") - replyStored = []byte("STORED\r\n") - replyNotStored = []byte("NOT_STORED\r\n") - replyExists = []byte("EXISTS\r\n") - replyNotFound = []byte("NOT_FOUND\r\n") - replyDeleted = []byte("DELETED\r\n") - replyEnd = []byte("END\r\n") - replyTouched = []byte("TOUCHED\r\n") - replyValueStr = "VALUE" - replyClientErrorPrefix = []byte("CLIENT_ERROR ") - replyServerErrorPrefix = []byte("SERVER_ERROR ") -) - const ( - _encodeBuf = 4096 // 4kb // 1024*1024 - 1, set error??? _largeValue = 1000 * 1000 // 1MB ) -type reader struct { - io.Reader -} - -func (r *reader) Reset(rd io.Reader) { - r.Reader = rd -} - -// conn is the low-level implementation of Conn -type conn struct { - // Shared - mu sync.Mutex - err error - conn net.Conn - // Read & Write - readTimeout time.Duration - writeTimeout time.Duration - rw *bufio.ReadWriter - // Item Reader - ir bytes.Reader - // Compress - gr gzip.Reader - gw *gzip.Writer - cb bytes.Buffer - // Encoding - edb bytes.Buffer - // json - jr reader - jd *json.Decoder - je *json.Encoder - // protobuffer - ped *proto.Buffer +// low level connection that implement memcache protocol provide basic operation. +type protocolConn interface { + Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error + Get(ctx context.Context, key string) (*Item, error) + GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error) + Touch(ctx context.Context, key string, expire int32) error + IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) + Delete(ctx context.Context, key string) error + Close() error + Err() error } // DialOption specifies an option for dialing a Memcache server. @@ -83,6 +35,7 @@ type DialOption struct { type dialOptions struct { readTimeout time.Duration writeTimeout time.Duration + protocol string dial func(network, addr string) (net.Conn, error) } @@ -130,556 +83,205 @@ func Dial(network, address string, options ...DialOption) (Conn, error) { if err != nil { return nil, pkgerr.WithStack(err) } - return NewConn(netConn, do.readTimeout, do.writeTimeout), nil + pconn, err := newASCIIConn(netConn, do.readTimeout, do.writeTimeout) + return &conn{pconn: pconn, ed: newEncodeDecoder()}, nil } -// NewConn returns a new memcache connection for the given net connection. -func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn { - if writeTimeout <= 0 || readTimeout <= 0 { - panic("must config memcache timeout") - } - c := &conn{ - conn: netConn, - rw: bufio.NewReadWriter(bufio.NewReader(netConn), - bufio.NewWriter(netConn)), - readTimeout: readTimeout, - writeTimeout: writeTimeout, - } - c.jd = json.NewDecoder(&c.jr) - c.je = json.NewEncoder(&c.edb) - c.gw = gzip.NewWriter(&c.cb) - c.edb.Grow(_encodeBuf) - // NOTE reuse bytes.Buffer internal buf - // DON'T concurrency call Scan - c.ped = proto.NewBuffer(c.edb.Bytes()) - return c +type conn struct { + // low level connection. + pconn protocolConn + ed *encodeDecode } func (c *conn) Close() error { - c.mu.Lock() - err := c.err - if c.err == nil { - c.err = pkgerr.New("memcache: closed") - err = c.conn.Close() - } - c.mu.Unlock() - return err -} - -func (c *conn) fatal(err error) error { - c.mu.Lock() - if c.err == nil { - c.err = pkgerr.WithStack(err) - // Close connection to force errors on subsequent calls and to unblock - // other reader or writer. - c.conn.Close() - } - c.mu.Unlock() - return c.err + return c.pconn.Close() } func (c *conn) Err() error { - c.mu.Lock() - err := c.err - c.mu.Unlock() - return err + return c.pconn.Err() } -func (c *conn) Add(item *Item) error { - return c.populate("add", item) +func (c *conn) AddContext(ctx context.Context, item *Item) error { + return c.populate(ctx, "add", item) } -func (c *conn) Set(item *Item) error { - return c.populate("set", item) +func (c *conn) SetContext(ctx context.Context, item *Item) error { + return c.populate(ctx, "set", item) } -func (c *conn) Replace(item *Item) error { - return c.populate("replace", item) +func (c *conn) ReplaceContext(ctx context.Context, item *Item) error { + return c.populate(ctx, "replace", item) } -func (c *conn) CompareAndSwap(item *Item) error { - return c.populate("cas", item) +func (c *conn) CompareAndSwapContext(ctx context.Context, item *Item) error { + return c.populate(ctx, "cas", item) } -func (c *conn) populate(cmd string, item *Item) (err error) { +func (c *conn) populate(ctx context.Context, cmd string, item *Item) error { if !legalKey(item.Key) { - return pkgerr.WithStack(ErrMalformedKey) + return ErrMalformedKey } - var res []byte - if res, err = c.encode(item); err != nil { - return - } - l := len(res) - count := l/(_largeValue) + 1 - if count == 1 { - item.Value = res - return c.populateOne(cmd, item) - } - nItem := &Item{ - Key: item.Key, - Value: []byte(strconv.Itoa(l)), - Expiration: item.Expiration, - Flags: item.Flags | flagLargeValue, - } - err = c.populateOne(cmd, nItem) + data, err := c.ed.encode(item) if err != nil { - return + return err } - k := item.Key - nItem.Flags = item.Flags + length := len(data) + if length < _largeValue { + return c.pconn.Populate(ctx, cmd, item.Key, item.Flags, item.Expiration, item.cas, data) + } + count := length/_largeValue + 1 + if err = c.pconn.Populate(ctx, cmd, item.Key, item.Flags|flagLargeValue, item.Expiration, item.cas, []byte(strconv.Itoa(length))); err != nil { + return err + } + var chunk []byte for i := 1; i <= count; i++ { if i == count { - nItem.Value = res[_largeValue*(count-1):] + chunk = data[_largeValue*(count-1):] } else { - nItem.Value = res[_largeValue*(i-1) : _largeValue*i] + chunk = data[_largeValue*(i-1) : _largeValue*i] } - nItem.Key = fmt.Sprintf("%s%d", k, i) - if err = c.populateOne(cmd, nItem); err != nil { - return + key := fmt.Sprintf("%s%d", item.Key, i) + if err = c.pconn.Populate(ctx, cmd, key, item.Flags, item.Expiration, item.cas, chunk); err != nil { + return err } } - return + return nil } -func (c *conn) populateOne(cmd string, item *Item) (err error) { - if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) - } - - // [noreply]\r\n - if cmd == "cas" { - _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", - cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas) - } else { - _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", - cmd, item.Key, item.Flags, item.Expiration, len(item.Value)) - } - if err != nil { - return c.fatal(err) - } - c.rw.Write(item.Value) - c.rw.Write(crlf) - if err = c.rw.Flush(); err != nil { - return c.fatal(err) - } - if c.readTimeout != 0 { - c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) - } - line, err := c.rw.ReadSlice('\n') - if err != nil { - return c.fatal(err) - } - switch { - case bytes.Equal(line, replyStored): - return nil - case bytes.Equal(line, replyNotStored): - return ErrNotStored - case bytes.Equal(line, replyExists): - return ErrCASConflict - case bytes.Equal(line, replyNotFound): - return ErrNotFound - } - return pkgerr.WithStack(protocolError(string(line))) -} - -func (c *conn) Get(key string) (r *Item, err error) { +func (c *conn) GetContext(ctx context.Context, key string) (*Item, error) { if !legalKey(key) { - return nil, pkgerr.WithStack(ErrMalformedKey) + return nil, ErrMalformedKey } - if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + result, err := c.pconn.Get(ctx, key) + if err != nil { + return nil, err } - if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil { - return nil, c.fatal(err) + if result.Flags&flagLargeValue != flagLargeValue { + return result, err } - if err = c.rw.Flush(); err != nil { - return nil, c.fatal(err) - } - if err = c.parseGetReply(func(it *Item) { - r = it - }); err != nil { - return - } - if r == nil { - err = ErrNotFound - return - } - if r.Flags&flagLargeValue != flagLargeValue { - return - } - if r, err = c.getLargeValue(r); err != nil { - return - } - return + return c.getLargeItem(ctx, result) } -func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) { +func (c *conn) getLargeItem(ctx context.Context, result *Item) (*Item, error) { + length, err := strconv.Atoi(string(result.Value)) + if err != nil { + return nil, err + } + count := length/_largeValue + 1 + keys := make([]string, 0, count) + for i := 1; i <= count; i++ { + keys = append(keys, fmt.Sprintf("%s%d", result.Key, i)) + } + var results map[string]*Item + if results, err = c.pconn.GetMulti(ctx, keys...); err != nil { + return nil, err + } + if len(results) < count { + return nil, ErrNotFound + } + result.Value = make([]byte, 0, length) + for _, k := range keys { + ti := results[k] + if ti == nil || ti.Value == nil { + return nil, ErrNotFound + } + result.Value = append(result.Value, ti.Value...) + } + result.Flags = result.Flags ^ flagLargeValue + return result, nil +} + +func (c *conn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) { + // TODO: move to protocolConn? for _, key := range keys { if !legalKey(key) { - return nil, pkgerr.WithStack(ErrMalformedKey) + return nil, ErrMalformedKey } } - if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) + results, err := c.pconn.GetMulti(ctx, keys...) + if err != nil { + return results, err } - if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { - return nil, c.fatal(err) - } - if err = c.rw.Flush(); err != nil { - return nil, c.fatal(err) - } - res = make(map[string]*Item, len(keys)) - if err = c.parseGetReply(func(it *Item) { - res[it.Key] = it - }); err != nil { - return - } - for k, v := range res { + for k, v := range results { if v.Flags&flagLargeValue != flagLargeValue { continue } - r, err := c.getLargeValue(v) - if err != nil { - return res, err + if v, err = c.getLargeItem(ctx, v); err != nil { + return results, err } - res[k] = r + results[k] = v } - return + return results, nil } -func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) { - for _, key := range keys { - if !legalKey(key) { - return nil, pkgerr.WithStack(ErrMalformedKey) - } - } - if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) - } - if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { - return nil, c.fatal(err) - } - if err = c.rw.Flush(); err != nil { - return nil, c.fatal(err) - } - res = make(map[string]*Item, len(keys)) - err = c.parseGetReply(func(it *Item) { - res[it.Key] = it - }) - return -} - -func (c *conn) getLargeValue(it *Item) (r *Item, err error) { - l, err := strconv.Atoi(string(it.Value)) - if err != nil { - return - } - count := l/_largeValue + 1 - keys := make([]string, 0, count) - for i := 1; i <= count; i++ { - keys = append(keys, fmt.Sprintf("%s%d", it.Key, i)) - } - items, err := c.getMulti(keys) - if err != nil { - return - } - if len(items) < count { - err = ErrNotFound - return - } - v := make([]byte, 0, l) - for _, k := range keys { - if items[k] == nil || items[k].Value == nil { - err = ErrNotFound - return - } - v = append(v, items[k].Value...) - } - it.Value = v - it.Flags = it.Flags ^ flagLargeValue - r = it - return -} - -func (c *conn) parseGetReply(f func(*Item)) error { - if c.readTimeout != 0 { - c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) - } - for { - line, err := c.rw.ReadSlice('\n') - if err != nil { - return c.fatal(err) - } - if bytes.Equal(line, replyEnd) { - return nil - } - if bytes.HasPrefix(line, replyServerErrorPrefix) { - errMsg := line[len(replyServerErrorPrefix):] - return c.fatal(protocolError(errMsg)) - } - it := new(Item) - size, err := scanGetReply(line, it) - if err != nil { - return c.fatal(err) - } - it.Value = make([]byte, size+2) - if _, err = io.ReadFull(c.rw, it.Value); err != nil { - return c.fatal(err) - } - if !bytes.HasSuffix(it.Value, crlf) { - return c.fatal(protocolError("corrupt get reply, no except CRLF")) - } - it.Value = it.Value[:size] - f(it) - } -} - -func scanGetReply(line []byte, item *Item) (size int, err error) { - if !bytes.HasSuffix(line, crlf) { - return 0, protocolError("corrupt get reply, no except CRLF") - } - // VALUE [] - chunks := strings.Split(string(line[:len(line)-2]), spaceStr) - if len(chunks) < 4 { - return 0, protocolError("corrupt get reply") - } - if chunks[0] != replyValueStr { - return 0, protocolError("corrupt get reply, no except VALUE") - } - item.Key = chunks[1] - flags64, err := strconv.ParseUint(chunks[2], 10, 32) - if err != nil { - return 0, err - } - item.Flags = uint32(flags64) - if size, err = strconv.Atoi(chunks[3]); err != nil { - return - } - if len(chunks) > 4 { - item.cas, err = strconv.ParseUint(chunks[4], 10, 64) - } - return -} - -func (c *conn) Touch(key string, expire int32) (err error) { +func (c *conn) DeleteContext(ctx context.Context, key string) error { if !legalKey(key) { - return pkgerr.WithStack(ErrMalformedKey) - } - line, err := c.writeReadLine("touch %s %d\r\n", key, expire) - if err != nil { - return err - } - switch { - case bytes.Equal(line, replyTouched): - return nil - case bytes.Equal(line, replyNotFound): - return ErrNotFound - default: - return pkgerr.WithStack(protocolError(string(line))) + return ErrMalformedKey } + return c.pconn.Delete(ctx, key) } -func (c *conn) Increment(key string, delta uint64) (uint64, error) { - return c.incrDecr("incr", key, delta) +func (c *conn) IncrementContext(ctx context.Context, key string, delta uint64) (uint64, error) { + if !legalKey(key) { + return 0, ErrMalformedKey + } + return c.pconn.IncrDecr(ctx, "incr", key, delta) +} + +func (c *conn) DecrementContext(ctx context.Context, key string, delta uint64) (uint64, error) { + if !legalKey(key) { + return 0, ErrMalformedKey + } + return c.pconn.IncrDecr(ctx, "decr", key, delta) +} + +func (c *conn) TouchContext(ctx context.Context, key string, seconds int32) error { + if !legalKey(key) { + return ErrMalformedKey + } + return c.pconn.Touch(ctx, key, seconds) +} + +func (c *conn) Add(item *Item) error { + return c.AddContext(context.TODO(), item) +} + +func (c *conn) Set(item *Item) error { + return c.SetContext(context.TODO(), item) +} + +func (c *conn) Replace(item *Item) error { + return c.ReplaceContext(context.TODO(), item) +} + +func (c *conn) Get(key string) (*Item, error) { + return c.GetContext(context.TODO(), key) +} + +func (c *conn) GetMulti(keys []string) (map[string]*Item, error) { + return c.GetMultiContext(context.TODO(), keys) +} + +func (c *conn) Delete(key string) error { + return c.DeleteContext(context.TODO(), key) +} + +func (c *conn) Increment(key string, delta uint64) (newValue uint64, err error) { + return c.IncrementContext(context.TODO(), key, delta) } func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) { - return c.incrDecr("decr", key, delta) + return c.DecrementContext(context.TODO(), key, delta) } -func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) { - if !legalKey(key) { - return 0, pkgerr.WithStack(ErrMalformedKey) - } - line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta) - if err != nil { - return 0, err - } - switch { - case bytes.Equal(line, replyNotFound): - return 0, ErrNotFound - case bytes.HasPrefix(line, replyClientErrorPrefix): - errMsg := line[len(replyClientErrorPrefix):] - return 0, pkgerr.WithStack(protocolError(errMsg)) - } - val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64) - if err != nil { - return 0, err - } - return val, nil +func (c *conn) CompareAndSwap(item *Item) error { + return c.CompareAndSwapContext(context.TODO(), item) } -func (c *conn) Delete(key string) (err error) { - if !legalKey(key) { - return pkgerr.WithStack(ErrMalformedKey) - } - line, err := c.writeReadLine("delete %s\r\n", key) - if err != nil { - return err - } - switch { - case bytes.Equal(line, replyOK): - return nil - case bytes.Equal(line, replyDeleted): - return nil - case bytes.Equal(line, replyNotStored): - return ErrNotStored - case bytes.Equal(line, replyExists): - return ErrCASConflict - case bytes.Equal(line, replyNotFound): - return ErrNotFound - } - return pkgerr.WithStack(protocolError(string(line))) -} - -func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) { - if c.writeTimeout != 0 { - c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout)) - } - _, err := fmt.Fprintf(c.rw, format, args...) - if err != nil { - return nil, c.fatal(pkgerr.WithStack(err)) - } - if err = c.rw.Flush(); err != nil { - return nil, c.fatal(pkgerr.WithStack(err)) - } - if c.readTimeout != 0 { - c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)) - } - line, err := c.rw.ReadSlice('\n') - if err != nil { - return line, c.fatal(pkgerr.WithStack(err)) - } - return line, nil +func (c *conn) Touch(key string, seconds int32) (err error) { + return c.TouchContext(context.TODO(), key, seconds) } func (c *conn) Scan(item *Item, v interface{}) (err error) { - c.ir.Reset(item.Value) - if item.Flags&FlagGzip == FlagGzip { - if err = c.gr.Reset(&c.ir); err != nil { - return - } - if err = c.decode(&c.gr, item, v); err != nil { - err = pkgerr.WithStack(err) - return - } - err = c.gr.Close() - } else { - err = c.decode(&c.ir, item, v) - } - err = pkgerr.WithStack(err) - return -} - -func (c *conn) WithContext(ctx context.Context) Conn { - // FIXME: implement WithContext - return c -} - -func (c *conn) encode(item *Item) (data []byte, err error) { - if (item.Flags | _flagEncoding) == _flagEncoding { - if item.Value == nil { - return nil, ErrItem - } - } else if item.Object == nil { - return nil, ErrItem - } - // encoding - switch { - case item.Flags&FlagGOB == FlagGOB: - c.edb.Reset() - if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil { - return - } - data = c.edb.Bytes() - case item.Flags&FlagProtobuf == FlagProtobuf: - c.edb.Reset() - c.ped.SetBuf(c.edb.Bytes()) - pb, ok := item.Object.(proto.Message) - if !ok { - err = ErrItemObject - return - } - if err = c.ped.Marshal(pb); err != nil { - return - } - data = c.ped.Bytes() - case item.Flags&FlagJSON == FlagJSON: - c.edb.Reset() - if err = c.je.Encode(item.Object); err != nil { - return - } - data = c.edb.Bytes() - default: - data = item.Value - } - // compress - if item.Flags&FlagGzip == FlagGzip { - c.cb.Reset() - c.gw.Reset(&c.cb) - if _, err = c.gw.Write(data); err != nil { - return - } - if err = c.gw.Close(); err != nil { - return - } - data = c.cb.Bytes() - } - if len(data) > 8000000 { - err = ErrValueSize - } - return -} - -func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) { - var data []byte - switch { - case item.Flags&FlagGOB == FlagGOB: - err = gob.NewDecoder(rd).Decode(v) - case item.Flags&FlagJSON == FlagJSON: - c.jr.Reset(rd) - err = c.jd.Decode(v) - default: - data = item.Value - if item.Flags&FlagGzip == FlagGzip { - c.edb.Reset() - if _, err = io.Copy(&c.edb, rd); err != nil { - return - } - data = c.edb.Bytes() - } - if item.Flags&FlagProtobuf == FlagProtobuf { - m, ok := v.(proto.Message) - if !ok { - err = ErrItemObject - return - } - c.ped.SetBuf(data) - err = c.ped.Unmarshal(m) - } else { - switch v.(type) { - case *[]byte: - d := v.(*[]byte) - *d = data - case *string: - d := v.(*string) - *d = string(data) - case interface{}: - err = json.Unmarshal(data, v) - } - } - } - return -} - -func legalKey(key string) bool { - if len(key) > 250 || len(key) == 0 { - return false - } - for i := 0; i < len(key); i++ { - if key[i] <= ' ' || key[i] == 0x7f { - return false - } - } - return true + return pkgerr.WithStack(c.ed.decode(item, v)) } diff --git a/pkg/cache/memcache/conn_test.go b/pkg/cache/memcache/conn_test.go new file mode 100644 index 000000000..0d48ad702 --- /dev/null +++ b/pkg/cache/memcache/conn_test.go @@ -0,0 +1,185 @@ +package memcache + +import ( + "bytes" + "encoding/json" + "testing" + + test "github.com/bilibili/kratos/pkg/cache/memcache/test" + "github.com/gogo/protobuf/proto" +) + +func TestConnRaw(t *testing.T) { + item := &Item{ + Key: "test", + Value: []byte("test"), + Flags: FlagRAW, + Expiration: 60, + cas: 0, + } + if err := testConnASCII.Set(item); err != nil { + t.Errorf("conn.Store() error(%v)", err) + } +} + +func TestConnSerialization(t *testing.T) { + type TestObj struct { + Name string + Age int32 + } + + tests := []struct { + name string + a *Item + e error + }{ + + { + "JSON", + &Item{ + Key: "test_json", + Object: &TestObj{"json", 1}, + Expiration: 60, + Flags: FlagJSON, + }, + nil, + }, + { + "JSONGzip", + &Item{ + Key: "test_json_gzip", + Object: &TestObj{"jsongzip", 2}, + Expiration: 60, + Flags: FlagJSON | FlagGzip, + }, + nil, + }, + { + "GOB", + &Item{ + Key: "test_gob", + Object: &TestObj{"gob", 3}, + Expiration: 60, + Flags: FlagGOB, + }, + nil, + }, + { + "GOBGzip", + &Item{ + Key: "test_gob_gzip", + Object: &TestObj{"gobgzip", 4}, + Expiration: 60, + Flags: FlagGOB | FlagGzip, + }, + nil, + }, + { + "Protobuf", + &Item{ + Key: "test_protobuf", + Object: &test.TestItem{Name: "protobuf", Age: 6}, + Expiration: 60, + Flags: FlagProtobuf, + }, + nil, + }, + { + "ProtobufGzip", + &Item{ + Key: "test_protobuf_gzip", + Object: &test.TestItem{Name: "protobufgzip", Age: 7}, + Expiration: 60, + Flags: FlagProtobuf | FlagGzip, + }, + nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if err := testConnASCII.Set(tc.a); err != nil { + t.Fatal(err) + } + if r, err := testConnASCII.Get(tc.a.Key); err != tc.e { + t.Fatal(err) + } else { + if (tc.a.Flags & FlagProtobuf) > 0 { + var no test.TestItem + if err := testConnASCII.Scan(r, &no); err != nil { + t.Fatal(err) + } + if (tc.a.Object.(*test.TestItem).Name != no.Name) || (tc.a.Object.(*test.TestItem).Age != no.Age) { + t.Fatalf("compare failed error, %v %v", tc.a.Object.(*test.TestItem), no) + } + } else { + var no TestObj + if err := testConnASCII.Scan(r, &no); err != nil { + t.Fatal(err) + } + if (tc.a.Object.(*TestObj).Name != no.Name) || (tc.a.Object.(*TestObj).Age != no.Age) { + t.Fatalf("compare failed error, %v %v", tc.a.Object.(*TestObj), no) + } + } + + } + }) + } +} + +func BenchmarkConnJSON(b *testing.B) { + st := &struct { + Name string + Age int + }{"json", 10} + itemx := &Item{Key: "json", Object: st, Flags: FlagJSON} + var ( + eb bytes.Buffer + je *json.Encoder + ir bytes.Reader + jd *json.Decoder + jr reader + nst test.TestItem + ) + jd = json.NewDecoder(&jr) + je = json.NewEncoder(&eb) + eb.Grow(_encodeBuf) + // NOTE reuse bytes.Buffer internal buf + // DON'T concurrency call Scan + b.ResetTimer() + for i := 0; i < b.N; i++ { + eb.Reset() + if err := je.Encode(itemx.Object); err != nil { + return + } + data := eb.Bytes() + ir.Reset(data) + jr.Reset(&ir) + jd.Decode(&nst) + } +} + +func BenchmarkConnProtobuf(b *testing.B) { + st := &test.TestItem{Name: "protobuf", Age: 10} + itemx := &Item{Key: "protobuf", Object: st, Flags: FlagJSON} + var ( + eb bytes.Buffer + nst test.TestItem + ped *proto.Buffer + ) + ped = proto.NewBuffer(eb.Bytes()) + eb.Grow(_encodeBuf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + ped.Reset() + pb, ok := itemx.Object.(proto.Message) + if !ok { + return + } + if err := ped.Marshal(pb); err != nil { + return + } + data := ped.Bytes() + ped.SetBuf(data) + ped.Unmarshal(&nst) + } +} diff --git a/pkg/cache/memcache/encoding.go b/pkg/cache/memcache/encoding.go new file mode 100644 index 000000000..1a386af9b --- /dev/null +++ b/pkg/cache/memcache/encoding.go @@ -0,0 +1,162 @@ +package memcache + +import ( + "bytes" + "compress/gzip" + "encoding/gob" + "encoding/json" + "io" + + "github.com/gogo/protobuf/proto" +) + +type reader struct { + io.Reader +} + +func (r *reader) Reset(rd io.Reader) { + r.Reader = rd +} + +const _encodeBuf = 4096 // 4kb + +type encodeDecode struct { + // Item Reader + ir bytes.Reader + // Compress + gr gzip.Reader + gw *gzip.Writer + cb bytes.Buffer + // Encoding + edb bytes.Buffer + // json + jr reader + jd *json.Decoder + je *json.Encoder + // protobuffer + ped *proto.Buffer +} + +func newEncodeDecoder() *encodeDecode { + ed := &encodeDecode{} + ed.jd = json.NewDecoder(&ed.jr) + ed.je = json.NewEncoder(&ed.edb) + ed.gw = gzip.NewWriter(&ed.cb) + ed.edb.Grow(_encodeBuf) + // NOTE reuse bytes.Buffer internal buf + // DON'T concurrency call Scan + ed.ped = proto.NewBuffer(ed.edb.Bytes()) + return ed +} + +func (ed *encodeDecode) encode(item *Item) (data []byte, err error) { + if (item.Flags | _flagEncoding) == _flagEncoding { + if item.Value == nil { + return nil, ErrItem + } + } else if item.Object == nil { + return nil, ErrItem + } + // encoding + switch { + case item.Flags&FlagGOB == FlagGOB: + ed.edb.Reset() + if err = gob.NewEncoder(&ed.edb).Encode(item.Object); err != nil { + return + } + data = ed.edb.Bytes() + case item.Flags&FlagProtobuf == FlagProtobuf: + ed.edb.Reset() + ed.ped.SetBuf(ed.edb.Bytes()) + pb, ok := item.Object.(proto.Message) + if !ok { + err = ErrItemObject + return + } + if err = ed.ped.Marshal(pb); err != nil { + return + } + data = ed.ped.Bytes() + case item.Flags&FlagJSON == FlagJSON: + ed.edb.Reset() + if err = ed.je.Encode(item.Object); err != nil { + return + } + data = ed.edb.Bytes() + default: + data = item.Value + } + // compress + if item.Flags&FlagGzip == FlagGzip { + ed.cb.Reset() + ed.gw.Reset(&ed.cb) + if _, err = ed.gw.Write(data); err != nil { + return + } + if err = ed.gw.Close(); err != nil { + return + } + data = ed.cb.Bytes() + } + if len(data) > 8000000 { + err = ErrValueSize + } + return +} + +func (ed *encodeDecode) decode(item *Item, v interface{}) (err error) { + var ( + data []byte + rd io.Reader + ) + ed.ir.Reset(item.Value) + rd = &ed.ir + if item.Flags&FlagGzip == FlagGzip { + rd = &ed.gr + if err = ed.gr.Reset(&ed.ir); err != nil { + return + } + defer func() { + if e := ed.gr.Close(); e != nil { + err = e + } + }() + } + switch { + case item.Flags&FlagGOB == FlagGOB: + err = gob.NewDecoder(rd).Decode(v) + case item.Flags&FlagJSON == FlagJSON: + ed.jr.Reset(rd) + err = ed.jd.Decode(v) + default: + data = item.Value + if item.Flags&FlagGzip == FlagGzip { + ed.edb.Reset() + if _, err = io.Copy(&ed.edb, rd); err != nil { + return + } + data = ed.edb.Bytes() + } + if item.Flags&FlagProtobuf == FlagProtobuf { + m, ok := v.(proto.Message) + if !ok { + err = ErrItemObject + return + } + ed.ped.SetBuf(data) + err = ed.ped.Unmarshal(m) + } else { + switch v.(type) { + case *[]byte: + d := v.(*[]byte) + *d = data + case *string: + d := v.(*string) + *d = string(data) + case interface{}: + err = json.Unmarshal(data, v) + } + } + } + return +} diff --git a/pkg/cache/memcache/encoding_test.go b/pkg/cache/memcache/encoding_test.go new file mode 100644 index 000000000..3fadac5fb --- /dev/null +++ b/pkg/cache/memcache/encoding_test.go @@ -0,0 +1,220 @@ +package memcache + +import ( + "bytes" + "testing" + + mt "github.com/bilibili/kratos/pkg/cache/memcache/test" +) + +func TestEncode(t *testing.T) { + type TestObj struct { + Name string + Age int32 + } + testObj := TestObj{"abc", 1} + + ed := newEncodeDecoder() + tests := []struct { + name string + a *Item + r []byte + e error + }{ + { + "EncodeRawFlagErrItem", + &Item{ + Object: &TestObj{"abc", 1}, + Flags: FlagRAW, + }, + []byte{}, + ErrItem, + }, + { + "EncodeEncodeFlagErrItem", + &Item{ + Value: []byte("test"), + Flags: FlagJSON, + }, + []byte{}, + ErrItem, + }, + { + "EncodeEmpty", + &Item{ + Value: []byte(""), + Flags: FlagRAW, + }, + []byte(""), + nil, + }, + { + "EncodeMaxSize", + &Item{ + Value: bytes.Repeat([]byte("A"), 8000000), + Flags: FlagRAW, + }, + bytes.Repeat([]byte("A"), 8000000), + nil, + }, + { + "EncodeExceededMaxSize", + &Item{ + Value: bytes.Repeat([]byte("A"), 8000000+1), + Flags: FlagRAW, + }, + nil, + ErrValueSize, + }, + { + "EncodeGOB", + &Item{ + Object: testObj, + Flags: FlagGOB, + }, + []byte{38, 255, 131, 3, 1, 1, 7, 84, 101, 115, 116, 79, 98, 106, 1, 255, 132, 0, 1, 2, 1, 4, 78, 97, 109, 101, 1, 12, 0, 1, 3, 65, 103, 101, 1, 4, 0, 0, 0, 10, 255, 132, 1, 3, 97, 98, 99, 1, 2, 0}, + nil, + }, + { + "EncodeJSON", + &Item{ + Object: testObj, + Flags: FlagJSON, + }, + []byte{123, 34, 78, 97, 109, 101, 34, 58, 34, 97, 98, 99, 34, 44, 34, 65, 103, 101, 34, 58, 49, 125, 10}, + nil, + }, + { + "EncodeProtobuf", + &Item{ + Object: &mt.TestItem{Name: "abc", Age: 1}, + Flags: FlagProtobuf, + }, + []byte{10, 3, 97, 98, 99, 16, 1}, + nil, + }, + { + "EncodeGzip", + &Item{ + Value: bytes.Repeat([]byte("B"), 50), + Flags: FlagGzip, + }, + []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 114, 34, 25, 0, 2, 0, 0, 255, 255, 252, 253, 67, 209, 50, 0, 0, 0}, + nil, + }, + { + "EncodeGOBGzip", + &Item{ + Object: testObj, + Flags: FlagGOB | FlagGzip, + }, + []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 82, 251, 223, 204, 204, 200, 200, 30, 146, 90, 92, 226, 159, 148, 197, 248, 191, 133, 129, 145, 137, 145, 197, 47, 49, 55, 149, 145, 135, 129, 145, 217, 49, 61, 149, 145, 133, 129, 129, 129, 235, 127, 11, 35, 115, 98, 82, 50, 35, 19, 3, 32, 0, 0, 255, 255, 211, 249, 1, 154, 50, 0, 0, 0}, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if r, err := ed.encode(test.a); err != test.e { + t.Fatal(err) + } else { + if err == nil { + if !bytes.Equal(r, test.r) { + t.Fatalf("not equal, expect %v\n got %v", test.r, r) + } + } + } + }) + } +} + +func TestDecode(t *testing.T) { + type TestObj struct { + Name string + Age int32 + } + testObj := &TestObj{"abc", 1} + + ed := newEncodeDecoder() + tests := []struct { + name string + a *Item + r interface{} + e error + }{ + { + "DecodeGOB", + &Item{ + Flags: FlagGOB, + Value: []byte{38, 255, 131, 3, 1, 1, 7, 84, 101, 115, 116, 79, 98, 106, 1, 255, 132, 0, 1, 2, 1, 4, 78, 97, 109, 101, 1, 12, 0, 1, 3, 65, 103, 101, 1, 4, 0, 0, 0, 10, 255, 132, 1, 3, 97, 98, 99, 1, 2, 0}, + }, + testObj, + nil, + }, + { + "DecodeJSON", + &Item{ + Value: []byte{123, 34, 78, 97, 109, 101, 34, 58, 34, 97, 98, 99, 34, 44, 34, 65, 103, 101, 34, 58, 49, 125, 10}, + Flags: FlagJSON, + }, + testObj, + nil, + }, + { + "DecodeProtobuf", + &Item{ + Value: []byte{10, 3, 97, 98, 99, 16, 1}, + + Flags: FlagProtobuf, + }, + &mt.TestItem{Name: "abc", Age: 1}, + nil, + }, + { + "DecodeGzip", + &Item{ + Value: []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 114, 34, 25, 0, 2, 0, 0, 255, 255, 252, 253, 67, 209, 50, 0, 0, 0}, + Flags: FlagGzip, + }, + bytes.Repeat([]byte("B"), 50), + nil, + }, + { + "DecodeGOBGzip", + &Item{ + Value: []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 82, 251, 223, 204, 204, 200, 200, 30, 146, 90, 92, 226, 159, 148, 197, 248, 191, 133, 129, 145, 137, 145, 197, 47, 49, 55, 149, 145, 135, 129, 145, 217, 49, 61, 149, 145, 133, 129, 129, 129, 235, 127, 11, 35, 115, 98, 82, 50, 35, 19, 3, 32, 0, 0, 255, 255, 211, 249, 1, 154, 50, 0, 0, 0}, + Flags: FlagGOB | FlagGzip, + }, + testObj, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if (test.a.Flags & FlagProtobuf) > 0 { + var dd mt.TestItem + if err := ed.decode(test.a, &dd); err != nil { + t.Fatal(err) + } + if (test.r.(*mt.TestItem).Name != dd.Name) || (test.r.(*mt.TestItem).Age != dd.Age) { + t.Fatalf("compare failed error, expect %v\n got %v", test.r.(*mt.TestItem), dd) + } + } else if test.a.Flags == FlagGzip { + var dd []byte + if err := ed.decode(test.a, &dd); err != nil { + t.Fatal(err) + } + if !bytes.Equal(dd, test.r.([]byte)) { + t.Fatalf("compare failed error, expect %v\n got %v", test.r, dd) + } + } else { + var dd TestObj + if err := ed.decode(test.a, &dd); err != nil { + t.Fatal(err) + } + if (test.r.(*TestObj).Name != dd.Name) || (test.r.(*TestObj).Age != dd.Age) { + t.Fatalf("compare failed error, expect %v\n got %v", test.r.(*TestObj), dd) + } + } + }) + } +} diff --git a/pkg/cache/memcache/example_test.go b/pkg/cache/memcache/example_test.go new file mode 100644 index 000000000..bab5d1f00 --- /dev/null +++ b/pkg/cache/memcache/example_test.go @@ -0,0 +1,177 @@ +package memcache + +import ( + "encoding/json" + "fmt" + "time" +) + +var testExampleAddr string + +func ExampleConn_set() { + var ( + err error + value []byte + conn Conn + expire int32 = 100 + p = struct { + Name string + Age int64 + }{"golang", 10} + ) + cnop := DialConnectTimeout(time.Duration(time.Second)) + rdop := DialReadTimeout(time.Duration(time.Second)) + wrop := DialWriteTimeout(time.Duration(time.Second)) + if value, err = json.Marshal(p); err != nil { + fmt.Println(err) + return + } + if conn, err = Dial("tcp", testExampleAddr, cnop, rdop, wrop); err != nil { + fmt.Println(err) + return + } + // FlagRAW test + itemRaw := &Item{ + Key: "test_raw", + Value: value, + Expiration: expire, + } + if err = conn.Set(itemRaw); err != nil { + fmt.Println(err) + return + } + // FlagGzip + itemGZip := &Item{ + Key: "test_gzip", + Value: value, + Flags: FlagGzip, + Expiration: expire, + } + if err = conn.Set(itemGZip); err != nil { + fmt.Println(err) + return + } + // FlagGOB + itemGOB := &Item{ + Key: "test_gob", + Object: p, + Flags: FlagGOB, + Expiration: expire, + } + if err = conn.Set(itemGOB); err != nil { + fmt.Println(err) + return + } + // FlagJSON + itemJSON := &Item{ + Key: "test_json", + Object: p, + Flags: FlagJSON, + Expiration: expire, + } + if err = conn.Set(itemJSON); err != nil { + fmt.Println(err) + return + } + // FlagJSON | FlagGzip + itemJSONGzip := &Item{ + Key: "test_jsonGzip", + Object: p, + Flags: FlagJSON | FlagGzip, + Expiration: expire, + } + if err = conn.Set(itemJSONGzip); err != nil { + fmt.Println(err) + return + } + // Output: +} + +func ExampleConn_get() { + var ( + err error + item2 *Item + conn Conn + p struct { + Name string + Age int64 + } + ) + cnop := DialConnectTimeout(time.Duration(time.Second)) + rdop := DialReadTimeout(time.Duration(time.Second)) + wrop := DialWriteTimeout(time.Duration(time.Second)) + if conn, err = Dial("tcp", testExampleAddr, cnop, rdop, wrop); err != nil { + fmt.Println(err) + return + } + if item2, err = conn.Get("test_raw"); err != nil { + fmt.Println(err) + } else { + if err = conn.Scan(item2, &p); err != nil { + fmt.Printf("FlagRAW conn.Scan error(%v)\n", err) + return + } + } + // FlagGZip + if item2, err = conn.Get("test_gzip"); err != nil { + fmt.Println(err) + } else { + if err = conn.Scan(item2, &p); err != nil { + fmt.Printf("FlagGZip conn.Scan error(%v)\n", err) + return + } + } + // FlagGOB + if item2, err = conn.Get("test_gob"); err != nil { + fmt.Println(err) + } else { + if err = conn.Scan(item2, &p); err != nil { + fmt.Printf("FlagGOB conn.Scan error(%v)\n", err) + return + } + } + // FlagJSON + if item2, err = conn.Get("test_json"); err != nil { + fmt.Println(err) + } else { + if err = conn.Scan(item2, &p); err != nil { + fmt.Printf("FlagJSON conn.Scan error(%v)\n", err) + return + } + } + // Output: +} + +func ExampleConn_getMulti() { + var ( + err error + conn Conn + res map[string]*Item + keys = []string{"test_raw", "test_gzip"} + p struct { + Name string + Age int64 + } + ) + cnop := DialConnectTimeout(time.Duration(time.Second)) + rdop := DialReadTimeout(time.Duration(time.Second)) + wrop := DialWriteTimeout(time.Duration(time.Second)) + if conn, err = Dial("tcp", testExampleAddr, cnop, rdop, wrop); err != nil { + fmt.Println(err) + return + } + if res, err = conn.GetMulti(keys); err != nil { + fmt.Printf("conn.GetMulti(%v) error(%v)", keys, err) + return + } + for _, v := range res { + if err = conn.Scan(v, &p); err != nil { + fmt.Printf("conn.Scan error(%v)\n", err) + return + } + fmt.Println(p) + } + // Output: + //{golang 10} + //{golang 10} +} diff --git a/pkg/cache/memcache/main_test.go b/pkg/cache/memcache/main_test.go new file mode 100644 index 000000000..5d40535a6 --- /dev/null +++ b/pkg/cache/memcache/main_test.go @@ -0,0 +1,85 @@ +package memcache + +import ( + "log" + "os" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" + xtime "github.com/bilibili/kratos/pkg/time" +) + +var testConnASCII Conn +var testMemcache *Memcache +var testPool *Pool +var testMemcacheAddr string + +func setupTestConnASCII(addr string) { + var err error + cnop := DialConnectTimeout(time.Duration(2 * time.Second)) + rdop := DialReadTimeout(time.Duration(2 * time.Second)) + wrop := DialWriteTimeout(time.Duration(2 * time.Second)) + testConnASCII, err = Dial("tcp", addr, cnop, rdop, wrop) + if err != nil { + log.Fatal(err) + } + testConnASCII.Delete("test") + testConnASCII.Delete("test1") + testConnASCII.Delete("test2") + if err != nil { + log.Fatal(err) + } +} + +func setupTestMemcache(addr string) { + testConfig := &Config{ + Config: &pool.Config{ + Active: 10, + Idle: 10, + IdleTimeout: xtime.Duration(time.Second), + WaitTimeout: xtime.Duration(time.Second), + Wait: false, + }, + Addr: addr, + Proto: "tcp", + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + } + testMemcache = New(testConfig) +} + +func setupTestPool(addr string) { + config := &Config{ + Name: "test", + Proto: "tcp", + Addr: addr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + } + config.Config = &pool.Config{ + Active: 10, + Idle: 5, + IdleTimeout: xtime.Duration(90 * time.Second), + } + testPool = NewPool(config) +} + +func TestMain(m *testing.M) { + testMemcacheAddr = os.Getenv("TEST_MEMCACHE_ADDR") + if testExampleAddr == "" { + log.Print("TEST_MEMCACHE_ADDR not provide skip test.") + // ignored test. + os.Exit(0) + } + setupTestConnASCII(testMemcacheAddr) + setupTestMemcache(testMemcacheAddr) + setupTestPool(testMemcacheAddr) + // TODO: add setupexample? + testExampleAddr = testMemcacheAddr + + ret := m.Run() + os.Exit(ret) +} diff --git a/pkg/cache/memcache/memcache.go b/pkg/cache/memcache/memcache.go index 2847840ff..a0b745e4b 100644 --- a/pkg/cache/memcache/memcache.go +++ b/pkg/cache/memcache/memcache.go @@ -2,13 +2,11 @@ package memcache import ( "context" + + "github.com/bilibili/kratos/pkg/container/pool" + xtime "github.com/bilibili/kratos/pkg/time" ) -// Error represents an error returned in a command reply. -type Error string - -func (err Error) Error() string { return string(err) } - const ( // Flag, 15(encoding) bit+ 17(compress) bit @@ -87,20 +85,20 @@ type Conn interface { GetMulti(keys []string) (map[string]*Item, error) // Delete deletes the item with the provided key. - // The error ErrCacheMiss is returned if the item didn't already exist in + // The error ErrNotFound is returned if the item didn't already exist in // the cache. Delete(key string) error // Increment atomically increments key by delta. The return value is the // new value after being incremented or an error. If the value didn't exist - // in memcached the error is ErrCacheMiss. The value in memcached must be + // in memcached the error is ErrNotFound. The value in memcached must be // an decimal number, or an error will be returned. // On 64-bit overflow, the new value wraps around. Increment(key string, delta uint64) (newValue uint64, err error) // Decrement atomically decrements key by delta. The return value is the // new value after being decremented or an error. If the value didn't exist - // in memcached the error is ErrCacheMiss. The value in memcached must be + // in memcached the error is ErrNotFound. The value in memcached must be // an decimal number, or an error will be returned. On underflow, the new // value is capped at zero and does not wrap around. Decrement(key string, delta uint64) (newValue uint64, err error) @@ -116,7 +114,7 @@ type Conn interface { // Touch updates the expiry for the given key. The seconds parameter is // either a Unix timestamp or, if seconds is less than 1 month, the number // of seconds into the future at which time the item will expire. - //ErrCacheMiss is returned if the key is not in the cache. The key must be + // ErrNotFound is returned if the key is not in the cache. The key must be // at most 250 bytes in length. Touch(key string, seconds int32) (err error) @@ -129,8 +127,251 @@ type Conn interface { // Scan(item *Item, v interface{}) (err error) - // WithContext return a Conn with its context changed to ctx - // the context controls the entire lifetime of Conn before you change it - // NOTE: this method is not thread-safe - WithContext(ctx context.Context) Conn + // Add writes the given item, if no value already exists for its key. + // ErrNotStored is returned if that condition is not met. + AddContext(ctx context.Context, item *Item) error + + // Set writes the given item, unconditionally. + SetContext(ctx context.Context, item *Item) error + + // Replace writes the given item, but only if the server *does* already + // hold data for this key. + ReplaceContext(ctx context.Context, item *Item) error + + // Get sends a command to the server for gets data. + GetContext(ctx context.Context, key string) (*Item, error) + + // GetMulti is a batch version of Get. The returned map from keys to items + // may have fewer elements than the input slice, due to memcache cache + // misses. Each key must be at most 250 bytes in length. + // If no error is returned, the returned map will also be non-nil. + GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) + + // Delete deletes the item with the provided key. + // The error ErrNotFound is returned if the item didn't already exist in + // the cache. + DeleteContext(ctx context.Context, key string) error + + // Increment atomically increments key by delta. The return value is the + // new value after being incremented or an error. If the value didn't exist + // in memcached the error is ErrNotFound. The value in memcached must be + // an decimal number, or an error will be returned. + // On 64-bit overflow, the new value wraps around. + IncrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) + + // Decrement atomically decrements key by delta. The return value is the + // new value after being decremented or an error. If the value didn't exist + // in memcached the error is ErrNotFound. The value in memcached must be + // an decimal number, or an error will be returned. On underflow, the new + // value is capped at zero and does not wrap around. + DecrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) + + // CompareAndSwap writes the given item that was previously returned by + // Get, if the value was neither modified or evicted between the Get and + // the CompareAndSwap calls. The item's Key should not change between calls + // but all other item fields may differ. ErrCASConflict is returned if the + // value was modified in between the calls. + // ErrNotStored is returned if the value was evicted in between the calls. + CompareAndSwapContext(ctx context.Context, item *Item) error + + // Touch updates the expiry for the given key. The seconds parameter is + // either a Unix timestamp or, if seconds is less than 1 month, the number + // of seconds into the future at which time the item will expire. + // ErrNotFound is returned if the key is not in the cache. The key must be + // at most 250 bytes in length. + TouchContext(ctx context.Context, key string, seconds int32) (err error) +} + +// Config memcache config. +type Config struct { + *pool.Config + + Name string // memcache name, for trace + Proto string + Addr string + DialTimeout xtime.Duration + ReadTimeout xtime.Duration + WriteTimeout xtime.Duration +} + +// Memcache memcache client +type Memcache struct { + pool *Pool +} + +// Reply is the result of Get +type Reply struct { + err error + item *Item + conn Conn + closed bool +} + +// Replies is the result of GetMulti +type Replies struct { + err error + items map[string]*Item + usedItems map[string]struct{} + conn Conn + closed bool +} + +// New get a memcache client +func New(cfg *Config) *Memcache { + return &Memcache{pool: NewPool(cfg)} +} + +// Close close connection pool +func (mc *Memcache) Close() error { + return mc.pool.Close() +} + +// Conn direct get a connection +func (mc *Memcache) Conn(ctx context.Context) Conn { + return mc.pool.Get(ctx) +} + +// Set writes the given item, unconditionally. +func (mc *Memcache) Set(ctx context.Context, item *Item) (err error) { + conn := mc.pool.Get(ctx) + err = conn.SetContext(ctx, item) + conn.Close() + return +} + +// Add writes the given item, if no value already exists for its key. +// ErrNotStored is returned if that condition is not met. +func (mc *Memcache) Add(ctx context.Context, item *Item) (err error) { + conn := mc.pool.Get(ctx) + err = conn.AddContext(ctx, item) + conn.Close() + return +} + +// Replace writes the given item, but only if the server *does* already hold data for this key. +func (mc *Memcache) Replace(ctx context.Context, item *Item) (err error) { + conn := mc.pool.Get(ctx) + err = conn.ReplaceContext(ctx, item) + conn.Close() + return +} + +// CompareAndSwap writes the given item that was previously returned by Get +func (mc *Memcache) CompareAndSwap(ctx context.Context, item *Item) (err error) { + conn := mc.pool.Get(ctx) + err = conn.CompareAndSwapContext(ctx, item) + conn.Close() + return +} + +// Get sends a command to the server for gets data. +func (mc *Memcache) Get(ctx context.Context, key string) *Reply { + conn := mc.pool.Get(ctx) + item, err := conn.GetContext(ctx, key) + if err != nil { + conn.Close() + } + return &Reply{err: err, item: item, conn: conn} +} + +// Item get raw Item +func (r *Reply) Item() *Item { + return r.item +} + +// Scan converts value, read from the memcache +func (r *Reply) Scan(v interface{}) (err error) { + if r.err != nil { + return r.err + } + err = r.conn.Scan(r.item, v) + if !r.closed { + r.conn.Close() + r.closed = true + } + return +} + +// GetMulti is a batch version of Get +func (mc *Memcache) GetMulti(ctx context.Context, keys []string) (*Replies, error) { + conn := mc.pool.Get(ctx) + items, err := conn.GetMultiContext(ctx, keys) + rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))} + if (err != nil) || (len(items) == 0) { + rs.Close() + } + return rs, err +} + +// Close close rows. +func (rs *Replies) Close() (err error) { + if !rs.closed { + err = rs.conn.Close() + rs.closed = true + } + return +} + +// Item get Item from rows +func (rs *Replies) Item(key string) *Item { + return rs.items[key] +} + +// Scan converts value, read from key in rows +func (rs *Replies) Scan(key string, v interface{}) (err error) { + if rs.err != nil { + return rs.err + } + item, ok := rs.items[key] + if !ok { + rs.Close() + return ErrNotFound + } + rs.usedItems[key] = struct{}{} + err = rs.conn.Scan(item, v) + if (err != nil) || (len(rs.items) == len(rs.usedItems)) { + rs.Close() + } + return +} + +// Keys keys of result +func (rs *Replies) Keys() (keys []string) { + keys = make([]string, 0, len(rs.items)) + for key := range rs.items { + keys = append(keys, key) + } + return +} + +// Touch updates the expiry for the given key. +func (mc *Memcache) Touch(ctx context.Context, key string, timeout int32) (err error) { + conn := mc.pool.Get(ctx) + err = conn.TouchContext(ctx, key, timeout) + conn.Close() + return +} + +// Delete deletes the item with the provided key. +func (mc *Memcache) Delete(ctx context.Context, key string) (err error) { + conn := mc.pool.Get(ctx) + err = conn.DeleteContext(ctx, key) + conn.Close() + return +} + +// Increment atomically increments key by delta. +func (mc *Memcache) Increment(ctx context.Context, key string, delta uint64) (newValue uint64, err error) { + conn := mc.pool.Get(ctx) + newValue, err = conn.IncrementContext(ctx, key, delta) + conn.Close() + return +} + +// Decrement atomically decrements key by delta. +func (mc *Memcache) Decrement(ctx context.Context, key string, delta uint64) (newValue uint64, err error) { + conn := mc.pool.Get(ctx) + newValue, err = conn.DecrementContext(ctx, key, delta) + conn.Close() + return } diff --git a/pkg/cache/memcache/memcache_test.go b/pkg/cache/memcache/memcache_test.go new file mode 100644 index 000000000..878841c6a --- /dev/null +++ b/pkg/cache/memcache/memcache_test.go @@ -0,0 +1,300 @@ +package memcache + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" +) + +func Test_client_Set(t *testing.T) { + type args struct { + c context.Context + item *Item + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "set value", args: args{c: context.Background(), item: &Item{Key: "Test_client_Set", Value: []byte("abc")}}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := testMemcache.Set(tt.args.c, tt.args.item); (err != nil) != tt.wantErr { + t.Errorf("client.Set() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_client_Add(t *testing.T) { + type args struct { + c context.Context + item *Item + } + key := fmt.Sprintf("Test_client_Add_%d", time.Now().Unix()) + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "add not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: false}, + {name: "add exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := testMemcache.Add(tt.args.c, tt.args.item); (err != nil) != tt.wantErr { + t.Errorf("client.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_client_Replace(t *testing.T) { + key := fmt.Sprintf("Test_client_Replace_%d", time.Now().Unix()) + ekey := "Test_client_Replace_exist" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("ok")}) + type args struct { + c context.Context + item *Item + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), item: &Item{Key: ekey, Value: []byte("abc")}}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := testMemcache.Replace(tt.args.c, tt.args.item); (err != nil) != tt.wantErr { + t.Errorf("client.Replace() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_client_CompareAndSwap(t *testing.T) { + key := fmt.Sprintf("Test_client_CompareAndSwap_%d", time.Now().Unix()) + ekey := "Test_client_CompareAndSwap_k" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")}) + cas := testMemcache.Get(context.Background(), ekey).Item().cas + type args struct { + c context.Context + item *Item + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), item: &Item{Key: ekey, cas: cas, Value: []byte("abc")}}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := testMemcache.CompareAndSwap(tt.args.c, tt.args.item); (err != nil) != tt.wantErr { + t.Errorf("client.CompareAndSwap() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_client_Get(t *testing.T) { + key := fmt.Sprintf("Test_client_Get_%d", time.Now().Unix()) + ekey := "Test_client_Get_k" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")}) + type args struct { + c context.Context + key string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), key: key}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), key: ekey}, wantErr: false, want: "old"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var res string + if err := testMemcache.Get(tt.args.c, tt.args.key).Scan(&res); (err != nil) != tt.wantErr || res != tt.want { + t.Errorf("client.Get() = %v, want %v, got err: %v, want err: %v", err, tt.want, err, tt.wantErr) + } + }) + } +} + +func Test_client_Touch(t *testing.T) { + key := fmt.Sprintf("Test_client_Touch_%d", time.Now().Unix()) + ekey := "Test_client_Touch_k" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")}) + type args struct { + c context.Context + key string + timeout int32 + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), key: key, timeout: 100000}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), key: ekey, timeout: 100000}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := testMemcache.Touch(tt.args.c, tt.args.key, tt.args.timeout); (err != nil) != tt.wantErr { + t.Errorf("client.Touch() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_client_Delete(t *testing.T) { + key := fmt.Sprintf("Test_client_Delete_%d", time.Now().Unix()) + ekey := "Test_client_Delete_k" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")}) + type args struct { + c context.Context + key string + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), key: key}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), key: ekey}, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := testMemcache.Delete(tt.args.c, tt.args.key); (err != nil) != tt.wantErr { + t.Errorf("client.Delete() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_client_Increment(t *testing.T) { + key := fmt.Sprintf("Test_client_Increment_%d", time.Now().Unix()) + ekey := "Test_client_Increment_k" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("1")}) + type args struct { + c context.Context + key string + delta uint64 + } + tests := []struct { + name string + args args + wantNewValue uint64 + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), key: key, delta: 10}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), key: ekey, delta: 10}, wantErr: false, wantNewValue: 11}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotNewValue, err := testMemcache.Increment(tt.args.c, tt.args.key, tt.args.delta) + if (err != nil) != tt.wantErr { + t.Errorf("client.Increment() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotNewValue != tt.wantNewValue { + t.Errorf("client.Increment() = %v, want %v", gotNewValue, tt.wantNewValue) + } + }) + } +} + +func Test_client_Decrement(t *testing.T) { + key := fmt.Sprintf("Test_client_Decrement_%d", time.Now().Unix()) + ekey := "Test_client_Decrement_k" + testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("100")}) + type args struct { + c context.Context + key string + delta uint64 + } + tests := []struct { + name string + args args + wantNewValue uint64 + wantErr bool + }{ + {name: "not exist value", args: args{c: context.Background(), key: key, delta: 10}, wantErr: true}, + {name: "exist value", args: args{c: context.Background(), key: ekey, delta: 10}, wantErr: false, wantNewValue: 90}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotNewValue, err := testMemcache.Decrement(tt.args.c, tt.args.key, tt.args.delta) + if (err != nil) != tt.wantErr { + t.Errorf("client.Decrement() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotNewValue != tt.wantNewValue { + t.Errorf("client.Decrement() = %v, want %v", gotNewValue, tt.wantNewValue) + } + }) + } +} + +func Test_client_GetMulti(t *testing.T) { + key := fmt.Sprintf("Test_client_GetMulti_%d", time.Now().Unix()) + ekey1 := "Test_client_GetMulti_k1" + ekey2 := "Test_client_GetMulti_k2" + testMemcache.Set(context.Background(), &Item{Key: ekey1, Value: []byte("1")}) + testMemcache.Set(context.Background(), &Item{Key: ekey2, Value: []byte("2")}) + keys := []string{key, ekey1, ekey2} + rows, err := testMemcache.GetMulti(context.Background(), keys) + if err != nil { + t.Errorf("client.GetMulti() error = %v, wantErr %v", err, nil) + } + tests := []struct { + key string + wantNewValue string + wantErr bool + nilItem bool + }{ + {key: ekey1, wantErr: false, wantNewValue: "1", nilItem: false}, + {key: ekey2, wantErr: false, wantNewValue: "2", nilItem: false}, + {key: key, wantErr: true, nilItem: true}, + } + if reflect.DeepEqual(keys, rows.Keys()) { + t.Errorf("got %v, expect: %v", rows.Keys(), keys) + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + var gotNewValue string + err = rows.Scan(tt.key, &gotNewValue) + if (err != nil) != tt.wantErr { + t.Errorf("rows.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotNewValue != tt.wantNewValue { + t.Errorf("rows.Value() = %v, want %v", gotNewValue, tt.wantNewValue) + } + if (rows.Item(tt.key) == nil) != tt.nilItem { + t.Errorf("rows.Item() = %v, want %v", rows.Item(tt.key) == nil, tt.nilItem) + } + }) + } + err = rows.Close() + if err != nil { + t.Errorf("client.Replies.Close() error = %v, wantErr %v", err, nil) + } +} + +func Test_client_Conn(t *testing.T) { + conn := testMemcache.Conn(context.Background()) + defer conn.Close() + if conn == nil { + t.Errorf("expect get conn, get nil") + } +} diff --git a/pkg/cache/memcache/mock.go b/pkg/cache/memcache/mock.go deleted file mode 100644 index 192b6cc1d..000000000 --- a/pkg/cache/memcache/mock.go +++ /dev/null @@ -1,59 +0,0 @@ -package memcache - -import ( - "context" -) - -// MockErr for unit test. -type MockErr struct { - Error error -} - -var _ Conn = MockErr{} - -// MockWith return a mock conn. -func MockWith(err error) MockErr { - return MockErr{Error: err} -} - -// Err . -func (m MockErr) Err() error { return m.Error } - -// Close . -func (m MockErr) Close() error { return m.Error } - -// Add . -func (m MockErr) Add(item *Item) error { return m.Error } - -// Set . -func (m MockErr) Set(item *Item) error { return m.Error } - -// Replace . -func (m MockErr) Replace(item *Item) error { return m.Error } - -// CompareAndSwap . -func (m MockErr) CompareAndSwap(item *Item) error { return m.Error } - -// Get . -func (m MockErr) Get(key string) (*Item, error) { return nil, m.Error } - -// GetMulti . -func (m MockErr) GetMulti(keys []string) (map[string]*Item, error) { return nil, m.Error } - -// Touch . -func (m MockErr) Touch(key string, timeout int32) error { return m.Error } - -// Delete . -func (m MockErr) Delete(key string) error { return m.Error } - -// Increment . -func (m MockErr) Increment(key string, delta uint64) (uint64, error) { return 0, m.Error } - -// Decrement . -func (m MockErr) Decrement(key string, delta uint64) (uint64, error) { return 0, m.Error } - -// Scan . -func (m MockErr) Scan(item *Item, v interface{}) error { return m.Error } - -// WithContext . -func (m MockErr) WithContext(ctx context.Context) Conn { return m } diff --git a/pkg/cache/memcache/pool.go b/pkg/cache/memcache/pool.go deleted file mode 100644 index 52e8fb4d9..000000000 --- a/pkg/cache/memcache/pool.go +++ /dev/null @@ -1,197 +0,0 @@ -package memcache - -import ( - "context" - "io" - "time" - - "github.com/bilibili/kratos/pkg/container/pool" - "github.com/bilibili/kratos/pkg/stat" - xtime "github.com/bilibili/kratos/pkg/time" -) - -var stats = stat.Cache - -// Config memcache config. -type Config struct { - *pool.Config - - Name string // memcache name, for trace - Proto string - Addr string - DialTimeout xtime.Duration - ReadTimeout xtime.Duration - WriteTimeout xtime.Duration -} - -// Pool memcache connection pool struct. -type Pool struct { - p pool.Pool - c *Config -} - -// NewPool new a memcache conn pool. -func NewPool(c *Config) (p *Pool) { - if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 { - panic("must config memcache timeout") - } - p1 := pool.NewList(c.Config) - cnop := DialConnectTimeout(time.Duration(c.DialTimeout)) - rdop := DialReadTimeout(time.Duration(c.ReadTimeout)) - wrop := DialWriteTimeout(time.Duration(c.WriteTimeout)) - p1.New = func(ctx context.Context) (io.Closer, error) { - conn, err := Dial(c.Proto, c.Addr, cnop, rdop, wrop) - return &traceConn{Conn: conn, address: c.Addr}, err - } - p = &Pool{p: p1, c: c} - return -} - -// Get gets a connection. The application must close the returned connection. -// This method always returns a valid connection so that applications can defer -// error handling to the first use of the connection. If there is an error -// getting an underlying connection, then the connection Err, Do, Send, Flush -// and Receive methods return that error. -func (p *Pool) Get(ctx context.Context) Conn { - c, err := p.p.Get(ctx) - if err != nil { - return errorConnection{err} - } - c1, _ := c.(Conn) - return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx} -} - -// Close release the resources used by the pool. -func (p *Pool) Close() error { - return p.p.Close() -} - -type pooledConnection struct { - p *Pool - c Conn - ctx context.Context -} - -func pstat(key string, t time.Time, err error) { - stats.Timing(key, int64(time.Since(t)/time.Millisecond)) - if err != nil { - if msg := formatErr(err); msg != "" { - stats.Incr("memcache", msg) - } - } -} - -func (pc *pooledConnection) Close() error { - c := pc.c - if _, ok := c.(errorConnection); ok { - return nil - } - pc.c = errorConnection{ErrConnClosed} - pc.p.p.Put(context.Background(), c, c.Err() != nil) - return nil -} - -func (pc *pooledConnection) Err() error { - return pc.c.Err() -} - -func (pc *pooledConnection) Set(item *Item) (err error) { - now := time.Now() - err = pc.c.Set(item) - pstat("memcache:set", now, err) - return -} - -func (pc *pooledConnection) Add(item *Item) (err error) { - now := time.Now() - err = pc.c.Add(item) - pstat("memcache:add", now, err) - return -} - -func (pc *pooledConnection) Replace(item *Item) (err error) { - now := time.Now() - err = pc.c.Replace(item) - pstat("memcache:replace", now, err) - return -} - -func (pc *pooledConnection) CompareAndSwap(item *Item) (err error) { - now := time.Now() - err = pc.c.CompareAndSwap(item) - pstat("memcache:cas", now, err) - return -} - -func (pc *pooledConnection) Get(key string) (r *Item, err error) { - now := time.Now() - r, err = pc.c.Get(key) - pstat("memcache:get", now, err) - return -} - -func (pc *pooledConnection) GetMulti(keys []string) (res map[string]*Item, err error) { - // if keys is empty slice returns empty map direct - if len(keys) == 0 { - return make(map[string]*Item), nil - } - now := time.Now() - res, err = pc.c.GetMulti(keys) - pstat("memcache:gets", now, err) - return -} - -func (pc *pooledConnection) Touch(key string, timeout int32) (err error) { - now := time.Now() - err = pc.c.Touch(key, timeout) - pstat("memcache:touch", now, err) - return -} - -func (pc *pooledConnection) Scan(item *Item, v interface{}) error { - return pc.c.Scan(item, v) -} - -func (pc *pooledConnection) WithContext(ctx context.Context) Conn { - // TODO: set context - pc.ctx = ctx - return pc -} - -func (pc *pooledConnection) Delete(key string) (err error) { - now := time.Now() - err = pc.c.Delete(key) - pstat("memcache:delete", now, err) - return -} - -func (pc *pooledConnection) Increment(key string, delta uint64) (newValue uint64, err error) { - now := time.Now() - newValue, err = pc.c.Increment(key, delta) - pstat("memcache:increment", now, err) - return -} - -func (pc *pooledConnection) Decrement(key string, delta uint64) (newValue uint64, err error) { - now := time.Now() - newValue, err = pc.c.Decrement(key, delta) - pstat("memcache:decrement", now, err) - return -} - -type errorConnection struct{ err error } - -func (ec errorConnection) Err() error { return ec.err } -func (ec errorConnection) Close() error { return ec.err } -func (ec errorConnection) Add(item *Item) error { return ec.err } -func (ec errorConnection) Set(item *Item) error { return ec.err } -func (ec errorConnection) Replace(item *Item) error { return ec.err } -func (ec errorConnection) CompareAndSwap(item *Item) error { return ec.err } -func (ec errorConnection) Get(key string) (*Item, error) { return nil, ec.err } -func (ec errorConnection) GetMulti(keys []string) (map[string]*Item, error) { return nil, ec.err } -func (ec errorConnection) Touch(key string, timeout int32) error { return ec.err } -func (ec errorConnection) Delete(key string) error { return ec.err } -func (ec errorConnection) Increment(key string, delta uint64) (uint64, error) { return 0, ec.err } -func (ec errorConnection) Decrement(key string, delta uint64) (uint64, error) { return 0, ec.err } -func (ec errorConnection) Scan(item *Item, v interface{}) error { return ec.err } -func (ec errorConnection) WithContext(ctx context.Context) Conn { return ec } diff --git a/pkg/cache/memcache/pool_conn.go b/pkg/cache/memcache/pool_conn.go new file mode 100644 index 000000000..4ccff2aac --- /dev/null +++ b/pkg/cache/memcache/pool_conn.go @@ -0,0 +1,204 @@ +package memcache + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" + "github.com/bilibili/kratos/pkg/stat" +) + +var stats = stat.Cache + +// Pool memcache connection pool struct. +// Deprecated: Use Memcache instead +type Pool struct { + p pool.Pool + c *Config +} + +// NewPool new a memcache conn pool. +// Deprecated: Use New instead +func NewPool(cfg *Config) (p *Pool) { + if cfg.DialTimeout <= 0 || cfg.ReadTimeout <= 0 || cfg.WriteTimeout <= 0 { + panic("must config memcache timeout") + } + p1 := pool.NewList(cfg.Config) + cnop := DialConnectTimeout(time.Duration(cfg.DialTimeout)) + rdop := DialReadTimeout(time.Duration(cfg.ReadTimeout)) + wrop := DialWriteTimeout(time.Duration(cfg.WriteTimeout)) + p1.New = func(ctx context.Context) (io.Closer, error) { + conn, err := Dial(cfg.Proto, cfg.Addr, cnop, rdop, wrop) + return newTraceConn(conn, fmt.Sprintf("%s://%s", cfg.Proto, cfg.Addr)), err + } + p = &Pool{p: p1, c: cfg} + return +} + +// Get gets a connection. The application must close the returned connection. +// This method always returns a valid connection so that applications can defer +// error handling to the first use of the connection. If there is an error +// getting an underlying connection, then the connection Err, Do, Send, Flush +// and Receive methods return that error. +func (p *Pool) Get(ctx context.Context) Conn { + c, err := p.p.Get(ctx) + if err != nil { + return errConn{err} + } + c1, _ := c.(Conn) + return &poolConn{p: p, c: c1, ctx: ctx} +} + +// Close release the resources used by the pool. +func (p *Pool) Close() error { + return p.p.Close() +} + +type poolConn struct { + c Conn + p *Pool + ctx context.Context +} + +func pstat(key string, t time.Time, err error) { + stats.Timing(key, int64(time.Since(t)/time.Millisecond)) + if err != nil { + if msg := formatErr(err); msg != "" { + stats.Incr("memcache", msg) + } + } +} + +func (pc *poolConn) Close() error { + c := pc.c + if _, ok := c.(errConn); ok { + return nil + } + pc.c = errConn{ErrConnClosed} + pc.p.p.Put(context.Background(), c, c.Err() != nil) + return nil +} + +func (pc *poolConn) Err() error { + return pc.c.Err() +} + +func (pc *poolConn) Set(item *Item) (err error) { + return pc.c.SetContext(pc.ctx, item) +} + +func (pc *poolConn) Add(item *Item) (err error) { + return pc.AddContext(pc.ctx, item) +} + +func (pc *poolConn) Replace(item *Item) (err error) { + return pc.ReplaceContext(pc.ctx, item) +} + +func (pc *poolConn) CompareAndSwap(item *Item) (err error) { + return pc.CompareAndSwapContext(pc.ctx, item) +} + +func (pc *poolConn) Get(key string) (r *Item, err error) { + return pc.c.GetContext(pc.ctx, key) +} + +func (pc *poolConn) GetMulti(keys []string) (res map[string]*Item, err error) { + return pc.c.GetMultiContext(pc.ctx, keys) +} + +func (pc *poolConn) Touch(key string, timeout int32) (err error) { + return pc.c.TouchContext(pc.ctx, key, timeout) +} + +func (pc *poolConn) Scan(item *Item, v interface{}) error { + return pc.c.Scan(item, v) +} + +func (pc *poolConn) Delete(key string) (err error) { + return pc.c.DeleteContext(pc.ctx, key) +} + +func (pc *poolConn) Increment(key string, delta uint64) (newValue uint64, err error) { + return pc.c.IncrementContext(pc.ctx, key, delta) +} + +func (pc *poolConn) Decrement(key string, delta uint64) (newValue uint64, err error) { + return pc.c.DecrementContext(pc.ctx, key, delta) +} + +func (pc *poolConn) AddContext(ctx context.Context, item *Item) error { + now := time.Now() + err := pc.c.AddContext(ctx, item) + pstat("memcache:add", now, err) + return err +} + +func (pc *poolConn) SetContext(ctx context.Context, item *Item) error { + now := time.Now() + err := pc.c.SetContext(ctx, item) + pstat("memcache:set", now, err) + return err +} + +func (pc *poolConn) ReplaceContext(ctx context.Context, item *Item) error { + now := time.Now() + err := pc.c.ReplaceContext(ctx, item) + pstat("memcache:replace", now, err) + return err +} + +func (pc *poolConn) GetContext(ctx context.Context, key string) (*Item, error) { + now := time.Now() + item, err := pc.c.Get(key) + pstat("memcache:get", now, err) + return item, err +} + +func (pc *poolConn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) { + // if keys is empty slice returns empty map direct + if len(keys) == 0 { + return make(map[string]*Item), nil + } + now := time.Now() + items, err := pc.c.GetMulti(keys) + pstat("memcache:gets", now, err) + return items, err +} + +func (pc *poolConn) DeleteContext(ctx context.Context, key string) error { + now := time.Now() + err := pc.c.Delete(key) + pstat("memcache:delete", now, err) + return err +} + +func (pc *poolConn) IncrementContext(ctx context.Context, key string, delta uint64) (uint64, error) { + now := time.Now() + newValue, err := pc.c.IncrementContext(ctx, key, delta) + pstat("memcache:increment", now, err) + return newValue, err +} + +func (pc *poolConn) DecrementContext(ctx context.Context, key string, delta uint64) (uint64, error) { + now := time.Now() + newValue, err := pc.c.DecrementContext(ctx, key, delta) + pstat("memcache:decrement", now, err) + return newValue, err +} + +func (pc *poolConn) CompareAndSwapContext(ctx context.Context, item *Item) error { + now := time.Now() + err := pc.c.CompareAndSwap(item) + pstat("memcache:cas", now, err) + return err +} + +func (pc *poolConn) TouchContext(ctx context.Context, key string, seconds int32) error { + now := time.Now() + err := pc.c.Touch(key, seconds) + pstat("memcache:touch", now, err) + return err +} diff --git a/pkg/cache/memcache/pool_conn_test.go b/pkg/cache/memcache/pool_conn_test.go new file mode 100644 index 000000000..a61bdb697 --- /dev/null +++ b/pkg/cache/memcache/pool_conn_test.go @@ -0,0 +1,545 @@ +package memcache + +import ( + "bytes" + "context" + "reflect" + "testing" + "time" + + "github.com/bilibili/kratos/pkg/container/pool" + xtime "github.com/bilibili/kratos/pkg/time" +) + +var itempool = &Item{ + Key: "testpool", + Value: []byte("testpool"), + Flags: 0, + Expiration: 60, + cas: 0, +} +var itempool2 = &Item{ + Key: "test_count", + Value: []byte("0"), + Flags: 0, + Expiration: 1000, + cas: 0, +} + +type testObject struct { + Mid int64 + Value []byte +} + +var largeValue = &Item{ + Key: "large_value", + Flags: FlagGOB | FlagGzip, + Expiration: 1000, + cas: 0, +} + +var largeValueBoundary = &Item{ + Key: "large_value", + Flags: FlagGOB | FlagGzip, + Expiration: 1000, + cas: 0, +} + +func TestPoolSet(t *testing.T) { + conn := testPool.Get(context.Background()) + defer conn.Close() + // set + if err := conn.Set(itempool); err != nil { + t.Errorf("memcache: set error(%v)", err) + } else { + t.Logf("memcache: set value: %s", itempool.Value) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolGet(t *testing.T) { + key := "testpool" + conn := testPool.Get(context.Background()) + defer conn.Close() + // get + if res, err := conn.Get(key); err != nil { + t.Errorf("memcache: get error(%v)", err) + } else { + t.Logf("memcache: get value: %s", res.Value) + } + if _, err := conn.Get("not_found"); err != ErrNotFound { + t.Errorf("memcache: expceted err is not found but got: %v", err) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolGetMulti(t *testing.T) { + conn := testPool.Get(context.Background()) + defer conn.Close() + s := []string{"testpool", "test1"} + // get + if res, err := conn.GetMulti(s); err != nil { + t.Errorf("memcache: gets error(%v)", err) + } else { + t.Logf("memcache: gets value: %d", len(res)) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolTouch(t *testing.T) { + key := "testpool" + conn := testPool.Get(context.Background()) + defer conn.Close() + // touch + if err := conn.Touch(key, 10); err != nil { + t.Errorf("memcache: touch error(%v)", err) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolIncrement(t *testing.T) { + key := "test_count" + conn := testPool.Get(context.Background()) + defer conn.Close() + // set + if err := conn.Set(itempool2); err != nil { + t.Errorf("memcache: set error(%v)", err) + } else { + t.Logf("memcache: set value: 0") + } + // incr + if res, err := conn.Increment(key, 1); err != nil { + t.Errorf("memcache: incr error(%v)", err) + } else { + t.Logf("memcache: incr n: %d", res) + if res != 1 { + t.Errorf("memcache: expected res=1 but got %d", res) + } + } + // decr + if res, err := conn.Decrement(key, 1); err != nil { + t.Errorf("memcache: decr error(%v)", err) + } else { + t.Logf("memcache: decr n: %d", res) + if res != 0 { + t.Errorf("memcache: expected res=0 but got %d", res) + } + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolErr(t *testing.T) { + conn := testPool.Get(context.Background()) + defer conn.Close() + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } + if err := conn.Err(); err == nil { + t.Errorf("memcache: err not nil") + } else { + t.Logf("memcache: err: %v", err) + } +} + +func TestPoolCompareAndSwap(t *testing.T) { + conn := testPool.Get(context.Background()) + defer conn.Close() + key := "testpool" + //cas + if r, err := conn.Get(key); err != nil { + t.Errorf("conn.Get() error(%v)", err) + } else { + r.Value = []byte("shit") + if err := conn.CompareAndSwap(r); err != nil { + t.Errorf("conn.Get() error(%v)", err) + } + r, _ := conn.Get("testpool") + if r.Key != "testpool" || !bytes.Equal(r.Value, []byte("shit")) || r.Flags != 0 { + t.Error("conn.Get() error, value") + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } + } +} + +func TestPoolDel(t *testing.T) { + key := "testpool" + conn := testPool.Get(context.Background()) + defer conn.Close() + // delete + if err := conn.Delete(key); err != nil { + t.Errorf("memcache: delete error(%v)", err) + } else { + t.Logf("memcache: delete key: %s", key) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func BenchmarkMemcache(b *testing.B) { + c := &Config{ + Name: "test", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + } + c.Config = &pool.Config{ + Active: 10, + Idle: 5, + IdleTimeout: xtime.Duration(90 * time.Second), + } + testPool = NewPool(c) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn := testPool.Get(context.Background()) + if err := conn.Close(); err != nil { + b.Errorf("memcache: close error(%v)", err) + } + } + }) + if err := testPool.Close(); err != nil { + b.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolSetLargeValue(t *testing.T) { + var b bytes.Buffer + for i := 0; i < 4000000; i++ { + b.WriteByte(1) + } + obj := &testObject{} + obj.Mid = 1000 + obj.Value = b.Bytes() + largeValue.Object = obj + conn := testPool.Get(context.Background()) + defer conn.Close() + // set + if err := conn.Set(largeValue); err != nil { + t.Errorf("memcache: set error(%v)", err) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolGetLargeValue(t *testing.T) { + key := largeValue.Key + conn := testPool.Get(context.Background()) + defer conn.Close() + // get + var err error + if _, err = conn.Get(key); err != nil { + t.Errorf("memcache: large get error(%+v)", err) + } +} + +func TestPoolGetMultiLargeValue(t *testing.T) { + conn := testPool.Get(context.Background()) + defer conn.Close() + s := []string{largeValue.Key, largeValue.Key} + // get + if res, err := conn.GetMulti(s); err != nil { + t.Errorf("memcache: gets error(%v)", err) + } else { + t.Logf("memcache: gets value: %d", len(res)) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolSetLargeValueBoundary(t *testing.T) { + var b bytes.Buffer + for i := 0; i < _largeValue; i++ { + b.WriteByte(1) + } + obj := &testObject{} + obj.Mid = 1000 + obj.Value = b.Bytes() + largeValueBoundary.Object = obj + conn := testPool.Get(context.Background()) + defer conn.Close() + // set + if err := conn.Set(largeValueBoundary); err != nil { + t.Errorf("memcache: set error(%v)", err) + } + if err := conn.Close(); err != nil { + t.Errorf("memcache: close error(%v)", err) + } +} + +func TestPoolGetLargeValueBoundary(t *testing.T) { + key := largeValueBoundary.Key + conn := testPool.Get(context.Background()) + defer conn.Close() + // get + var err error + if _, err = conn.Get(key); err != nil { + t.Errorf("memcache: large get error(%v)", err) + } +} + +func TestPoolAdd(t *testing.T) { + var ( + key = "test_add" + item = &Item{ + Key: key, + Value: []byte("0"), + Flags: 0, + Expiration: 60, + cas: 0, + } + conn = testPool.Get(context.Background()) + ) + defer conn.Close() + conn.Delete(key) + if err := conn.Add(item); err != nil { + t.Errorf("memcache: add error(%v)", err) + } + if err := conn.Add(item); err != ErrNotStored { + t.Errorf("memcache: add error(%v)", err) + } +} + +func TestNewPool(t *testing.T) { + type args struct { + cfg *Config + } + tests := []struct { + name string + args args + wantErr error + wantPanic bool + }{ + { + "NewPoolIllegalDialTimeout", + args{ + &Config{ + Name: "test_illegal_dial_timeout", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(-time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }, + }, + nil, + true, + }, + { + "NewPoolIllegalReadTimeout", + args{ + &Config{ + Name: "test_illegal_read_timeout", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(-time.Second), + WriteTimeout: xtime.Duration(time.Second), + }, + }, + nil, + true, + }, + { + "NewPoolIllegalWriteTimeout", + args{ + &Config{ + Name: "test_illegal_write_timeout", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(-time.Second), + }, + }, + nil, + true, + }, + { + "NewPool", + args{ + &Config{ + Name: "test_new", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }, + }, + nil, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if (r != nil) != tt.wantPanic { + t.Errorf("wantPanic recover = %v, wantPanic = %v", r, tt.wantPanic) + } + }() + + if gotP := NewPool(tt.args.cfg); gotP == nil { + t.Error("NewPool() failed, got nil") + } + }) + } +} + +func TestPool_Get(t *testing.T) { + + type args struct { + ctx context.Context + } + tests := []struct { + name string + p *Pool + args args + wantErr bool + n int + }{ + { + "Get", + NewPool(&Config{ + Config: &pool.Config{ + Active: 3, + Idle: 2, + }, + Name: "test_get", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }), + args{context.TODO()}, + false, + 3, + }, + { + "GetExceededPoolSize", + NewPool(&Config{ + Config: &pool.Config{ + Active: 3, + Idle: 2, + }, + Name: "test_get_out", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }), + args{context.TODO()}, + true, + 6, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := 1; i <= tt.n; i++ { + got := tt.p.Get(tt.args.ctx) + if reflect.TypeOf(got) == reflect.TypeOf(errConn{}) { + if !tt.wantErr { + t.Errorf("got errConn, export Conn") + } + return + } else { + if tt.wantErr { + if i > tt.p.c.Active { + t.Errorf("got Conn, export errConn") + } + } + } + } + }) + } +} + +func TestPool_Close(t *testing.T) { + + type args struct { + ctx context.Context + } + tests := []struct { + name string + p *Pool + args args + wantErr bool + g int + c int + }{ + { + "Close", + NewPool(&Config{ + Config: &pool.Config{ + Active: 1, + Idle: 1, + }, + Name: "test_get", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }), + args{context.TODO()}, + false, + 3, + 3, + }, + { + "CloseExceededPoolSize", + NewPool(&Config{ + Config: &pool.Config{ + Active: 1, + Idle: 1, + }, + Name: "test_get_out", + Proto: "tcp", + Addr: testMemcacheAddr, + DialTimeout: xtime.Duration(time.Second), + ReadTimeout: xtime.Duration(time.Second), + WriteTimeout: xtime.Duration(time.Second), + }), + args{context.TODO()}, + true, + 5, + 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := 1; i <= tt.g; i++ { + got := tt.p.Get(tt.args.ctx) + if err := got.Close(); err != nil { + if !tt.wantErr { + t.Error(err) + } + } + if i <= tt.c { + if err := got.Close(); err != nil { + t.Error(err) + } + } + } + }) + } +} diff --git a/pkg/cache/memcache/test/BUILD.bazel b/pkg/cache/memcache/test/BUILD.bazel new file mode 100644 index 000000000..0bf9680e6 --- /dev/null +++ b/pkg/cache/memcache/test/BUILD.bazel @@ -0,0 +1,48 @@ +load( + "@io_bazel_rules_go//go:def.bzl", + "go_library", +) +load( + "@io_bazel_rules_go//proto:def.bzl", + "go_proto_library", +) + +go_library( + name = "go_default_library", + srcs = [], + embed = [":proto_go_proto"], + importpath = "go-common/library/cache/memcache/test", + tags = ["automanaged"], + visibility = ["//visibility:public"], + deps = ["@com_github_golang_protobuf//proto:go_default_library"], +) + +filegroup( + name = "package-srcs", + srcs = glob(["**"]), + tags = ["automanaged"], + visibility = ["//visibility:private"], +) + +filegroup( + name = "all-srcs", + srcs = [":package-srcs"], + tags = ["automanaged"], + visibility = ["//visibility:public"], +) + +proto_library( + name = "test_proto", + srcs = ["test.proto"], + import_prefix = "go-common/library/cache/memcache/test", + strip_import_prefix = "", + tags = ["automanaged"], +) + +go_proto_library( + name = "proto_go_proto", + compilers = ["@io_bazel_rules_go//proto:go_proto"], + importpath = "go-common/library/cache/memcache/test", + proto = ":test_proto", + tags = ["automanaged"], +) diff --git a/pkg/cache/memcache/test/test.pb.go b/pkg/cache/memcache/test/test.pb.go new file mode 100644 index 000000000..1dc41aa00 --- /dev/null +++ b/pkg/cache/memcache/test/test.pb.go @@ -0,0 +1,375 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: test.proto + +/* + Package proto is a generated protocol buffer package. + + It is generated from these files: + test.proto + + It has these top-level messages: + TestItem +*/ +package proto + +import proto1 "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import io "io" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto1.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto1.ProtoPackageIsVersion2 // please upgrade the proto package + +type FOO int32 + +const ( + FOO_X FOO = 0 +) + +var FOO_name = map[int32]string{ + 0: "X", +} +var FOO_value = map[string]int32{ + "X": 0, +} + +func (x FOO) String() string { + return proto1.EnumName(FOO_name, int32(x)) +} +func (FOO) EnumDescriptor() ([]byte, []int) { return fileDescriptorTest, []int{0} } + +type TestItem struct { + Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` + Age int32 `protobuf:"varint,2,opt,name=Age,proto3" json:"Age,omitempty"` +} + +func (m *TestItem) Reset() { *m = TestItem{} } +func (m *TestItem) String() string { return proto1.CompactTextString(m) } +func (*TestItem) ProtoMessage() {} +func (*TestItem) Descriptor() ([]byte, []int) { return fileDescriptorTest, []int{0} } + +func (m *TestItem) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func (m *TestItem) GetAge() int32 { + if m != nil { + return m.Age + } + return 0 +} + +func init() { + proto1.RegisterType((*TestItem)(nil), "proto.TestItem") + proto1.RegisterEnum("proto.FOO", FOO_name, FOO_value) +} +func (m *TestItem) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *TestItem) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.Name) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintTest(dAtA, i, uint64(len(m.Name))) + i += copy(dAtA[i:], m.Name) + } + if m.Age != 0 { + dAtA[i] = 0x10 + i++ + i = encodeVarintTest(dAtA, i, uint64(m.Age)) + } + return i, nil +} + +func encodeFixed64Test(dAtA []byte, offset int, v uint64) int { + dAtA[offset] = uint8(v) + dAtA[offset+1] = uint8(v >> 8) + dAtA[offset+2] = uint8(v >> 16) + dAtA[offset+3] = uint8(v >> 24) + dAtA[offset+4] = uint8(v >> 32) + dAtA[offset+5] = uint8(v >> 40) + dAtA[offset+6] = uint8(v >> 48) + dAtA[offset+7] = uint8(v >> 56) + return offset + 8 +} +func encodeFixed32Test(dAtA []byte, offset int, v uint32) int { + dAtA[offset] = uint8(v) + dAtA[offset+1] = uint8(v >> 8) + dAtA[offset+2] = uint8(v >> 16) + dAtA[offset+3] = uint8(v >> 24) + return offset + 4 +} +func encodeVarintTest(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} +func (m *TestItem) Size() (n int) { + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + sovTest(uint64(l)) + } + if m.Age != 0 { + n += 1 + sovTest(uint64(m.Age)) + } + return n +} + +func sovTest(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} +func sozTest(x uint64) (n int) { + return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *TestItem) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: TestItem: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: TestItem: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthTest + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Name = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType) + } + m.Age = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowTest + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Age |= (int32(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipTest(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthTest + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipTest(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthTest + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowTest + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipTest(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowTest = fmt.Errorf("proto: integer overflow") +) + +func init() { proto1.RegisterFile("test.proto", fileDescriptorTest) } + +var fileDescriptorTest = []byte{ + // 122 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, + 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x4a, 0x06, 0x5c, 0x1c, 0x21, 0xa9, + 0xc5, 0x25, 0x9e, 0x25, 0xa9, 0xb9, 0x42, 0x42, 0x5c, 0x2c, 0x7e, 0x89, 0xb9, 0xa9, 0x12, 0x8c, + 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x60, 0xb6, 0x90, 0x00, 0x17, 0xb3, 0x63, 0x7a, 0xaa, 0x04, 0x93, + 0x02, 0xa3, 0x06, 0x6b, 0x10, 0x88, 0xa9, 0xc5, 0xc3, 0xc5, 0xec, 0xe6, 0xef, 0x2f, 0xc4, 0xca, + 0xc5, 0x18, 0x21, 0xc0, 0xe0, 0x24, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f, + 0x1e, 0xc9, 0x31, 0xce, 0x78, 0x2c, 0xc7, 0x90, 0xc4, 0x06, 0x36, 0xd8, 0x18, 0x10, 0x00, 0x00, + 0xff, 0xff, 0x16, 0x80, 0x60, 0x15, 0x6d, 0x00, 0x00, 0x00, +} diff --git a/pkg/cache/memcache/test/test.proto b/pkg/cache/memcache/test/test.proto new file mode 100644 index 000000000..adad15bea --- /dev/null +++ b/pkg/cache/memcache/test/test.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; +package proto; + +enum FOO +{ + X = 0; +}; + +message TestItem{ + string Name = 1; + int32 Age = 2; +} \ No newline at end of file diff --git a/pkg/cache/memcache/trace.go b/pkg/cache/memcache/trace.go deleted file mode 100644 index 589ef0fe9..000000000 --- a/pkg/cache/memcache/trace.go +++ /dev/null @@ -1,109 +0,0 @@ -package memcache - -import ( - "context" - "strconv" - "strings" - "time" - - "github.com/bilibili/kratos/pkg/log" - "github.com/bilibili/kratos/pkg/net/trace" -) - -const ( - _traceFamily = "memcache" - _traceSpanKind = "client" - _traceComponentName = "library/cache/memcache" - _tracePeerService = "memcache" - _slowLogDuration = time.Millisecond * 250 -) - -type traceConn struct { - Conn - ctx context.Context - address string -} - -func (t *traceConn) setTrace(action, statement string) func(error) error { - now := time.Now() - parent, ok := trace.FromContext(t.ctx) - if !ok { - return func(err error) error { return err } - } - span := parent.Fork(_traceFamily, "Memcache:"+action) - span.SetTag( - trace.String(trace.TagSpanKind, _traceSpanKind), - trace.String(trace.TagComponent, _traceComponentName), - trace.String(trace.TagPeerService, _tracePeerService), - trace.String(trace.TagPeerAddress, t.address), - trace.String(trace.TagDBStatement, action+" "+statement), - ) - return func(err error) error { - span.Finish(&err) - t := time.Since(now) - if t > _slowLogDuration { - log.Warn("%s slow log action: %s key: %s time: %v", _traceFamily, action, statement, t) - } - return err - } -} - -func (t *traceConn) WithContext(ctx context.Context) Conn { - t.ctx = ctx - t.Conn = t.Conn.WithContext(ctx) - return t -} - -func (t *traceConn) Add(item *Item) error { - finishFn := t.setTrace("Add", item.Key) - return finishFn(t.Conn.Add(item)) -} - -func (t *traceConn) Set(item *Item) error { - finishFn := t.setTrace("Set", item.Key) - return finishFn(t.Conn.Set(item)) -} - -func (t *traceConn) Replace(item *Item) error { - finishFn := t.setTrace("Replace", item.Key) - return finishFn(t.Conn.Replace(item)) -} - -func (t *traceConn) Get(key string) (*Item, error) { - finishFn := t.setTrace("Get", key) - item, err := t.Conn.Get(key) - return item, finishFn(err) -} - -func (t *traceConn) GetMulti(keys []string) (map[string]*Item, error) { - finishFn := t.setTrace("GetMulti", strings.Join(keys, " ")) - items, err := t.Conn.GetMulti(keys) - return items, finishFn(err) -} - -func (t *traceConn) Delete(key string) error { - finishFn := t.setTrace("Delete", key) - return finishFn(t.Conn.Delete(key)) -} - -func (t *traceConn) Increment(key string, delta uint64) (newValue uint64, err error) { - finishFn := t.setTrace("Increment", key+" "+strconv.FormatUint(delta, 10)) - newValue, err = t.Conn.Increment(key, delta) - return newValue, finishFn(err) -} - -func (t *traceConn) Decrement(key string, delta uint64) (newValue uint64, err error) { - finishFn := t.setTrace("Decrement", key+" "+strconv.FormatUint(delta, 10)) - newValue, err = t.Conn.Decrement(key, delta) - return newValue, finishFn(err) -} - -func (t *traceConn) CompareAndSwap(item *Item) error { - finishFn := t.setTrace("CompareAndSwap", item.Key) - return finishFn(t.Conn.CompareAndSwap(item)) -} - -func (t *traceConn) Touch(key string, seconds int32) (err error) { - finishFn := t.setTrace("Touch", key+" "+strconv.Itoa(int(seconds))) - return finishFn(t.Conn.Touch(key, seconds)) -} diff --git a/pkg/cache/memcache/trace_conn.go b/pkg/cache/memcache/trace_conn.go new file mode 100644 index 000000000..086dab5c7 --- /dev/null +++ b/pkg/cache/memcache/trace_conn.go @@ -0,0 +1,103 @@ +package memcache + +import ( + "context" + "strconv" + "strings" + "time" + + "github.com/bilibili/kratos/pkg/log" + "github.com/bilibili/kratos/pkg/net/trace" +) + +const ( + _slowLogDuration = time.Millisecond * 250 +) + +func newTraceConn(conn Conn, address string) Conn { + tags := []trace.Tag{ + trace.String(trace.TagSpanKind, "client"), + trace.String(trace.TagComponent, "cache/memcache"), + trace.String(trace.TagPeerService, "memcache"), + trace.String(trace.TagPeerAddress, address), + } + return &traceConn{Conn: conn, tags: tags} +} + +type traceConn struct { + Conn + tags []trace.Tag +} + +func (t *traceConn) setTrace(ctx context.Context, action, statement string) func(error) error { + now := time.Now() + parent, ok := trace.FromContext(ctx) + if !ok { + return func(err error) error { return err } + } + span := parent.Fork("", "Memcache:"+action) + span.SetTag(t.tags...) + span.SetTag(trace.String(trace.TagDBStatement, action+" "+statement)) + return func(err error) error { + span.Finish(&err) + t := time.Since(now) + if t > _slowLogDuration { + log.Warn("memcache slow log action: %s key: %s time: %v", action, statement, t) + } + return err + } +} + +func (t *traceConn) AddContext(ctx context.Context, item *Item) error { + finishFn := t.setTrace(ctx, "Add", item.Key) + return finishFn(t.Conn.Add(item)) +} + +func (t *traceConn) SetContext(ctx context.Context, item *Item) error { + finishFn := t.setTrace(ctx, "Set", item.Key) + return finishFn(t.Conn.Set(item)) +} + +func (t *traceConn) ReplaceContext(ctx context.Context, item *Item) error { + finishFn := t.setTrace(ctx, "Replace", item.Key) + return finishFn(t.Conn.Replace(item)) +} + +func (t *traceConn) GetContext(ctx context.Context, key string) (*Item, error) { + finishFn := t.setTrace(ctx, "Get", key) + item, err := t.Conn.Get(key) + return item, finishFn(err) +} + +func (t *traceConn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) { + finishFn := t.setTrace(ctx, "GetMulti", strings.Join(keys, " ")) + items, err := t.Conn.GetMulti(keys) + return items, finishFn(err) +} + +func (t *traceConn) DeleteContext(ctx context.Context, key string) error { + finishFn := t.setTrace(ctx, "Delete", key) + return finishFn(t.Conn.Delete(key)) +} + +func (t *traceConn) IncrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) { + finishFn := t.setTrace(ctx, "Increment", key+" "+strconv.FormatUint(delta, 10)) + newValue, err = t.Conn.Increment(key, delta) + return newValue, finishFn(err) +} + +func (t *traceConn) DecrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) { + finishFn := t.setTrace(ctx, "Decrement", key+" "+strconv.FormatUint(delta, 10)) + newValue, err = t.Conn.Decrement(key, delta) + return newValue, finishFn(err) +} + +func (t *traceConn) CompareAndSwapContext(ctx context.Context, item *Item) error { + finishFn := t.setTrace(ctx, "CompareAndSwap", item.Key) + return finishFn(t.Conn.CompareAndSwap(item)) +} + +func (t *traceConn) TouchContext(ctx context.Context, key string, seconds int32) (err error) { + finishFn := t.setTrace(ctx, "Touch", key+" "+strconv.Itoa(int(seconds))) + return finishFn(t.Conn.Touch(key, seconds)) +} diff --git a/pkg/cache/memcache/util.go b/pkg/cache/memcache/util.go index f35072941..ce64bf1fc 100644 --- a/pkg/cache/memcache/util.go +++ b/pkg/cache/memcache/util.go @@ -1,9 +1,57 @@ package memcache import ( + "context" + "time" + "github.com/gogo/protobuf/proto" ) +func legalKey(key string) bool { + if len(key) > 250 || len(key) == 0 { + return false + } + for i := 0; i < len(key); i++ { + if key[i] <= ' ' || key[i] == 0x7f { + return false + } + } + return true +} + +// MockWith error +func MockWith(err error) Conn { + return errConn{err} +} + +type errConn struct{ err error } + +func (c errConn) Err() error { return c.err } +func (c errConn) Close() error { return c.err } +func (c errConn) Add(*Item) error { return c.err } +func (c errConn) Set(*Item) error { return c.err } +func (c errConn) Replace(*Item) error { return c.err } +func (c errConn) CompareAndSwap(*Item) error { return c.err } +func (c errConn) Get(string) (*Item, error) { return nil, c.err } +func (c errConn) GetMulti([]string) (map[string]*Item, error) { return nil, c.err } +func (c errConn) Touch(string, int32) error { return c.err } +func (c errConn) Delete(string) error { return c.err } +func (c errConn) Increment(string, uint64) (uint64, error) { return 0, c.err } +func (c errConn) Decrement(string, uint64) (uint64, error) { return 0, c.err } +func (c errConn) Scan(*Item, interface{}) error { return c.err } +func (c errConn) AddContext(context.Context, *Item) error { return c.err } +func (c errConn) SetContext(context.Context, *Item) error { return c.err } +func (c errConn) ReplaceContext(context.Context, *Item) error { return c.err } +func (c errConn) GetContext(context.Context, string) (*Item, error) { return nil, c.err } +func (c errConn) DecrementContext(context.Context, string, uint64) (uint64, error) { return 0, c.err } +func (c errConn) CompareAndSwapContext(context.Context, *Item) error { return c.err } +func (c errConn) TouchContext(context.Context, string, int32) error { return c.err } +func (c errConn) DeleteContext(context.Context, string) error { return c.err } +func (c errConn) IncrementContext(context.Context, string, uint64) (uint64, error) { return 0, c.err } +func (c errConn) GetMultiContext(context.Context, []string) (map[string]*Item, error) { + return nil, c.err +} + // RawItem item with FlagRAW flag. // // Expiration is the cache expiration time, in seconds: either a relative @@ -30,3 +78,12 @@ func JSONItem(key string, v interface{}, flags uint32, expiration int32) *Item { func ProtobufItem(key string, message proto.Message, flags uint32, expiration int32) *Item { return &Item{Key: key, Flags: flags | FlagProtobuf, Object: message, Expiration: expiration} } + +func shrinkDeadline(ctx context.Context, timeout time.Duration) time.Time { + // TODO: ignored context deadline to compatible old behaviour. + //deadline, ok := ctx.Deadline() + //if ok { + // return deadline + //} + return time.Now().Add(timeout) +} diff --git a/pkg/cache/memcache/util_test.go b/pkg/cache/memcache/util_test.go new file mode 100644 index 000000000..34b66c290 --- /dev/null +++ b/pkg/cache/memcache/util_test.go @@ -0,0 +1,75 @@ +package memcache + +import ( + "testing" + + pb "github.com/bilibili/kratos/pkg/cache/memcache/test" + + "github.com/stretchr/testify/assert" +) + +func TestItemUtil(t *testing.T) { + item1 := RawItem("test", []byte("hh"), 0, 0) + assert.Equal(t, "test", item1.Key) + assert.Equal(t, []byte("hh"), item1.Value) + assert.Equal(t, FlagRAW, FlagRAW&item1.Flags) + + item1 = JSONItem("test", &Item{}, 0, 0) + assert.Equal(t, "test", item1.Key) + assert.NotNil(t, item1.Object) + assert.Equal(t, FlagJSON, FlagJSON&item1.Flags) + + item1 = ProtobufItem("test", &pb.TestItem{}, 0, 0) + assert.Equal(t, "test", item1.Key) + assert.NotNil(t, item1.Object) + assert.Equal(t, FlagProtobuf, FlagProtobuf&item1.Flags) +} + +func TestLegalKey(t *testing.T) { + type args struct { + key string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "test empty key", + want: false, + }, + { + name: "test too large key", + args: args{func() string { + var data []byte + for i := 0; i < 255; i++ { + data = append(data, 'k') + } + return string(data) + }()}, + want: false, + }, + { + name: "test invalid char", + args: args{"hello world"}, + want: false, + }, + { + name: "test invalid char", + args: args{string([]byte{0x7f})}, + want: false, + }, + { + name: "test normal key", + args: args{"hello"}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := legalKey(tt.args.key); got != tt.want { + t.Errorf("legalKey() = %v, want %v", got, tt.want) + } + }) + } +}