1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-07 23:02:12 +02:00

Merge pull request #140 from bilibili/update/memcache

update memcache sdk
This commit is contained in:
Tony 2019-06-06 18:06:56 +08:00 committed by GitHub
commit 3f2c51b56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 3782 additions and 1138 deletions

View File

@ -1,25 +0,0 @@
# cache/memcache
##### 项目简介
1. 提供protobuf,gob,json序列化方式,gzip的memcache接口
#### 使用方式
```golang
// 初始化 注意这里只是示例 展示用法 不能每次都New 只需要初始化一次
mc := memcache.New(&memcache.Config{})
// 程序关闭的时候调用close方法
defer mc.Close()
// 增加 key
err = mc.Set(c, &memcache.Item{})
// 删除key
err := mc.Delete(c,key)
// 获得某个key的内容
err := mc.Get(c,key).Scan(&v)
// 获取多个key的内容
replies, err := mc.GetMulti(c, keys)
for _, key := range replies.Keys() {
if err = replies.Scan(key, &v); err != nil {
return
}
}
```

261
pkg/cache/memcache/ascii_conn.go vendored Normal file
View File

@ -0,0 +1,261 @@
package memcache
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
pkgerr "github.com/pkg/errors"
)
var (
crlf = []byte("\r\n")
space = []byte(" ")
replyOK = []byte("OK\r\n")
replyStored = []byte("STORED\r\n")
replyNotStored = []byte("NOT_STORED\r\n")
replyExists = []byte("EXISTS\r\n")
replyNotFound = []byte("NOT_FOUND\r\n")
replyDeleted = []byte("DELETED\r\n")
replyEnd = []byte("END\r\n")
replyTouched = []byte("TOUCHED\r\n")
replyClientErrorPrefix = []byte("CLIENT_ERROR ")
replyServerErrorPrefix = []byte("SERVER_ERROR ")
)
var _ protocolConn = &asiiConn{}
// asiiConn is the low-level implementation of Conn
type asiiConn struct {
err error
conn net.Conn
// Read & Write
readTimeout time.Duration
writeTimeout time.Duration
rw *bufio.ReadWriter
}
func replyToError(line []byte) error {
switch {
case bytes.Equal(line, replyStored):
return nil
case bytes.Equal(line, replyOK):
return nil
case bytes.Equal(line, replyDeleted):
return nil
case bytes.Equal(line, replyTouched):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *asiiConn) Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error {
c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout))
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
var err error
if cmd == "cas" {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n", cmd, key, flags, expiration, len(data), cas)
} else {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n", cmd, key, flags, expiration, len(data))
}
if err != nil {
return c.fatal(err)
}
c.rw.Write(data)
c.rw.Write(crlf)
if err = c.rw.Flush(); err != nil {
return c.fatal(err)
}
c.conn.SetReadDeadline(shrinkDeadline(ctx, c.readTimeout))
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
return replyToError(line)
}
// newConn returns a new memcache connection for the given net connection.
func newASCIIConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) (protocolConn, error) {
if writeTimeout <= 0 || readTimeout <= 0 {
return nil, pkgerr.Errorf("readTimeout writeTimeout can't be zero")
}
c := &asiiConn{
conn: netConn,
rw: bufio.NewReadWriter(bufio.NewReader(netConn),
bufio.NewWriter(netConn)),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
return c, nil
}
func (c *asiiConn) Close() error {
if c.err == nil {
c.err = pkgerr.New("memcache: closed")
}
return c.conn.Close()
}
func (c *asiiConn) fatal(err error) error {
if c.err == nil {
c.err = pkgerr.WithStack(err)
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
return c.err
}
func (c *asiiConn) Err() error {
return c.err
}
func (c *asiiConn) Get(ctx context.Context, key string) (result *Item, err error) {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
if err = c.parseGetReply(func(it *Item) {
result = it
}); err != nil {
return
}
if result == nil {
return nil, ErrNotFound
}
return
}
func (c *asiiConn) GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error) {
var err error
c.conn.SetWriteDeadline(shrinkDeadline(ctx, c.writeTimeout))
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
results := make(map[string]*Item, len(keys))
if err = c.parseGetReply(func(it *Item) {
results[it.Key] = it
}); err != nil {
return nil, err
}
return results, nil
}
func (c *asiiConn) parseGetReply(f func(*Item)) error {
c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout))
for {
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
if bytes.Equal(line, replyEnd) {
return nil
}
if bytes.HasPrefix(line, replyServerErrorPrefix) {
errMsg := line[len(replyServerErrorPrefix):]
return c.fatal(protocolError(errMsg))
}
it := new(Item)
size, err := scanGetReply(line, it)
if err != nil {
return c.fatal(err)
}
it.Value = make([]byte, size+2)
if _, err = io.ReadFull(c.rw, it.Value); err != nil {
return c.fatal(err)
}
if !bytes.HasSuffix(it.Value, crlf) {
return c.fatal(protocolError("corrupt get reply, no except CRLF"))
}
it.Value = it.Value[:size]
f(it)
}
}
func scanGetReply(line []byte, item *Item) (size int, err error) {
pattern := "VALUE %s %d %d %d\r\n"
dest := []interface{}{&item.Key, &item.Flags, &size, &item.cas}
if bytes.Count(line, space) == 3 {
pattern = "VALUE %s %d %d\r\n"
dest = dest[:3]
}
n, err := fmt.Sscanf(string(line), pattern, dest...)
if err != nil || n != len(dest) {
return -1, fmt.Errorf("memcache: unexpected line in get response: %q", line)
}
return size, nil
}
func (c *asiiConn) Touch(ctx context.Context, key string, expire int32) error {
line, err := c.writeReadLine("touch %s %d\r\n", key, expire)
if err != nil {
return err
}
return replyToError(line)
}
func (c *asiiConn) IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error) {
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta)
if err != nil {
return 0, err
}
switch {
case bytes.Equal(line, replyNotFound):
return 0, ErrNotFound
case bytes.HasPrefix(line, replyClientErrorPrefix):
errMsg := line[len(replyClientErrorPrefix):]
return 0, pkgerr.WithStack(protocolError(errMsg))
}
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64)
if err != nil {
return 0, err
}
return val, nil
}
func (c *asiiConn) Delete(ctx context.Context, key string) error {
line, err := c.writeReadLine("delete %s\r\n", key)
if err != nil {
return err
}
return replyToError(line)
}
func (c *asiiConn) writeReadLine(format string, args ...interface{}) ([]byte, error) {
c.conn.SetWriteDeadline(shrinkDeadline(context.TODO(), c.writeTimeout))
_, err := fmt.Fprintf(c.rw, format, args...)
if err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
c.conn.SetReadDeadline(shrinkDeadline(context.TODO(), c.readTimeout))
line, err := c.rw.ReadSlice('\n')
if err != nil {
return line, c.fatal(pkgerr.WithStack(err))
}
return line, nil
}

569
pkg/cache/memcache/ascii_conn_test.go vendored Normal file
View File

@ -0,0 +1,569 @@
package memcache
import (
"bytes"
"strconv"
"strings"
"testing"
)
func TestASCIIConnAdd(t *testing.T) {
tests := []struct {
name string
a *Item
e error
}{
{
"Add",
&Item{
Key: "test_add",
Value: []byte("0"),
Flags: 0,
Expiration: 60,
},
nil,
},
{
"Add_Large",
&Item{
Key: "test_add_large",
Value: bytes.Repeat(space, _largeValue+1),
Flags: 0,
Expiration: 60,
},
nil,
},
{
"Add_Exist",
&Item{
Key: "test_add",
Value: []byte("0"),
Flags: 0,
Expiration: 60,
},
ErrNotStored,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := testConnASCII.Add(test.a); err != test.e {
t.Fatal(err)
}
if b, err := testConnASCII.Get(test.a.Key); err != nil {
t.Fatal(err)
} else {
compareItem(t, test.a, b)
}
})
}
}
func TestASCIIConnGet(t *testing.T) {
tests := []struct {
name string
a *Item
k string
e error
}{
{
"Get",
&Item{
Key: "test_get",
Value: []byte("0"),
Flags: 0,
Expiration: 60,
},
"test_get",
nil,
},
{
"Get_NotExist",
&Item{
Key: "test_get_not_exist",
Value: []byte("0"),
Flags: 0,
Expiration: 60,
},
"test_get_not_exist!",
ErrNotFound,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := testConnASCII.Add(test.a); err != nil {
t.Fatal(err)
}
if b, err := testConnASCII.Get(test.a.Key); err != nil {
t.Fatal(err)
} else {
compareItem(t, test.a, b)
}
})
}
}
//func TestGetHasErr(t *testing.T) {
// prepareEnv(t)
//
// st := &TestItem{Name: "json", Age: 10}
// itemx := &Item{Key: "test", Object: st, Flags: FlagJSON}
// c.Set(itemx)
//
// expected := errors.New("some error")
// monkey.Patch(scanGetReply, func(line []byte, item *Item) (size int, err error) {
// return 0, expected
// })
//
// if _, err := c.Get("test"); err.Error() != expected.Error() {
// t.Errorf("conn.Get() unexpected error(%v)", err)
// }
// if err := c.(*asciiConn).err; err.Error() != expected.Error() {
// t.Errorf("unexpected error(%v)", err)
// }
//}
func TestASCIIConnGetMulti(t *testing.T) {
tests := []struct {
name string
a []*Item
k []string
e error
}{
{"getMulti_Add",
[]*Item{
{
Key: "get_multi_1",
Value: []byte("test"),
Flags: FlagRAW,
Expiration: 60,
cas: 0,
},
{
Key: "get_multi_2",
Value: []byte("test2"),
Flags: FlagRAW,
Expiration: 60,
cas: 0,
},
},
[]string{"get_multi_1", "get_multi_2"},
nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
for _, i := range test.a {
if err := testConnASCII.Set(i); err != nil {
t.Fatal(err)
}
}
if r, err := testConnASCII.GetMulti(test.k); err != nil {
t.Fatal(err)
} else {
reply := r["get_multi_1"]
compareItem(t, reply, test.a[0])
reply = r["get_multi_2"]
compareItem(t, reply, test.a[1])
}
})
}
}
func TestASCIIConnSet(t *testing.T) {
tests := []struct {
name string
a *Item
e error
}{
{
"SetLowerBound",
&Item{
Key: strings.Repeat("a", 1),
Value: []byte("4"),
Flags: 0,
Expiration: 60,
},
nil,
},
{
"SetUpperBound",
&Item{
Key: strings.Repeat("a", 250),
Value: []byte("3"),
Flags: 0,
Expiration: 60,
},
nil,
},
{
"SetIllegalKeyZeroLength",
&Item{
Key: "",
Value: []byte("2"),
Flags: 0,
Expiration: 60,
},
ErrMalformedKey,
},
{
"SetIllegalKeyLengthExceededLimit",
&Item{
Key: " ",
Value: []byte("1"),
Flags: 0,
Expiration: 60,
},
ErrMalformedKey,
},
{
"SeJsonItem",
&Item{
Key: "set_obj",
Object: &struct {
Name string
Age int
}{"json", 10},
Expiration: 60,
Flags: FlagJSON,
},
nil,
},
{
"SeErrItemJSONGzip",
&Item{
Key: "set_err_item",
Expiration: 60,
Flags: FlagJSON | FlagGzip,
},
ErrItem,
},
{
"SeErrItemBytesValueWrongFlag",
&Item{
Key: "set_err_item",
Value: []byte("2"),
Expiration: 60,
Flags: FlagJSON,
},
ErrItem,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := testConnASCII.Set(test.a); err != test.e {
t.Fatal(err)
}
})
}
}
func TestASCIIConnCompareAndSwap(t *testing.T) {
tests := []struct {
name string
a *Item
b *Item
c *Item
k string
e error
}{
{
"CompareAndSwap",
&Item{
Key: "test_cas",
Value: []byte("2"),
Flags: 0,
Expiration: 60,
},
nil,
&Item{
Key: "test_cas",
Value: []byte("3"),
Flags: 0,
Expiration: 60,
},
"test_cas",
nil,
},
{
"CompareAndSwapErrCASConflict",
&Item{
Key: "test_cas_conflict",
Value: []byte("2"),
Flags: 0,
Expiration: 60,
},
&Item{
Key: "test_cas_conflict",
Value: []byte("1"),
Flags: 0,
Expiration: 60,
},
&Item{
Key: "test_cas_conflict",
Value: []byte("3"),
Flags: 0,
Expiration: 60,
},
"test_cas_conflict",
ErrCASConflict,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := testConnASCII.Set(test.a); err != nil {
t.Fatal(err)
}
r, err := testConnASCII.Get(test.k)
if err != nil {
t.Fatal(err)
}
if test.b != nil {
if err := testConnASCII.Set(test.b); err != nil {
t.Fatal(err)
}
}
r.Value = test.c.Value
if err := testConnASCII.CompareAndSwap(r); err != nil {
if err != test.e {
t.Fatal(err)
}
} else {
if fr, err := testConnASCII.Get(test.k); err != nil {
t.Fatal(err)
} else {
compareItem(t, fr, test.c)
}
}
})
}
t.Run("TestCompareAndSwapErrNotFound", func(t *testing.T) {
ti := &Item{
Key: "test_cas_notfound",
Value: []byte("2"),
Flags: 0,
Expiration: 60,
}
if err := testConnASCII.Set(ti); err != nil {
t.Fatal(err)
}
r, err := testConnASCII.Get(ti.Key)
if err != nil {
t.Fatal(err)
}
r.Key = "test_cas_notfound_boom"
r.Value = []byte("3")
if err := testConnASCII.CompareAndSwap(r); err != nil {
if err != ErrNotFound {
t.Fatal(err)
}
}
})
}
func TestASCIIConnReplace(t *testing.T) {
tests := []struct {
name string
a *Item
b *Item
e error
}{
{
"TestReplace",
&Item{
Key: "test_replace",
Value: []byte("2"),
Flags: 0,
Expiration: 60,
},
&Item{
Key: "test_replace",
Value: []byte("3"),
Flags: 0,
Expiration: 60,
},
nil,
},
{
"TestReplaceErrNotStored",
&Item{
Key: "test_replace_not_stored",
Value: []byte("2"),
Flags: 0,
Expiration: 60,
},
&Item{
Key: "test_replace_not_stored_boom",
Value: []byte("3"),
Flags: 0,
Expiration: 60,
},
ErrNotStored,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err := testConnASCII.Set(test.a); err != nil {
t.Fatal(err)
}
if err := testConnASCII.Replace(test.b); err != nil {
if err == test.e {
return
}
t.Fatal(err)
}
if r, err := testConnASCII.Get(test.b.Key); err != nil {
t.Fatal(err)
} else {
compareItem(t, r, test.b)
}
})
}
}
func TestASCIIConnIncrDecr(t *testing.T) {
tests := []struct {
fn func(key string, delta uint64) (uint64, error)
name string
k string
v uint64
w uint64
}{
{
testConnASCII.Increment,
"Incr_10",
"test_incr",
10,
10,
},
{
testConnASCII.Increment,
"Incr_10(2)",
"test_incr",
10,
20,
},
{
testConnASCII.Decrement,
"Decr_10",
"test_incr",
10,
10,
},
}
if err := testConnASCII.Add(&Item{
Key: "test_incr",
Value: []byte("0"),
}); err != nil {
t.Fatal(err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if a, err := test.fn(test.k, test.v); err != nil {
t.Fatal(err)
} else {
if a != test.w {
t.Fatalf("want %d, got %d", test.w, a)
}
}
if b, err := testConnASCII.Get(test.k); err != nil {
t.Fatal(err)
} else {
if string(b.Value) != strconv.FormatUint(test.w, 10) {
t.Fatalf("want %s, got %d", b.Value, test.w)
}
}
})
}
}
func TestASCIIConnTouch(t *testing.T) {
tests := []struct {
name string
k string
a *Item
e error
}{
{
"Touch",
"test_touch",
&Item{
Key: "test_touch",
Value: []byte("0"),
Expiration: 60,
},
nil,
},
{
"Touch_NotExist",
"test_touch_not_exist",
nil,
ErrNotFound,
},
}
for _, test := range tests {
if test.a != nil {
if err := testConnASCII.Add(test.a); err != nil {
t.Fatal(err)
}
if err := testConnASCII.Touch(test.k, 1); err != test.e {
t.Fatal(err)
}
}
}
}
func TestASCIIConnDelete(t *testing.T) {
tests := []struct {
name string
k string
a *Item
e error
}{
{
"Delete",
"test_delete",
&Item{
Key: "test_delete",
Value: []byte("0"),
Expiration: 60,
},
nil,
},
{
"Delete_NotExist",
"test_delete_not_exist",
nil,
ErrNotFound,
},
}
for _, test := range tests {
if test.a != nil {
if err := testConnASCII.Add(test.a); err != nil {
t.Fatal(err)
}
if err := testConnASCII.Delete(test.k); err != test.e {
t.Fatal(err)
}
if _, err := testConnASCII.Get(test.k); err != ErrNotFound {
t.Fatal(err)
}
}
}
}
func compareItem(t *testing.T, a, b *Item) {
if a.Key != b.Key || !bytes.Equal(a.Value, b.Value) || a.Flags != b.Flags {
t.Fatalf("compareItem: a(%s, %d, %d) : b(%s, %d, %d)", a.Key, len(a.Value), a.Flags, b.Key, len(b.Value), b.Flags)
}
}

View File

@ -1,187 +0,0 @@
package memcache
import (
"context"
)
// Memcache memcache client
type Memcache struct {
pool *Pool
}
// Reply is the result of Get
type Reply struct {
err error
item *Item
conn Conn
closed bool
}
// Replies is the result of GetMulti
type Replies struct {
err error
items map[string]*Item
usedItems map[string]struct{}
conn Conn
closed bool
}
// New get a memcache client
func New(c *Config) *Memcache {
return &Memcache{pool: NewPool(c)}
}
// Close close connection pool
func (mc *Memcache) Close() error {
return mc.pool.Close()
}
// Conn direct get a connection
func (mc *Memcache) Conn(c context.Context) Conn {
return mc.pool.Get(c)
}
// Set writes the given item, unconditionally.
func (mc *Memcache) Set(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Set(item)
conn.Close()
return
}
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
func (mc *Memcache) Add(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Add(item)
conn.Close()
return
}
// Replace writes the given item, but only if the server *does* already hold data for this key.
func (mc *Memcache) Replace(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.Replace(item)
conn.Close()
return
}
// CompareAndSwap writes the given item that was previously returned by Get
func (mc *Memcache) CompareAndSwap(c context.Context, item *Item) (err error) {
conn := mc.pool.Get(c)
err = conn.CompareAndSwap(item)
conn.Close()
return
}
// Get sends a command to the server for gets data.
func (mc *Memcache) Get(c context.Context, key string) *Reply {
conn := mc.pool.Get(c)
item, err := conn.Get(key)
if err != nil {
conn.Close()
}
return &Reply{err: err, item: item, conn: conn}
}
// Item get raw Item
func (r *Reply) Item() *Item {
return r.item
}
// Scan converts value, read from the memcache
func (r *Reply) Scan(v interface{}) (err error) {
if r.err != nil {
return r.err
}
err = r.conn.Scan(r.item, v)
if !r.closed {
r.conn.Close()
r.closed = true
}
return
}
// GetMulti is a batch version of Get
func (mc *Memcache) GetMulti(c context.Context, keys []string) (*Replies, error) {
conn := mc.pool.Get(c)
items, err := conn.GetMulti(keys)
rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))}
if (err != nil) || (len(items) == 0) {
rs.Close()
}
return rs, err
}
// Close close rows.
func (rs *Replies) Close() (err error) {
if !rs.closed {
err = rs.conn.Close()
rs.closed = true
}
return
}
// Item get Item from rows
func (rs *Replies) Item(key string) *Item {
return rs.items[key]
}
// Scan converts value, read from key in rows
func (rs *Replies) Scan(key string, v interface{}) (err error) {
if rs.err != nil {
return rs.err
}
item, ok := rs.items[key]
if !ok {
rs.Close()
return ErrNotFound
}
rs.usedItems[key] = struct{}{}
err = rs.conn.Scan(item, v)
if (err != nil) || (len(rs.items) == len(rs.usedItems)) {
rs.Close()
}
return
}
// Keys keys of result
func (rs *Replies) Keys() (keys []string) {
keys = make([]string, 0, len(rs.items))
for key := range rs.items {
keys = append(keys, key)
}
return
}
// Touch updates the expiry for the given key.
func (mc *Memcache) Touch(c context.Context, key string, timeout int32) (err error) {
conn := mc.pool.Get(c)
err = conn.Touch(key, timeout)
conn.Close()
return
}
// Delete deletes the item with the provided key.
func (mc *Memcache) Delete(c context.Context, key string) (err error) {
conn := mc.pool.Get(c)
err = conn.Delete(key)
conn.Close()
return
}
// Increment atomically increments key by delta.
func (mc *Memcache) Increment(c context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(c)
newValue, err = conn.Increment(key, delta)
conn.Close()
return
}
// Decrement atomically decrements key by delta.
func (mc *Memcache) Decrement(c context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(c)
newValue, err = conn.Decrement(key, delta)
conn.Close()
return
}

View File

@ -1,78 +1,30 @@
package memcache
import (
"bufio"
"bytes"
"compress/gzip"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"io"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/gogo/protobuf/proto"
pkgerr "github.com/pkg/errors"
)
var (
crlf = []byte("\r\n")
spaceStr = string(" ")
replyOK = []byte("OK\r\n")
replyStored = []byte("STORED\r\n")
replyNotStored = []byte("NOT_STORED\r\n")
replyExists = []byte("EXISTS\r\n")
replyNotFound = []byte("NOT_FOUND\r\n")
replyDeleted = []byte("DELETED\r\n")
replyEnd = []byte("END\r\n")
replyTouched = []byte("TOUCHED\r\n")
replyValueStr = "VALUE"
replyClientErrorPrefix = []byte("CLIENT_ERROR ")
replyServerErrorPrefix = []byte("SERVER_ERROR ")
)
const (
_encodeBuf = 4096 // 4kb
// 1024*1024 - 1, set error???
_largeValue = 1000 * 1000 // 1MB
)
type reader struct {
io.Reader
}
func (r *reader) Reset(rd io.Reader) {
r.Reader = rd
}
// conn is the low-level implementation of Conn
type conn struct {
// Shared
mu sync.Mutex
err error
conn net.Conn
// Read & Write
readTimeout time.Duration
writeTimeout time.Duration
rw *bufio.ReadWriter
// Item Reader
ir bytes.Reader
// Compress
gr gzip.Reader
gw *gzip.Writer
cb bytes.Buffer
// Encoding
edb bytes.Buffer
// json
jr reader
jd *json.Decoder
je *json.Encoder
// protobuffer
ped *proto.Buffer
// low level connection that implement memcache protocol provide basic operation.
type protocolConn interface {
Populate(ctx context.Context, cmd string, key string, flags uint32, expiration int32, cas uint64, data []byte) error
Get(ctx context.Context, key string) (*Item, error)
GetMulti(ctx context.Context, keys ...string) (map[string]*Item, error)
Touch(ctx context.Context, key string, expire int32) error
IncrDecr(ctx context.Context, cmd, key string, delta uint64) (uint64, error)
Delete(ctx context.Context, key string) error
Close() error
Err() error
}
// DialOption specifies an option for dialing a Memcache server.
@ -83,6 +35,7 @@ type DialOption struct {
type dialOptions struct {
readTimeout time.Duration
writeTimeout time.Duration
protocol string
dial func(network, addr string) (net.Conn, error)
}
@ -130,556 +83,205 @@ func Dial(network, address string, options ...DialOption) (Conn, error) {
if err != nil {
return nil, pkgerr.WithStack(err)
}
return NewConn(netConn, do.readTimeout, do.writeTimeout), nil
pconn, err := newASCIIConn(netConn, do.readTimeout, do.writeTimeout)
return &conn{pconn: pconn, ed: newEncodeDecoder()}, nil
}
// NewConn returns a new memcache connection for the given net connection.
func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
if writeTimeout <= 0 || readTimeout <= 0 {
panic("must config memcache timeout")
}
c := &conn{
conn: netConn,
rw: bufio.NewReadWriter(bufio.NewReader(netConn),
bufio.NewWriter(netConn)),
readTimeout: readTimeout,
writeTimeout: writeTimeout,
}
c.jd = json.NewDecoder(&c.jr)
c.je = json.NewEncoder(&c.edb)
c.gw = gzip.NewWriter(&c.cb)
c.edb.Grow(_encodeBuf)
// NOTE reuse bytes.Buffer internal buf
// DON'T concurrency call Scan
c.ped = proto.NewBuffer(c.edb.Bytes())
return c
type conn struct {
// low level connection.
pconn protocolConn
ed *encodeDecode
}
func (c *conn) Close() error {
c.mu.Lock()
err := c.err
if c.err == nil {
c.err = pkgerr.New("memcache: closed")
err = c.conn.Close()
}
c.mu.Unlock()
return err
}
func (c *conn) fatal(err error) error {
c.mu.Lock()
if c.err == nil {
c.err = pkgerr.WithStack(err)
// Close connection to force errors on subsequent calls and to unblock
// other reader or writer.
c.conn.Close()
}
c.mu.Unlock()
return c.err
return c.pconn.Close()
}
func (c *conn) Err() error {
c.mu.Lock()
err := c.err
c.mu.Unlock()
return err
return c.pconn.Err()
}
func (c *conn) Add(item *Item) error {
return c.populate("add", item)
func (c *conn) AddContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "add", item)
}
func (c *conn) Set(item *Item) error {
return c.populate("set", item)
func (c *conn) SetContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "set", item)
}
func (c *conn) Replace(item *Item) error {
return c.populate("replace", item)
func (c *conn) ReplaceContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "replace", item)
}
func (c *conn) CompareAndSwap(item *Item) error {
return c.populate("cas", item)
func (c *conn) CompareAndSwapContext(ctx context.Context, item *Item) error {
return c.populate(ctx, "cas", item)
}
func (c *conn) populate(cmd string, item *Item) (err error) {
func (c *conn) populate(ctx context.Context, cmd string, item *Item) error {
if !legalKey(item.Key) {
return pkgerr.WithStack(ErrMalformedKey)
return ErrMalformedKey
}
var res []byte
if res, err = c.encode(item); err != nil {
return
}
l := len(res)
count := l/(_largeValue) + 1
if count == 1 {
item.Value = res
return c.populateOne(cmd, item)
}
nItem := &Item{
Key: item.Key,
Value: []byte(strconv.Itoa(l)),
Expiration: item.Expiration,
Flags: item.Flags | flagLargeValue,
}
err = c.populateOne(cmd, nItem)
data, err := c.ed.encode(item)
if err != nil {
return
return err
}
k := item.Key
nItem.Flags = item.Flags
length := len(data)
if length < _largeValue {
return c.pconn.Populate(ctx, cmd, item.Key, item.Flags, item.Expiration, item.cas, data)
}
count := length/_largeValue + 1
if err = c.pconn.Populate(ctx, cmd, item.Key, item.Flags|flagLargeValue, item.Expiration, item.cas, []byte(strconv.Itoa(length))); err != nil {
return err
}
var chunk []byte
for i := 1; i <= count; i++ {
if i == count {
nItem.Value = res[_largeValue*(count-1):]
chunk = data[_largeValue*(count-1):]
} else {
nItem.Value = res[_largeValue*(i-1) : _largeValue*i]
chunk = data[_largeValue*(i-1) : _largeValue*i]
}
nItem.Key = fmt.Sprintf("%s%d", k, i)
if err = c.populateOne(cmd, nItem); err != nil {
return
key := fmt.Sprintf("%s%d", item.Key, i)
if err = c.pconn.Populate(ctx, cmd, key, item.Flags, item.Expiration, item.cas, chunk); err != nil {
return err
}
}
return
return nil
}
func (c *conn) populateOne(cmd string, item *Item) (err error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
// <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
if cmd == "cas" {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n",
cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas)
} else {
_, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n",
cmd, item.Key, item.Flags, item.Expiration, len(item.Value))
}
if err != nil {
return c.fatal(err)
}
c.rw.Write(item.Value)
c.rw.Write(crlf)
if err = c.rw.Flush(); err != nil {
return c.fatal(err)
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
switch {
case bytes.Equal(line, replyStored):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *conn) Get(key string) (r *Item, err error) {
func (c *conn) GetContext(ctx context.Context, key string) (*Item, error) {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
return nil, ErrMalformedKey
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
result, err := c.pconn.Get(ctx, key)
if err != nil {
return nil, err
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil {
return nil, c.fatal(err)
if result.Flags&flagLargeValue != flagLargeValue {
return result, err
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
if err = c.parseGetReply(func(it *Item) {
r = it
}); err != nil {
return
}
if r == nil {
err = ErrNotFound
return
}
if r.Flags&flagLargeValue != flagLargeValue {
return
}
if r, err = c.getLargeValue(r); err != nil {
return
}
return
return c.getLargeItem(ctx, result)
}
func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) {
func (c *conn) getLargeItem(ctx context.Context, result *Item) (*Item, error) {
length, err := strconv.Atoi(string(result.Value))
if err != nil {
return nil, err
}
count := length/_largeValue + 1
keys := make([]string, 0, count)
for i := 1; i <= count; i++ {
keys = append(keys, fmt.Sprintf("%s%d", result.Key, i))
}
var results map[string]*Item
if results, err = c.pconn.GetMulti(ctx, keys...); err != nil {
return nil, err
}
if len(results) < count {
return nil, ErrNotFound
}
result.Value = make([]byte, 0, length)
for _, k := range keys {
ti := results[k]
if ti == nil || ti.Value == nil {
return nil, ErrNotFound
}
result.Value = append(result.Value, ti.Value...)
}
result.Flags = result.Flags ^ flagLargeValue
return result, nil
}
func (c *conn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) {
// TODO: move to protocolConn?
for _, key := range keys {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
return nil, ErrMalformedKey
}
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
results, err := c.pconn.GetMulti(ctx, keys...)
if err != nil {
return results, err
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
res = make(map[string]*Item, len(keys))
if err = c.parseGetReply(func(it *Item) {
res[it.Key] = it
}); err != nil {
return
}
for k, v := range res {
for k, v := range results {
if v.Flags&flagLargeValue != flagLargeValue {
continue
}
r, err := c.getLargeValue(v)
if err != nil {
return res, err
if v, err = c.getLargeItem(ctx, v); err != nil {
return results, err
}
res[k] = r
results[k] = v
}
return
return results, nil
}
func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) {
for _, key := range keys {
if !legalKey(key) {
return nil, pkgerr.WithStack(ErrMalformedKey)
}
}
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
return nil, c.fatal(err)
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(err)
}
res = make(map[string]*Item, len(keys))
err = c.parseGetReply(func(it *Item) {
res[it.Key] = it
})
return
}
func (c *conn) getLargeValue(it *Item) (r *Item, err error) {
l, err := strconv.Atoi(string(it.Value))
if err != nil {
return
}
count := l/_largeValue + 1
keys := make([]string, 0, count)
for i := 1; i <= count; i++ {
keys = append(keys, fmt.Sprintf("%s%d", it.Key, i))
}
items, err := c.getMulti(keys)
if err != nil {
return
}
if len(items) < count {
err = ErrNotFound
return
}
v := make([]byte, 0, l)
for _, k := range keys {
if items[k] == nil || items[k].Value == nil {
err = ErrNotFound
return
}
v = append(v, items[k].Value...)
}
it.Value = v
it.Flags = it.Flags ^ flagLargeValue
r = it
return
}
func (c *conn) parseGetReply(f func(*Item)) error {
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
for {
line, err := c.rw.ReadSlice('\n')
if err != nil {
return c.fatal(err)
}
if bytes.Equal(line, replyEnd) {
return nil
}
if bytes.HasPrefix(line, replyServerErrorPrefix) {
errMsg := line[len(replyServerErrorPrefix):]
return c.fatal(protocolError(errMsg))
}
it := new(Item)
size, err := scanGetReply(line, it)
if err != nil {
return c.fatal(err)
}
it.Value = make([]byte, size+2)
if _, err = io.ReadFull(c.rw, it.Value); err != nil {
return c.fatal(err)
}
if !bytes.HasSuffix(it.Value, crlf) {
return c.fatal(protocolError("corrupt get reply, no except CRLF"))
}
it.Value = it.Value[:size]
f(it)
}
}
func scanGetReply(line []byte, item *Item) (size int, err error) {
if !bytes.HasSuffix(line, crlf) {
return 0, protocolError("corrupt get reply, no except CRLF")
}
// VALUE <key> <flags> <bytes> [<cas unique>]
chunks := strings.Split(string(line[:len(line)-2]), spaceStr)
if len(chunks) < 4 {
return 0, protocolError("corrupt get reply")
}
if chunks[0] != replyValueStr {
return 0, protocolError("corrupt get reply, no except VALUE")
}
item.Key = chunks[1]
flags64, err := strconv.ParseUint(chunks[2], 10, 32)
if err != nil {
return 0, err
}
item.Flags = uint32(flags64)
if size, err = strconv.Atoi(chunks[3]); err != nil {
return
}
if len(chunks) > 4 {
item.cas, err = strconv.ParseUint(chunks[4], 10, 64)
}
return
}
func (c *conn) Touch(key string, expire int32) (err error) {
func (c *conn) DeleteContext(ctx context.Context, key string) error {
if !legalKey(key) {
return pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("touch %s %d\r\n", key, expire)
if err != nil {
return err
}
switch {
case bytes.Equal(line, replyTouched):
return nil
case bytes.Equal(line, replyNotFound):
return ErrNotFound
default:
return pkgerr.WithStack(protocolError(string(line)))
return ErrMalformedKey
}
return c.pconn.Delete(ctx, key)
}
func (c *conn) Increment(key string, delta uint64) (uint64, error) {
return c.incrDecr("incr", key, delta)
func (c *conn) IncrementContext(ctx context.Context, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, ErrMalformedKey
}
return c.pconn.IncrDecr(ctx, "incr", key, delta)
}
func (c *conn) DecrementContext(ctx context.Context, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, ErrMalformedKey
}
return c.pconn.IncrDecr(ctx, "decr", key, delta)
}
func (c *conn) TouchContext(ctx context.Context, key string, seconds int32) error {
if !legalKey(key) {
return ErrMalformedKey
}
return c.pconn.Touch(ctx, key, seconds)
}
func (c *conn) Add(item *Item) error {
return c.AddContext(context.TODO(), item)
}
func (c *conn) Set(item *Item) error {
return c.SetContext(context.TODO(), item)
}
func (c *conn) Replace(item *Item) error {
return c.ReplaceContext(context.TODO(), item)
}
func (c *conn) Get(key string) (*Item, error) {
return c.GetContext(context.TODO(), key)
}
func (c *conn) GetMulti(keys []string) (map[string]*Item, error) {
return c.GetMultiContext(context.TODO(), keys)
}
func (c *conn) Delete(key string) error {
return c.DeleteContext(context.TODO(), key)
}
func (c *conn) Increment(key string, delta uint64) (newValue uint64, err error) {
return c.IncrementContext(context.TODO(), key, delta)
}
func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) {
return c.incrDecr("decr", key, delta)
return c.DecrementContext(context.TODO(), key, delta)
}
func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) {
if !legalKey(key) {
return 0, pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta)
if err != nil {
return 0, err
}
switch {
case bytes.Equal(line, replyNotFound):
return 0, ErrNotFound
case bytes.HasPrefix(line, replyClientErrorPrefix):
errMsg := line[len(replyClientErrorPrefix):]
return 0, pkgerr.WithStack(protocolError(errMsg))
}
val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64)
if err != nil {
return 0, err
}
return val, nil
func (c *conn) CompareAndSwap(item *Item) error {
return c.CompareAndSwapContext(context.TODO(), item)
}
func (c *conn) Delete(key string) (err error) {
if !legalKey(key) {
return pkgerr.WithStack(ErrMalformedKey)
}
line, err := c.writeReadLine("delete %s\r\n", key)
if err != nil {
return err
}
switch {
case bytes.Equal(line, replyOK):
return nil
case bytes.Equal(line, replyDeleted):
return nil
case bytes.Equal(line, replyNotStored):
return ErrNotStored
case bytes.Equal(line, replyExists):
return ErrCASConflict
case bytes.Equal(line, replyNotFound):
return ErrNotFound
}
return pkgerr.WithStack(protocolError(string(line)))
}
func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) {
if c.writeTimeout != 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
}
_, err := fmt.Fprintf(c.rw, format, args...)
if err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if err = c.rw.Flush(); err != nil {
return nil, c.fatal(pkgerr.WithStack(err))
}
if c.readTimeout != 0 {
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
}
line, err := c.rw.ReadSlice('\n')
if err != nil {
return line, c.fatal(pkgerr.WithStack(err))
}
return line, nil
func (c *conn) Touch(key string, seconds int32) (err error) {
return c.TouchContext(context.TODO(), key, seconds)
}
func (c *conn) Scan(item *Item, v interface{}) (err error) {
c.ir.Reset(item.Value)
if item.Flags&FlagGzip == FlagGzip {
if err = c.gr.Reset(&c.ir); err != nil {
return
}
if err = c.decode(&c.gr, item, v); err != nil {
err = pkgerr.WithStack(err)
return
}
err = c.gr.Close()
} else {
err = c.decode(&c.ir, item, v)
}
err = pkgerr.WithStack(err)
return
}
func (c *conn) WithContext(ctx context.Context) Conn {
// FIXME: implement WithContext
return c
}
func (c *conn) encode(item *Item) (data []byte, err error) {
if (item.Flags | _flagEncoding) == _flagEncoding {
if item.Value == nil {
return nil, ErrItem
}
} else if item.Object == nil {
return nil, ErrItem
}
// encoding
switch {
case item.Flags&FlagGOB == FlagGOB:
c.edb.Reset()
if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil {
return
}
data = c.edb.Bytes()
case item.Flags&FlagProtobuf == FlagProtobuf:
c.edb.Reset()
c.ped.SetBuf(c.edb.Bytes())
pb, ok := item.Object.(proto.Message)
if !ok {
err = ErrItemObject
return
}
if err = c.ped.Marshal(pb); err != nil {
return
}
data = c.ped.Bytes()
case item.Flags&FlagJSON == FlagJSON:
c.edb.Reset()
if err = c.je.Encode(item.Object); err != nil {
return
}
data = c.edb.Bytes()
default:
data = item.Value
}
// compress
if item.Flags&FlagGzip == FlagGzip {
c.cb.Reset()
c.gw.Reset(&c.cb)
if _, err = c.gw.Write(data); err != nil {
return
}
if err = c.gw.Close(); err != nil {
return
}
data = c.cb.Bytes()
}
if len(data) > 8000000 {
err = ErrValueSize
}
return
}
func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) {
var data []byte
switch {
case item.Flags&FlagGOB == FlagGOB:
err = gob.NewDecoder(rd).Decode(v)
case item.Flags&FlagJSON == FlagJSON:
c.jr.Reset(rd)
err = c.jd.Decode(v)
default:
data = item.Value
if item.Flags&FlagGzip == FlagGzip {
c.edb.Reset()
if _, err = io.Copy(&c.edb, rd); err != nil {
return
}
data = c.edb.Bytes()
}
if item.Flags&FlagProtobuf == FlagProtobuf {
m, ok := v.(proto.Message)
if !ok {
err = ErrItemObject
return
}
c.ped.SetBuf(data)
err = c.ped.Unmarshal(m)
} else {
switch v.(type) {
case *[]byte:
d := v.(*[]byte)
*d = data
case *string:
d := v.(*string)
*d = string(data)
case interface{}:
err = json.Unmarshal(data, v)
}
}
}
return
}
func legalKey(key string) bool {
if len(key) > 250 || len(key) == 0 {
return false
}
for i := 0; i < len(key); i++ {
if key[i] <= ' ' || key[i] == 0x7f {
return false
}
}
return true
return pkgerr.WithStack(c.ed.decode(item, v))
}

185
pkg/cache/memcache/conn_test.go vendored Normal file
View File

@ -0,0 +1,185 @@
package memcache
import (
"bytes"
"encoding/json"
"testing"
test "github.com/bilibili/kratos/pkg/cache/memcache/test"
"github.com/gogo/protobuf/proto"
)
func TestConnRaw(t *testing.T) {
item := &Item{
Key: "test",
Value: []byte("test"),
Flags: FlagRAW,
Expiration: 60,
cas: 0,
}
if err := testConnASCII.Set(item); err != nil {
t.Errorf("conn.Store() error(%v)", err)
}
}
func TestConnSerialization(t *testing.T) {
type TestObj struct {
Name string
Age int32
}
tests := []struct {
name string
a *Item
e error
}{
{
"JSON",
&Item{
Key: "test_json",
Object: &TestObj{"json", 1},
Expiration: 60,
Flags: FlagJSON,
},
nil,
},
{
"JSONGzip",
&Item{
Key: "test_json_gzip",
Object: &TestObj{"jsongzip", 2},
Expiration: 60,
Flags: FlagJSON | FlagGzip,
},
nil,
},
{
"GOB",
&Item{
Key: "test_gob",
Object: &TestObj{"gob", 3},
Expiration: 60,
Flags: FlagGOB,
},
nil,
},
{
"GOBGzip",
&Item{
Key: "test_gob_gzip",
Object: &TestObj{"gobgzip", 4},
Expiration: 60,
Flags: FlagGOB | FlagGzip,
},
nil,
},
{
"Protobuf",
&Item{
Key: "test_protobuf",
Object: &test.TestItem{Name: "protobuf", Age: 6},
Expiration: 60,
Flags: FlagProtobuf,
},
nil,
},
{
"ProtobufGzip",
&Item{
Key: "test_protobuf_gzip",
Object: &test.TestItem{Name: "protobufgzip", Age: 7},
Expiration: 60,
Flags: FlagProtobuf | FlagGzip,
},
nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if err := testConnASCII.Set(tc.a); err != nil {
t.Fatal(err)
}
if r, err := testConnASCII.Get(tc.a.Key); err != tc.e {
t.Fatal(err)
} else {
if (tc.a.Flags & FlagProtobuf) > 0 {
var no test.TestItem
if err := testConnASCII.Scan(r, &no); err != nil {
t.Fatal(err)
}
if (tc.a.Object.(*test.TestItem).Name != no.Name) || (tc.a.Object.(*test.TestItem).Age != no.Age) {
t.Fatalf("compare failed error, %v %v", tc.a.Object.(*test.TestItem), no)
}
} else {
var no TestObj
if err := testConnASCII.Scan(r, &no); err != nil {
t.Fatal(err)
}
if (tc.a.Object.(*TestObj).Name != no.Name) || (tc.a.Object.(*TestObj).Age != no.Age) {
t.Fatalf("compare failed error, %v %v", tc.a.Object.(*TestObj), no)
}
}
}
})
}
}
func BenchmarkConnJSON(b *testing.B) {
st := &struct {
Name string
Age int
}{"json", 10}
itemx := &Item{Key: "json", Object: st, Flags: FlagJSON}
var (
eb bytes.Buffer
je *json.Encoder
ir bytes.Reader
jd *json.Decoder
jr reader
nst test.TestItem
)
jd = json.NewDecoder(&jr)
je = json.NewEncoder(&eb)
eb.Grow(_encodeBuf)
// NOTE reuse bytes.Buffer internal buf
// DON'T concurrency call Scan
b.ResetTimer()
for i := 0; i < b.N; i++ {
eb.Reset()
if err := je.Encode(itemx.Object); err != nil {
return
}
data := eb.Bytes()
ir.Reset(data)
jr.Reset(&ir)
jd.Decode(&nst)
}
}
func BenchmarkConnProtobuf(b *testing.B) {
st := &test.TestItem{Name: "protobuf", Age: 10}
itemx := &Item{Key: "protobuf", Object: st, Flags: FlagJSON}
var (
eb bytes.Buffer
nst test.TestItem
ped *proto.Buffer
)
ped = proto.NewBuffer(eb.Bytes())
eb.Grow(_encodeBuf)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ped.Reset()
pb, ok := itemx.Object.(proto.Message)
if !ok {
return
}
if err := ped.Marshal(pb); err != nil {
return
}
data := ped.Bytes()
ped.SetBuf(data)
ped.Unmarshal(&nst)
}
}

162
pkg/cache/memcache/encoding.go vendored Normal file
View File

@ -0,0 +1,162 @@
package memcache
import (
"bytes"
"compress/gzip"
"encoding/gob"
"encoding/json"
"io"
"github.com/gogo/protobuf/proto"
)
type reader struct {
io.Reader
}
func (r *reader) Reset(rd io.Reader) {
r.Reader = rd
}
const _encodeBuf = 4096 // 4kb
type encodeDecode struct {
// Item Reader
ir bytes.Reader
// Compress
gr gzip.Reader
gw *gzip.Writer
cb bytes.Buffer
// Encoding
edb bytes.Buffer
// json
jr reader
jd *json.Decoder
je *json.Encoder
// protobuffer
ped *proto.Buffer
}
func newEncodeDecoder() *encodeDecode {
ed := &encodeDecode{}
ed.jd = json.NewDecoder(&ed.jr)
ed.je = json.NewEncoder(&ed.edb)
ed.gw = gzip.NewWriter(&ed.cb)
ed.edb.Grow(_encodeBuf)
// NOTE reuse bytes.Buffer internal buf
// DON'T concurrency call Scan
ed.ped = proto.NewBuffer(ed.edb.Bytes())
return ed
}
func (ed *encodeDecode) encode(item *Item) (data []byte, err error) {
if (item.Flags | _flagEncoding) == _flagEncoding {
if item.Value == nil {
return nil, ErrItem
}
} else if item.Object == nil {
return nil, ErrItem
}
// encoding
switch {
case item.Flags&FlagGOB == FlagGOB:
ed.edb.Reset()
if err = gob.NewEncoder(&ed.edb).Encode(item.Object); err != nil {
return
}
data = ed.edb.Bytes()
case item.Flags&FlagProtobuf == FlagProtobuf:
ed.edb.Reset()
ed.ped.SetBuf(ed.edb.Bytes())
pb, ok := item.Object.(proto.Message)
if !ok {
err = ErrItemObject
return
}
if err = ed.ped.Marshal(pb); err != nil {
return
}
data = ed.ped.Bytes()
case item.Flags&FlagJSON == FlagJSON:
ed.edb.Reset()
if err = ed.je.Encode(item.Object); err != nil {
return
}
data = ed.edb.Bytes()
default:
data = item.Value
}
// compress
if item.Flags&FlagGzip == FlagGzip {
ed.cb.Reset()
ed.gw.Reset(&ed.cb)
if _, err = ed.gw.Write(data); err != nil {
return
}
if err = ed.gw.Close(); err != nil {
return
}
data = ed.cb.Bytes()
}
if len(data) > 8000000 {
err = ErrValueSize
}
return
}
func (ed *encodeDecode) decode(item *Item, v interface{}) (err error) {
var (
data []byte
rd io.Reader
)
ed.ir.Reset(item.Value)
rd = &ed.ir
if item.Flags&FlagGzip == FlagGzip {
rd = &ed.gr
if err = ed.gr.Reset(&ed.ir); err != nil {
return
}
defer func() {
if e := ed.gr.Close(); e != nil {
err = e
}
}()
}
switch {
case item.Flags&FlagGOB == FlagGOB:
err = gob.NewDecoder(rd).Decode(v)
case item.Flags&FlagJSON == FlagJSON:
ed.jr.Reset(rd)
err = ed.jd.Decode(v)
default:
data = item.Value
if item.Flags&FlagGzip == FlagGzip {
ed.edb.Reset()
if _, err = io.Copy(&ed.edb, rd); err != nil {
return
}
data = ed.edb.Bytes()
}
if item.Flags&FlagProtobuf == FlagProtobuf {
m, ok := v.(proto.Message)
if !ok {
err = ErrItemObject
return
}
ed.ped.SetBuf(data)
err = ed.ped.Unmarshal(m)
} else {
switch v.(type) {
case *[]byte:
d := v.(*[]byte)
*d = data
case *string:
d := v.(*string)
*d = string(data)
case interface{}:
err = json.Unmarshal(data, v)
}
}
}
return
}

220
pkg/cache/memcache/encoding_test.go vendored Normal file
View File

@ -0,0 +1,220 @@
package memcache
import (
"bytes"
"testing"
mt "github.com/bilibili/kratos/pkg/cache/memcache/test"
)
func TestEncode(t *testing.T) {
type TestObj struct {
Name string
Age int32
}
testObj := TestObj{"abc", 1}
ed := newEncodeDecoder()
tests := []struct {
name string
a *Item
r []byte
e error
}{
{
"EncodeRawFlagErrItem",
&Item{
Object: &TestObj{"abc", 1},
Flags: FlagRAW,
},
[]byte{},
ErrItem,
},
{
"EncodeEncodeFlagErrItem",
&Item{
Value: []byte("test"),
Flags: FlagJSON,
},
[]byte{},
ErrItem,
},
{
"EncodeEmpty",
&Item{
Value: []byte(""),
Flags: FlagRAW,
},
[]byte(""),
nil,
},
{
"EncodeMaxSize",
&Item{
Value: bytes.Repeat([]byte("A"), 8000000),
Flags: FlagRAW,
},
bytes.Repeat([]byte("A"), 8000000),
nil,
},
{
"EncodeExceededMaxSize",
&Item{
Value: bytes.Repeat([]byte("A"), 8000000+1),
Flags: FlagRAW,
},
nil,
ErrValueSize,
},
{
"EncodeGOB",
&Item{
Object: testObj,
Flags: FlagGOB,
},
[]byte{38, 255, 131, 3, 1, 1, 7, 84, 101, 115, 116, 79, 98, 106, 1, 255, 132, 0, 1, 2, 1, 4, 78, 97, 109, 101, 1, 12, 0, 1, 3, 65, 103, 101, 1, 4, 0, 0, 0, 10, 255, 132, 1, 3, 97, 98, 99, 1, 2, 0},
nil,
},
{
"EncodeJSON",
&Item{
Object: testObj,
Flags: FlagJSON,
},
[]byte{123, 34, 78, 97, 109, 101, 34, 58, 34, 97, 98, 99, 34, 44, 34, 65, 103, 101, 34, 58, 49, 125, 10},
nil,
},
{
"EncodeProtobuf",
&Item{
Object: &mt.TestItem{Name: "abc", Age: 1},
Flags: FlagProtobuf,
},
[]byte{10, 3, 97, 98, 99, 16, 1},
nil,
},
{
"EncodeGzip",
&Item{
Value: bytes.Repeat([]byte("B"), 50),
Flags: FlagGzip,
},
[]byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 114, 34, 25, 0, 2, 0, 0, 255, 255, 252, 253, 67, 209, 50, 0, 0, 0},
nil,
},
{
"EncodeGOBGzip",
&Item{
Object: testObj,
Flags: FlagGOB | FlagGzip,
},
[]byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 82, 251, 223, 204, 204, 200, 200, 30, 146, 90, 92, 226, 159, 148, 197, 248, 191, 133, 129, 145, 137, 145, 197, 47, 49, 55, 149, 145, 135, 129, 145, 217, 49, 61, 149, 145, 133, 129, 129, 129, 235, 127, 11, 35, 115, 98, 82, 50, 35, 19, 3, 32, 0, 0, 255, 255, 211, 249, 1, 154, 50, 0, 0, 0},
nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if r, err := ed.encode(test.a); err != test.e {
t.Fatal(err)
} else {
if err == nil {
if !bytes.Equal(r, test.r) {
t.Fatalf("not equal, expect %v\n got %v", test.r, r)
}
}
}
})
}
}
func TestDecode(t *testing.T) {
type TestObj struct {
Name string
Age int32
}
testObj := &TestObj{"abc", 1}
ed := newEncodeDecoder()
tests := []struct {
name string
a *Item
r interface{}
e error
}{
{
"DecodeGOB",
&Item{
Flags: FlagGOB,
Value: []byte{38, 255, 131, 3, 1, 1, 7, 84, 101, 115, 116, 79, 98, 106, 1, 255, 132, 0, 1, 2, 1, 4, 78, 97, 109, 101, 1, 12, 0, 1, 3, 65, 103, 101, 1, 4, 0, 0, 0, 10, 255, 132, 1, 3, 97, 98, 99, 1, 2, 0},
},
testObj,
nil,
},
{
"DecodeJSON",
&Item{
Value: []byte{123, 34, 78, 97, 109, 101, 34, 58, 34, 97, 98, 99, 34, 44, 34, 65, 103, 101, 34, 58, 49, 125, 10},
Flags: FlagJSON,
},
testObj,
nil,
},
{
"DecodeProtobuf",
&Item{
Value: []byte{10, 3, 97, 98, 99, 16, 1},
Flags: FlagProtobuf,
},
&mt.TestItem{Name: "abc", Age: 1},
nil,
},
{
"DecodeGzip",
&Item{
Value: []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 114, 34, 25, 0, 2, 0, 0, 255, 255, 252, 253, 67, 209, 50, 0, 0, 0},
Flags: FlagGzip,
},
bytes.Repeat([]byte("B"), 50),
nil,
},
{
"DecodeGOBGzip",
&Item{
Value: []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 82, 251, 223, 204, 204, 200, 200, 30, 146, 90, 92, 226, 159, 148, 197, 248, 191, 133, 129, 145, 137, 145, 197, 47, 49, 55, 149, 145, 135, 129, 145, 217, 49, 61, 149, 145, 133, 129, 129, 129, 235, 127, 11, 35, 115, 98, 82, 50, 35, 19, 3, 32, 0, 0, 255, 255, 211, 249, 1, 154, 50, 0, 0, 0},
Flags: FlagGOB | FlagGzip,
},
testObj,
nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if (test.a.Flags & FlagProtobuf) > 0 {
var dd mt.TestItem
if err := ed.decode(test.a, &dd); err != nil {
t.Fatal(err)
}
if (test.r.(*mt.TestItem).Name != dd.Name) || (test.r.(*mt.TestItem).Age != dd.Age) {
t.Fatalf("compare failed error, expect %v\n got %v", test.r.(*mt.TestItem), dd)
}
} else if test.a.Flags == FlagGzip {
var dd []byte
if err := ed.decode(test.a, &dd); err != nil {
t.Fatal(err)
}
if !bytes.Equal(dd, test.r.([]byte)) {
t.Fatalf("compare failed error, expect %v\n got %v", test.r, dd)
}
} else {
var dd TestObj
if err := ed.decode(test.a, &dd); err != nil {
t.Fatal(err)
}
if (test.r.(*TestObj).Name != dd.Name) || (test.r.(*TestObj).Age != dd.Age) {
t.Fatalf("compare failed error, expect %v\n got %v", test.r.(*TestObj), dd)
}
}
})
}
}

177
pkg/cache/memcache/example_test.go vendored Normal file
View File

@ -0,0 +1,177 @@
package memcache
import (
"encoding/json"
"fmt"
"time"
)
var testExampleAddr string
func ExampleConn_set() {
var (
err error
value []byte
conn Conn
expire int32 = 100
p = struct {
Name string
Age int64
}{"golang", 10}
)
cnop := DialConnectTimeout(time.Duration(time.Second))
rdop := DialReadTimeout(time.Duration(time.Second))
wrop := DialWriteTimeout(time.Duration(time.Second))
if value, err = json.Marshal(p); err != nil {
fmt.Println(err)
return
}
if conn, err = Dial("tcp", testExampleAddr, cnop, rdop, wrop); err != nil {
fmt.Println(err)
return
}
// FlagRAW test
itemRaw := &Item{
Key: "test_raw",
Value: value,
Expiration: expire,
}
if err = conn.Set(itemRaw); err != nil {
fmt.Println(err)
return
}
// FlagGzip
itemGZip := &Item{
Key: "test_gzip",
Value: value,
Flags: FlagGzip,
Expiration: expire,
}
if err = conn.Set(itemGZip); err != nil {
fmt.Println(err)
return
}
// FlagGOB
itemGOB := &Item{
Key: "test_gob",
Object: p,
Flags: FlagGOB,
Expiration: expire,
}
if err = conn.Set(itemGOB); err != nil {
fmt.Println(err)
return
}
// FlagJSON
itemJSON := &Item{
Key: "test_json",
Object: p,
Flags: FlagJSON,
Expiration: expire,
}
if err = conn.Set(itemJSON); err != nil {
fmt.Println(err)
return
}
// FlagJSON | FlagGzip
itemJSONGzip := &Item{
Key: "test_jsonGzip",
Object: p,
Flags: FlagJSON | FlagGzip,
Expiration: expire,
}
if err = conn.Set(itemJSONGzip); err != nil {
fmt.Println(err)
return
}
// Output:
}
func ExampleConn_get() {
var (
err error
item2 *Item
conn Conn
p struct {
Name string
Age int64
}
)
cnop := DialConnectTimeout(time.Duration(time.Second))
rdop := DialReadTimeout(time.Duration(time.Second))
wrop := DialWriteTimeout(time.Duration(time.Second))
if conn, err = Dial("tcp", testExampleAddr, cnop, rdop, wrop); err != nil {
fmt.Println(err)
return
}
if item2, err = conn.Get("test_raw"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagRAW conn.Scan error(%v)\n", err)
return
}
}
// FlagGZip
if item2, err = conn.Get("test_gzip"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagGZip conn.Scan error(%v)\n", err)
return
}
}
// FlagGOB
if item2, err = conn.Get("test_gob"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagGOB conn.Scan error(%v)\n", err)
return
}
}
// FlagJSON
if item2, err = conn.Get("test_json"); err != nil {
fmt.Println(err)
} else {
if err = conn.Scan(item2, &p); err != nil {
fmt.Printf("FlagJSON conn.Scan error(%v)\n", err)
return
}
}
// Output:
}
func ExampleConn_getMulti() {
var (
err error
conn Conn
res map[string]*Item
keys = []string{"test_raw", "test_gzip"}
p struct {
Name string
Age int64
}
)
cnop := DialConnectTimeout(time.Duration(time.Second))
rdop := DialReadTimeout(time.Duration(time.Second))
wrop := DialWriteTimeout(time.Duration(time.Second))
if conn, err = Dial("tcp", testExampleAddr, cnop, rdop, wrop); err != nil {
fmt.Println(err)
return
}
if res, err = conn.GetMulti(keys); err != nil {
fmt.Printf("conn.GetMulti(%v) error(%v)", keys, err)
return
}
for _, v := range res {
if err = conn.Scan(v, &p); err != nil {
fmt.Printf("conn.Scan error(%v)\n", err)
return
}
fmt.Println(p)
}
// Output:
//{golang 10}
//{golang 10}
}

85
pkg/cache/memcache/main_test.go vendored Normal file
View File

@ -0,0 +1,85 @@
package memcache
import (
"log"
"os"
"testing"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
)
var testConnASCII Conn
var testMemcache *Memcache
var testPool *Pool
var testMemcacheAddr string
func setupTestConnASCII(addr string) {
var err error
cnop := DialConnectTimeout(time.Duration(2 * time.Second))
rdop := DialReadTimeout(time.Duration(2 * time.Second))
wrop := DialWriteTimeout(time.Duration(2 * time.Second))
testConnASCII, err = Dial("tcp", addr, cnop, rdop, wrop)
if err != nil {
log.Fatal(err)
}
testConnASCII.Delete("test")
testConnASCII.Delete("test1")
testConnASCII.Delete("test2")
if err != nil {
log.Fatal(err)
}
}
func setupTestMemcache(addr string) {
testConfig := &Config{
Config: &pool.Config{
Active: 10,
Idle: 10,
IdleTimeout: xtime.Duration(time.Second),
WaitTimeout: xtime.Duration(time.Second),
Wait: false,
},
Addr: addr,
Proto: "tcp",
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
testMemcache = New(testConfig)
}
func setupTestPool(addr string) {
config := &Config{
Name: "test",
Proto: "tcp",
Addr: addr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
config.Config = &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(90 * time.Second),
}
testPool = NewPool(config)
}
func TestMain(m *testing.M) {
testMemcacheAddr = os.Getenv("TEST_MEMCACHE_ADDR")
if testExampleAddr == "" {
log.Print("TEST_MEMCACHE_ADDR not provide skip test.")
// ignored test.
os.Exit(0)
}
setupTestConnASCII(testMemcacheAddr)
setupTestMemcache(testMemcacheAddr)
setupTestPool(testMemcacheAddr)
// TODO: add setupexample?
testExampleAddr = testMemcacheAddr
ret := m.Run()
os.Exit(ret)
}

View File

@ -2,13 +2,11 @@ package memcache
import (
"context"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
)
// Error represents an error returned in a command reply.
type Error string
func (err Error) Error() string { return string(err) }
const (
// Flag, 15(encoding) bit+ 17(compress) bit
@ -87,20 +85,20 @@ type Conn interface {
GetMulti(keys []string) (map[string]*Item, error)
// Delete deletes the item with the provided key.
// The error ErrCacheMiss is returned if the item didn't already exist in
// The error ErrNotFound is returned if the item didn't already exist in
// the cache.
Delete(key string) error
// Increment atomically increments key by delta. The return value is the
// new value after being incremented or an error. If the value didn't exist
// in memcached the error is ErrCacheMiss. The value in memcached must be
// in memcached the error is ErrNotFound. The value in memcached must be
// an decimal number, or an error will be returned.
// On 64-bit overflow, the new value wraps around.
Increment(key string, delta uint64) (newValue uint64, err error)
// Decrement atomically decrements key by delta. The return value is the
// new value after being decremented or an error. If the value didn't exist
// in memcached the error is ErrCacheMiss. The value in memcached must be
// in memcached the error is ErrNotFound. The value in memcached must be
// an decimal number, or an error will be returned. On underflow, the new
// value is capped at zero and does not wrap around.
Decrement(key string, delta uint64) (newValue uint64, err error)
@ -116,7 +114,7 @@ type Conn interface {
// Touch updates the expiry for the given key. The seconds parameter is
// either a Unix timestamp or, if seconds is less than 1 month, the number
// of seconds into the future at which time the item will expire.
//ErrCacheMiss is returned if the key is not in the cache. The key must be
// ErrNotFound is returned if the key is not in the cache. The key must be
// at most 250 bytes in length.
Touch(key string, seconds int32) (err error)
@ -129,8 +127,251 @@ type Conn interface {
//
Scan(item *Item, v interface{}) (err error)
// WithContext return a Conn with its context changed to ctx
// the context controls the entire lifetime of Conn before you change it
// NOTE: this method is not thread-safe
WithContext(ctx context.Context) Conn
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
AddContext(ctx context.Context, item *Item) error
// Set writes the given item, unconditionally.
SetContext(ctx context.Context, item *Item) error
// Replace writes the given item, but only if the server *does* already
// hold data for this key.
ReplaceContext(ctx context.Context, item *Item) error
// Get sends a command to the server for gets data.
GetContext(ctx context.Context, key string) (*Item, error)
// GetMulti is a batch version of Get. The returned map from keys to items
// may have fewer elements than the input slice, due to memcache cache
// misses. Each key must be at most 250 bytes in length.
// If no error is returned, the returned map will also be non-nil.
GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error)
// Delete deletes the item with the provided key.
// The error ErrNotFound is returned if the item didn't already exist in
// the cache.
DeleteContext(ctx context.Context, key string) error
// Increment atomically increments key by delta. The return value is the
// new value after being incremented or an error. If the value didn't exist
// in memcached the error is ErrNotFound. The value in memcached must be
// an decimal number, or an error will be returned.
// On 64-bit overflow, the new value wraps around.
IncrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error)
// Decrement atomically decrements key by delta. The return value is the
// new value after being decremented or an error. If the value didn't exist
// in memcached the error is ErrNotFound. The value in memcached must be
// an decimal number, or an error will be returned. On underflow, the new
// value is capped at zero and does not wrap around.
DecrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error)
// CompareAndSwap writes the given item that was previously returned by
// Get, if the value was neither modified or evicted between the Get and
// the CompareAndSwap calls. The item's Key should not change between calls
// but all other item fields may differ. ErrCASConflict is returned if the
// value was modified in between the calls.
// ErrNotStored is returned if the value was evicted in between the calls.
CompareAndSwapContext(ctx context.Context, item *Item) error
// Touch updates the expiry for the given key. The seconds parameter is
// either a Unix timestamp or, if seconds is less than 1 month, the number
// of seconds into the future at which time the item will expire.
// ErrNotFound is returned if the key is not in the cache. The key must be
// at most 250 bytes in length.
TouchContext(ctx context.Context, key string, seconds int32) (err error)
}
// Config memcache config.
type Config struct {
*pool.Config
Name string // memcache name, for trace
Proto string
Addr string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// Memcache memcache client
type Memcache struct {
pool *Pool
}
// Reply is the result of Get
type Reply struct {
err error
item *Item
conn Conn
closed bool
}
// Replies is the result of GetMulti
type Replies struct {
err error
items map[string]*Item
usedItems map[string]struct{}
conn Conn
closed bool
}
// New get a memcache client
func New(cfg *Config) *Memcache {
return &Memcache{pool: NewPool(cfg)}
}
// Close close connection pool
func (mc *Memcache) Close() error {
return mc.pool.Close()
}
// Conn direct get a connection
func (mc *Memcache) Conn(ctx context.Context) Conn {
return mc.pool.Get(ctx)
}
// Set writes the given item, unconditionally.
func (mc *Memcache) Set(ctx context.Context, item *Item) (err error) {
conn := mc.pool.Get(ctx)
err = conn.SetContext(ctx, item)
conn.Close()
return
}
// Add writes the given item, if no value already exists for its key.
// ErrNotStored is returned if that condition is not met.
func (mc *Memcache) Add(ctx context.Context, item *Item) (err error) {
conn := mc.pool.Get(ctx)
err = conn.AddContext(ctx, item)
conn.Close()
return
}
// Replace writes the given item, but only if the server *does* already hold data for this key.
func (mc *Memcache) Replace(ctx context.Context, item *Item) (err error) {
conn := mc.pool.Get(ctx)
err = conn.ReplaceContext(ctx, item)
conn.Close()
return
}
// CompareAndSwap writes the given item that was previously returned by Get
func (mc *Memcache) CompareAndSwap(ctx context.Context, item *Item) (err error) {
conn := mc.pool.Get(ctx)
err = conn.CompareAndSwapContext(ctx, item)
conn.Close()
return
}
// Get sends a command to the server for gets data.
func (mc *Memcache) Get(ctx context.Context, key string) *Reply {
conn := mc.pool.Get(ctx)
item, err := conn.GetContext(ctx, key)
if err != nil {
conn.Close()
}
return &Reply{err: err, item: item, conn: conn}
}
// Item get raw Item
func (r *Reply) Item() *Item {
return r.item
}
// Scan converts value, read from the memcache
func (r *Reply) Scan(v interface{}) (err error) {
if r.err != nil {
return r.err
}
err = r.conn.Scan(r.item, v)
if !r.closed {
r.conn.Close()
r.closed = true
}
return
}
// GetMulti is a batch version of Get
func (mc *Memcache) GetMulti(ctx context.Context, keys []string) (*Replies, error) {
conn := mc.pool.Get(ctx)
items, err := conn.GetMultiContext(ctx, keys)
rs := &Replies{err: err, items: items, conn: conn, usedItems: make(map[string]struct{}, len(keys))}
if (err != nil) || (len(items) == 0) {
rs.Close()
}
return rs, err
}
// Close close rows.
func (rs *Replies) Close() (err error) {
if !rs.closed {
err = rs.conn.Close()
rs.closed = true
}
return
}
// Item get Item from rows
func (rs *Replies) Item(key string) *Item {
return rs.items[key]
}
// Scan converts value, read from key in rows
func (rs *Replies) Scan(key string, v interface{}) (err error) {
if rs.err != nil {
return rs.err
}
item, ok := rs.items[key]
if !ok {
rs.Close()
return ErrNotFound
}
rs.usedItems[key] = struct{}{}
err = rs.conn.Scan(item, v)
if (err != nil) || (len(rs.items) == len(rs.usedItems)) {
rs.Close()
}
return
}
// Keys keys of result
func (rs *Replies) Keys() (keys []string) {
keys = make([]string, 0, len(rs.items))
for key := range rs.items {
keys = append(keys, key)
}
return
}
// Touch updates the expiry for the given key.
func (mc *Memcache) Touch(ctx context.Context, key string, timeout int32) (err error) {
conn := mc.pool.Get(ctx)
err = conn.TouchContext(ctx, key, timeout)
conn.Close()
return
}
// Delete deletes the item with the provided key.
func (mc *Memcache) Delete(ctx context.Context, key string) (err error) {
conn := mc.pool.Get(ctx)
err = conn.DeleteContext(ctx, key)
conn.Close()
return
}
// Increment atomically increments key by delta.
func (mc *Memcache) Increment(ctx context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(ctx)
newValue, err = conn.IncrementContext(ctx, key, delta)
conn.Close()
return
}
// Decrement atomically decrements key by delta.
func (mc *Memcache) Decrement(ctx context.Context, key string, delta uint64) (newValue uint64, err error) {
conn := mc.pool.Get(ctx)
newValue, err = conn.DecrementContext(ctx, key, delta)
conn.Close()
return
}

300
pkg/cache/memcache/memcache_test.go vendored Normal file
View File

@ -0,0 +1,300 @@
package memcache
import (
"context"
"fmt"
"reflect"
"testing"
"time"
)
func Test_client_Set(t *testing.T) {
type args struct {
c context.Context
item *Item
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "set value", args: args{c: context.Background(), item: &Item{Key: "Test_client_Set", Value: []byte("abc")}}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testMemcache.Set(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.Set() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Add(t *testing.T) {
type args struct {
c context.Context
item *Item
}
key := fmt.Sprintf("Test_client_Add_%d", time.Now().Unix())
tests := []struct {
name string
args args
wantErr bool
}{
{name: "add not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: false},
{name: "add exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testMemcache.Add(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.Add() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Replace(t *testing.T) {
key := fmt.Sprintf("Test_client_Replace_%d", time.Now().Unix())
ekey := "Test_client_Replace_exist"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("ok")})
type args struct {
c context.Context
item *Item
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), item: &Item{Key: ekey, Value: []byte("abc")}}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testMemcache.Replace(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.Replace() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_CompareAndSwap(t *testing.T) {
key := fmt.Sprintf("Test_client_CompareAndSwap_%d", time.Now().Unix())
ekey := "Test_client_CompareAndSwap_k"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
cas := testMemcache.Get(context.Background(), ekey).Item().cas
type args struct {
c context.Context
item *Item
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), item: &Item{Key: key, Value: []byte("abc")}}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), item: &Item{Key: ekey, cas: cas, Value: []byte("abc")}}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testMemcache.CompareAndSwap(tt.args.c, tt.args.item); (err != nil) != tt.wantErr {
t.Errorf("client.CompareAndSwap() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Get(t *testing.T) {
key := fmt.Sprintf("Test_client_Get_%d", time.Now().Unix())
ekey := "Test_client_Get_k"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
type args struct {
c context.Context
key string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey}, wantErr: false, want: "old"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var res string
if err := testMemcache.Get(tt.args.c, tt.args.key).Scan(&res); (err != nil) != tt.wantErr || res != tt.want {
t.Errorf("client.Get() = %v, want %v, got err: %v, want err: %v", err, tt.want, err, tt.wantErr)
}
})
}
}
func Test_client_Touch(t *testing.T) {
key := fmt.Sprintf("Test_client_Touch_%d", time.Now().Unix())
ekey := "Test_client_Touch_k"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
type args struct {
c context.Context
key string
timeout int32
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key, timeout: 100000}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey, timeout: 100000}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testMemcache.Touch(tt.args.c, tt.args.key, tt.args.timeout); (err != nil) != tt.wantErr {
t.Errorf("client.Touch() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Delete(t *testing.T) {
key := fmt.Sprintf("Test_client_Delete_%d", time.Now().Unix())
ekey := "Test_client_Delete_k"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("old")})
type args struct {
c context.Context
key string
}
tests := []struct {
name string
args args
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey}, wantErr: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := testMemcache.Delete(tt.args.c, tt.args.key); (err != nil) != tt.wantErr {
t.Errorf("client.Delete() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_client_Increment(t *testing.T) {
key := fmt.Sprintf("Test_client_Increment_%d", time.Now().Unix())
ekey := "Test_client_Increment_k"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("1")})
type args struct {
c context.Context
key string
delta uint64
}
tests := []struct {
name string
args args
wantNewValue uint64
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key, delta: 10}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey, delta: 10}, wantErr: false, wantNewValue: 11},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNewValue, err := testMemcache.Increment(tt.args.c, tt.args.key, tt.args.delta)
if (err != nil) != tt.wantErr {
t.Errorf("client.Increment() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotNewValue != tt.wantNewValue {
t.Errorf("client.Increment() = %v, want %v", gotNewValue, tt.wantNewValue)
}
})
}
}
func Test_client_Decrement(t *testing.T) {
key := fmt.Sprintf("Test_client_Decrement_%d", time.Now().Unix())
ekey := "Test_client_Decrement_k"
testMemcache.Set(context.Background(), &Item{Key: ekey, Value: []byte("100")})
type args struct {
c context.Context
key string
delta uint64
}
tests := []struct {
name string
args args
wantNewValue uint64
wantErr bool
}{
{name: "not exist value", args: args{c: context.Background(), key: key, delta: 10}, wantErr: true},
{name: "exist value", args: args{c: context.Background(), key: ekey, delta: 10}, wantErr: false, wantNewValue: 90},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotNewValue, err := testMemcache.Decrement(tt.args.c, tt.args.key, tt.args.delta)
if (err != nil) != tt.wantErr {
t.Errorf("client.Decrement() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotNewValue != tt.wantNewValue {
t.Errorf("client.Decrement() = %v, want %v", gotNewValue, tt.wantNewValue)
}
})
}
}
func Test_client_GetMulti(t *testing.T) {
key := fmt.Sprintf("Test_client_GetMulti_%d", time.Now().Unix())
ekey1 := "Test_client_GetMulti_k1"
ekey2 := "Test_client_GetMulti_k2"
testMemcache.Set(context.Background(), &Item{Key: ekey1, Value: []byte("1")})
testMemcache.Set(context.Background(), &Item{Key: ekey2, Value: []byte("2")})
keys := []string{key, ekey1, ekey2}
rows, err := testMemcache.GetMulti(context.Background(), keys)
if err != nil {
t.Errorf("client.GetMulti() error = %v, wantErr %v", err, nil)
}
tests := []struct {
key string
wantNewValue string
wantErr bool
nilItem bool
}{
{key: ekey1, wantErr: false, wantNewValue: "1", nilItem: false},
{key: ekey2, wantErr: false, wantNewValue: "2", nilItem: false},
{key: key, wantErr: true, nilItem: true},
}
if reflect.DeepEqual(keys, rows.Keys()) {
t.Errorf("got %v, expect: %v", rows.Keys(), keys)
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
var gotNewValue string
err = rows.Scan(tt.key, &gotNewValue)
if (err != nil) != tt.wantErr {
t.Errorf("rows.Scan() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotNewValue != tt.wantNewValue {
t.Errorf("rows.Value() = %v, want %v", gotNewValue, tt.wantNewValue)
}
if (rows.Item(tt.key) == nil) != tt.nilItem {
t.Errorf("rows.Item() = %v, want %v", rows.Item(tt.key) == nil, tt.nilItem)
}
})
}
err = rows.Close()
if err != nil {
t.Errorf("client.Replies.Close() error = %v, wantErr %v", err, nil)
}
}
func Test_client_Conn(t *testing.T) {
conn := testMemcache.Conn(context.Background())
defer conn.Close()
if conn == nil {
t.Errorf("expect get conn, get nil")
}
}

View File

@ -1,59 +0,0 @@
package memcache
import (
"context"
)
// MockErr for unit test.
type MockErr struct {
Error error
}
var _ Conn = MockErr{}
// MockWith return a mock conn.
func MockWith(err error) MockErr {
return MockErr{Error: err}
}
// Err .
func (m MockErr) Err() error { return m.Error }
// Close .
func (m MockErr) Close() error { return m.Error }
// Add .
func (m MockErr) Add(item *Item) error { return m.Error }
// Set .
func (m MockErr) Set(item *Item) error { return m.Error }
// Replace .
func (m MockErr) Replace(item *Item) error { return m.Error }
// CompareAndSwap .
func (m MockErr) CompareAndSwap(item *Item) error { return m.Error }
// Get .
func (m MockErr) Get(key string) (*Item, error) { return nil, m.Error }
// GetMulti .
func (m MockErr) GetMulti(keys []string) (map[string]*Item, error) { return nil, m.Error }
// Touch .
func (m MockErr) Touch(key string, timeout int32) error { return m.Error }
// Delete .
func (m MockErr) Delete(key string) error { return m.Error }
// Increment .
func (m MockErr) Increment(key string, delta uint64) (uint64, error) { return 0, m.Error }
// Decrement .
func (m MockErr) Decrement(key string, delta uint64) (uint64, error) { return 0, m.Error }
// Scan .
func (m MockErr) Scan(item *Item, v interface{}) error { return m.Error }
// WithContext .
func (m MockErr) WithContext(ctx context.Context) Conn { return m }

View File

@ -1,197 +0,0 @@
package memcache
import (
"context"
"io"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
"github.com/bilibili/kratos/pkg/stat"
xtime "github.com/bilibili/kratos/pkg/time"
)
var stats = stat.Cache
// Config memcache config.
type Config struct {
*pool.Config
Name string // memcache name, for trace
Proto string
Addr string
DialTimeout xtime.Duration
ReadTimeout xtime.Duration
WriteTimeout xtime.Duration
}
// Pool memcache connection pool struct.
type Pool struct {
p pool.Pool
c *Config
}
// NewPool new a memcache conn pool.
func NewPool(c *Config) (p *Pool) {
if c.DialTimeout <= 0 || c.ReadTimeout <= 0 || c.WriteTimeout <= 0 {
panic("must config memcache timeout")
}
p1 := pool.NewList(c.Config)
cnop := DialConnectTimeout(time.Duration(c.DialTimeout))
rdop := DialReadTimeout(time.Duration(c.ReadTimeout))
wrop := DialWriteTimeout(time.Duration(c.WriteTimeout))
p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(c.Proto, c.Addr, cnop, rdop, wrop)
return &traceConn{Conn: conn, address: c.Addr}, err
}
p = &Pool{p: p1, c: c}
return
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get(ctx context.Context) Conn {
c, err := p.p.Get(ctx)
if err != nil {
return errorConnection{err}
}
c1, _ := c.(Conn)
return &pooledConnection{p: p, c: c1.WithContext(ctx), ctx: ctx}
}
// Close release the resources used by the pool.
func (p *Pool) Close() error {
return p.p.Close()
}
type pooledConnection struct {
p *Pool
c Conn
ctx context.Context
}
func pstat(key string, t time.Time, err error) {
stats.Timing(key, int64(time.Since(t)/time.Millisecond))
if err != nil {
if msg := formatErr(err); msg != "" {
stats.Incr("memcache", msg)
}
}
}
func (pc *pooledConnection) Close() error {
c := pc.c
if _, ok := c.(errorConnection); ok {
return nil
}
pc.c = errorConnection{ErrConnClosed}
pc.p.p.Put(context.Background(), c, c.Err() != nil)
return nil
}
func (pc *pooledConnection) Err() error {
return pc.c.Err()
}
func (pc *pooledConnection) Set(item *Item) (err error) {
now := time.Now()
err = pc.c.Set(item)
pstat("memcache:set", now, err)
return
}
func (pc *pooledConnection) Add(item *Item) (err error) {
now := time.Now()
err = pc.c.Add(item)
pstat("memcache:add", now, err)
return
}
func (pc *pooledConnection) Replace(item *Item) (err error) {
now := time.Now()
err = pc.c.Replace(item)
pstat("memcache:replace", now, err)
return
}
func (pc *pooledConnection) CompareAndSwap(item *Item) (err error) {
now := time.Now()
err = pc.c.CompareAndSwap(item)
pstat("memcache:cas", now, err)
return
}
func (pc *pooledConnection) Get(key string) (r *Item, err error) {
now := time.Now()
r, err = pc.c.Get(key)
pstat("memcache:get", now, err)
return
}
func (pc *pooledConnection) GetMulti(keys []string) (res map[string]*Item, err error) {
// if keys is empty slice returns empty map direct
if len(keys) == 0 {
return make(map[string]*Item), nil
}
now := time.Now()
res, err = pc.c.GetMulti(keys)
pstat("memcache:gets", now, err)
return
}
func (pc *pooledConnection) Touch(key string, timeout int32) (err error) {
now := time.Now()
err = pc.c.Touch(key, timeout)
pstat("memcache:touch", now, err)
return
}
func (pc *pooledConnection) Scan(item *Item, v interface{}) error {
return pc.c.Scan(item, v)
}
func (pc *pooledConnection) WithContext(ctx context.Context) Conn {
// TODO: set context
pc.ctx = ctx
return pc
}
func (pc *pooledConnection) Delete(key string) (err error) {
now := time.Now()
err = pc.c.Delete(key)
pstat("memcache:delete", now, err)
return
}
func (pc *pooledConnection) Increment(key string, delta uint64) (newValue uint64, err error) {
now := time.Now()
newValue, err = pc.c.Increment(key, delta)
pstat("memcache:increment", now, err)
return
}
func (pc *pooledConnection) Decrement(key string, delta uint64) (newValue uint64, err error) {
now := time.Now()
newValue, err = pc.c.Decrement(key, delta)
pstat("memcache:decrement", now, err)
return
}
type errorConnection struct{ err error }
func (ec errorConnection) Err() error { return ec.err }
func (ec errorConnection) Close() error { return ec.err }
func (ec errorConnection) Add(item *Item) error { return ec.err }
func (ec errorConnection) Set(item *Item) error { return ec.err }
func (ec errorConnection) Replace(item *Item) error { return ec.err }
func (ec errorConnection) CompareAndSwap(item *Item) error { return ec.err }
func (ec errorConnection) Get(key string) (*Item, error) { return nil, ec.err }
func (ec errorConnection) GetMulti(keys []string) (map[string]*Item, error) { return nil, ec.err }
func (ec errorConnection) Touch(key string, timeout int32) error { return ec.err }
func (ec errorConnection) Delete(key string) error { return ec.err }
func (ec errorConnection) Increment(key string, delta uint64) (uint64, error) { return 0, ec.err }
func (ec errorConnection) Decrement(key string, delta uint64) (uint64, error) { return 0, ec.err }
func (ec errorConnection) Scan(item *Item, v interface{}) error { return ec.err }
func (ec errorConnection) WithContext(ctx context.Context) Conn { return ec }

204
pkg/cache/memcache/pool_conn.go vendored Normal file
View File

@ -0,0 +1,204 @@
package memcache
import (
"context"
"fmt"
"io"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
"github.com/bilibili/kratos/pkg/stat"
)
var stats = stat.Cache
// Pool memcache connection pool struct.
// Deprecated: Use Memcache instead
type Pool struct {
p pool.Pool
c *Config
}
// NewPool new a memcache conn pool.
// Deprecated: Use New instead
func NewPool(cfg *Config) (p *Pool) {
if cfg.DialTimeout <= 0 || cfg.ReadTimeout <= 0 || cfg.WriteTimeout <= 0 {
panic("must config memcache timeout")
}
p1 := pool.NewList(cfg.Config)
cnop := DialConnectTimeout(time.Duration(cfg.DialTimeout))
rdop := DialReadTimeout(time.Duration(cfg.ReadTimeout))
wrop := DialWriteTimeout(time.Duration(cfg.WriteTimeout))
p1.New = func(ctx context.Context) (io.Closer, error) {
conn, err := Dial(cfg.Proto, cfg.Addr, cnop, rdop, wrop)
return newTraceConn(conn, fmt.Sprintf("%s://%s", cfg.Proto, cfg.Addr)), err
}
p = &Pool{p: p1, c: cfg}
return
}
// Get gets a connection. The application must close the returned connection.
// This method always returns a valid connection so that applications can defer
// error handling to the first use of the connection. If there is an error
// getting an underlying connection, then the connection Err, Do, Send, Flush
// and Receive methods return that error.
func (p *Pool) Get(ctx context.Context) Conn {
c, err := p.p.Get(ctx)
if err != nil {
return errConn{err}
}
c1, _ := c.(Conn)
return &poolConn{p: p, c: c1, ctx: ctx}
}
// Close release the resources used by the pool.
func (p *Pool) Close() error {
return p.p.Close()
}
type poolConn struct {
c Conn
p *Pool
ctx context.Context
}
func pstat(key string, t time.Time, err error) {
stats.Timing(key, int64(time.Since(t)/time.Millisecond))
if err != nil {
if msg := formatErr(err); msg != "" {
stats.Incr("memcache", msg)
}
}
}
func (pc *poolConn) Close() error {
c := pc.c
if _, ok := c.(errConn); ok {
return nil
}
pc.c = errConn{ErrConnClosed}
pc.p.p.Put(context.Background(), c, c.Err() != nil)
return nil
}
func (pc *poolConn) Err() error {
return pc.c.Err()
}
func (pc *poolConn) Set(item *Item) (err error) {
return pc.c.SetContext(pc.ctx, item)
}
func (pc *poolConn) Add(item *Item) (err error) {
return pc.AddContext(pc.ctx, item)
}
func (pc *poolConn) Replace(item *Item) (err error) {
return pc.ReplaceContext(pc.ctx, item)
}
func (pc *poolConn) CompareAndSwap(item *Item) (err error) {
return pc.CompareAndSwapContext(pc.ctx, item)
}
func (pc *poolConn) Get(key string) (r *Item, err error) {
return pc.c.GetContext(pc.ctx, key)
}
func (pc *poolConn) GetMulti(keys []string) (res map[string]*Item, err error) {
return pc.c.GetMultiContext(pc.ctx, keys)
}
func (pc *poolConn) Touch(key string, timeout int32) (err error) {
return pc.c.TouchContext(pc.ctx, key, timeout)
}
func (pc *poolConn) Scan(item *Item, v interface{}) error {
return pc.c.Scan(item, v)
}
func (pc *poolConn) Delete(key string) (err error) {
return pc.c.DeleteContext(pc.ctx, key)
}
func (pc *poolConn) Increment(key string, delta uint64) (newValue uint64, err error) {
return pc.c.IncrementContext(pc.ctx, key, delta)
}
func (pc *poolConn) Decrement(key string, delta uint64) (newValue uint64, err error) {
return pc.c.DecrementContext(pc.ctx, key, delta)
}
func (pc *poolConn) AddContext(ctx context.Context, item *Item) error {
now := time.Now()
err := pc.c.AddContext(ctx, item)
pstat("memcache:add", now, err)
return err
}
func (pc *poolConn) SetContext(ctx context.Context, item *Item) error {
now := time.Now()
err := pc.c.SetContext(ctx, item)
pstat("memcache:set", now, err)
return err
}
func (pc *poolConn) ReplaceContext(ctx context.Context, item *Item) error {
now := time.Now()
err := pc.c.ReplaceContext(ctx, item)
pstat("memcache:replace", now, err)
return err
}
func (pc *poolConn) GetContext(ctx context.Context, key string) (*Item, error) {
now := time.Now()
item, err := pc.c.Get(key)
pstat("memcache:get", now, err)
return item, err
}
func (pc *poolConn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) {
// if keys is empty slice returns empty map direct
if len(keys) == 0 {
return make(map[string]*Item), nil
}
now := time.Now()
items, err := pc.c.GetMulti(keys)
pstat("memcache:gets", now, err)
return items, err
}
func (pc *poolConn) DeleteContext(ctx context.Context, key string) error {
now := time.Now()
err := pc.c.Delete(key)
pstat("memcache:delete", now, err)
return err
}
func (pc *poolConn) IncrementContext(ctx context.Context, key string, delta uint64) (uint64, error) {
now := time.Now()
newValue, err := pc.c.IncrementContext(ctx, key, delta)
pstat("memcache:increment", now, err)
return newValue, err
}
func (pc *poolConn) DecrementContext(ctx context.Context, key string, delta uint64) (uint64, error) {
now := time.Now()
newValue, err := pc.c.DecrementContext(ctx, key, delta)
pstat("memcache:decrement", now, err)
return newValue, err
}
func (pc *poolConn) CompareAndSwapContext(ctx context.Context, item *Item) error {
now := time.Now()
err := pc.c.CompareAndSwap(item)
pstat("memcache:cas", now, err)
return err
}
func (pc *poolConn) TouchContext(ctx context.Context, key string, seconds int32) error {
now := time.Now()
err := pc.c.Touch(key, seconds)
pstat("memcache:touch", now, err)
return err
}

545
pkg/cache/memcache/pool_conn_test.go vendored Normal file
View File

@ -0,0 +1,545 @@
package memcache
import (
"bytes"
"context"
"reflect"
"testing"
"time"
"github.com/bilibili/kratos/pkg/container/pool"
xtime "github.com/bilibili/kratos/pkg/time"
)
var itempool = &Item{
Key: "testpool",
Value: []byte("testpool"),
Flags: 0,
Expiration: 60,
cas: 0,
}
var itempool2 = &Item{
Key: "test_count",
Value: []byte("0"),
Flags: 0,
Expiration: 1000,
cas: 0,
}
type testObject struct {
Mid int64
Value []byte
}
var largeValue = &Item{
Key: "large_value",
Flags: FlagGOB | FlagGzip,
Expiration: 1000,
cas: 0,
}
var largeValueBoundary = &Item{
Key: "large_value",
Flags: FlagGOB | FlagGzip,
Expiration: 1000,
cas: 0,
}
func TestPoolSet(t *testing.T) {
conn := testPool.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(itempool); err != nil {
t.Errorf("memcache: set error(%v)", err)
} else {
t.Logf("memcache: set value: %s", itempool.Value)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGet(t *testing.T) {
key := "testpool"
conn := testPool.Get(context.Background())
defer conn.Close()
// get
if res, err := conn.Get(key); err != nil {
t.Errorf("memcache: get error(%v)", err)
} else {
t.Logf("memcache: get value: %s", res.Value)
}
if _, err := conn.Get("not_found"); err != ErrNotFound {
t.Errorf("memcache: expceted err is not found but got: %v", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGetMulti(t *testing.T) {
conn := testPool.Get(context.Background())
defer conn.Close()
s := []string{"testpool", "test1"}
// get
if res, err := conn.GetMulti(s); err != nil {
t.Errorf("memcache: gets error(%v)", err)
} else {
t.Logf("memcache: gets value: %d", len(res))
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolTouch(t *testing.T) {
key := "testpool"
conn := testPool.Get(context.Background())
defer conn.Close()
// touch
if err := conn.Touch(key, 10); err != nil {
t.Errorf("memcache: touch error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolIncrement(t *testing.T) {
key := "test_count"
conn := testPool.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(itempool2); err != nil {
t.Errorf("memcache: set error(%v)", err)
} else {
t.Logf("memcache: set value: 0")
}
// incr
if res, err := conn.Increment(key, 1); err != nil {
t.Errorf("memcache: incr error(%v)", err)
} else {
t.Logf("memcache: incr n: %d", res)
if res != 1 {
t.Errorf("memcache: expected res=1 but got %d", res)
}
}
// decr
if res, err := conn.Decrement(key, 1); err != nil {
t.Errorf("memcache: decr error(%v)", err)
} else {
t.Logf("memcache: decr n: %d", res)
if res != 0 {
t.Errorf("memcache: expected res=0 but got %d", res)
}
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolErr(t *testing.T) {
conn := testPool.Get(context.Background())
defer conn.Close()
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
if err := conn.Err(); err == nil {
t.Errorf("memcache: err not nil")
} else {
t.Logf("memcache: err: %v", err)
}
}
func TestPoolCompareAndSwap(t *testing.T) {
conn := testPool.Get(context.Background())
defer conn.Close()
key := "testpool"
//cas
if r, err := conn.Get(key); err != nil {
t.Errorf("conn.Get() error(%v)", err)
} else {
r.Value = []byte("shit")
if err := conn.CompareAndSwap(r); err != nil {
t.Errorf("conn.Get() error(%v)", err)
}
r, _ := conn.Get("testpool")
if r.Key != "testpool" || !bytes.Equal(r.Value, []byte("shit")) || r.Flags != 0 {
t.Error("conn.Get() error, value")
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
}
func TestPoolDel(t *testing.T) {
key := "testpool"
conn := testPool.Get(context.Background())
defer conn.Close()
// delete
if err := conn.Delete(key); err != nil {
t.Errorf("memcache: delete error(%v)", err)
} else {
t.Logf("memcache: delete key: %s", key)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func BenchmarkMemcache(b *testing.B) {
c := &Config{
Name: "test",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}
c.Config = &pool.Config{
Active: 10,
Idle: 5,
IdleTimeout: xtime.Duration(90 * time.Second),
}
testPool = NewPool(c)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
conn := testPool.Get(context.Background())
if err := conn.Close(); err != nil {
b.Errorf("memcache: close error(%v)", err)
}
}
})
if err := testPool.Close(); err != nil {
b.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolSetLargeValue(t *testing.T) {
var b bytes.Buffer
for i := 0; i < 4000000; i++ {
b.WriteByte(1)
}
obj := &testObject{}
obj.Mid = 1000
obj.Value = b.Bytes()
largeValue.Object = obj
conn := testPool.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(largeValue); err != nil {
t.Errorf("memcache: set error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGetLargeValue(t *testing.T) {
key := largeValue.Key
conn := testPool.Get(context.Background())
defer conn.Close()
// get
var err error
if _, err = conn.Get(key); err != nil {
t.Errorf("memcache: large get error(%+v)", err)
}
}
func TestPoolGetMultiLargeValue(t *testing.T) {
conn := testPool.Get(context.Background())
defer conn.Close()
s := []string{largeValue.Key, largeValue.Key}
// get
if res, err := conn.GetMulti(s); err != nil {
t.Errorf("memcache: gets error(%v)", err)
} else {
t.Logf("memcache: gets value: %d", len(res))
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolSetLargeValueBoundary(t *testing.T) {
var b bytes.Buffer
for i := 0; i < _largeValue; i++ {
b.WriteByte(1)
}
obj := &testObject{}
obj.Mid = 1000
obj.Value = b.Bytes()
largeValueBoundary.Object = obj
conn := testPool.Get(context.Background())
defer conn.Close()
// set
if err := conn.Set(largeValueBoundary); err != nil {
t.Errorf("memcache: set error(%v)", err)
}
if err := conn.Close(); err != nil {
t.Errorf("memcache: close error(%v)", err)
}
}
func TestPoolGetLargeValueBoundary(t *testing.T) {
key := largeValueBoundary.Key
conn := testPool.Get(context.Background())
defer conn.Close()
// get
var err error
if _, err = conn.Get(key); err != nil {
t.Errorf("memcache: large get error(%v)", err)
}
}
func TestPoolAdd(t *testing.T) {
var (
key = "test_add"
item = &Item{
Key: key,
Value: []byte("0"),
Flags: 0,
Expiration: 60,
cas: 0,
}
conn = testPool.Get(context.Background())
)
defer conn.Close()
conn.Delete(key)
if err := conn.Add(item); err != nil {
t.Errorf("memcache: add error(%v)", err)
}
if err := conn.Add(item); err != ErrNotStored {
t.Errorf("memcache: add error(%v)", err)
}
}
func TestNewPool(t *testing.T) {
type args struct {
cfg *Config
}
tests := []struct {
name string
args args
wantErr error
wantPanic bool
}{
{
"NewPoolIllegalDialTimeout",
args{
&Config{
Name: "test_illegal_dial_timeout",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(-time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
},
},
nil,
true,
},
{
"NewPoolIllegalReadTimeout",
args{
&Config{
Name: "test_illegal_read_timeout",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(-time.Second),
WriteTimeout: xtime.Duration(time.Second),
},
},
nil,
true,
},
{
"NewPoolIllegalWriteTimeout",
args{
&Config{
Name: "test_illegal_write_timeout",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(-time.Second),
},
},
nil,
true,
},
{
"NewPool",
args{
&Config{
Name: "test_new",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
},
},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
r := recover()
if (r != nil) != tt.wantPanic {
t.Errorf("wantPanic recover = %v, wantPanic = %v", r, tt.wantPanic)
}
}()
if gotP := NewPool(tt.args.cfg); gotP == nil {
t.Error("NewPool() failed, got nil")
}
})
}
}
func TestPool_Get(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
p *Pool
args args
wantErr bool
n int
}{
{
"Get",
NewPool(&Config{
Config: &pool.Config{
Active: 3,
Idle: 2,
},
Name: "test_get",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}),
args{context.TODO()},
false,
3,
},
{
"GetExceededPoolSize",
NewPool(&Config{
Config: &pool.Config{
Active: 3,
Idle: 2,
},
Name: "test_get_out",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}),
args{context.TODO()},
true,
6,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for i := 1; i <= tt.n; i++ {
got := tt.p.Get(tt.args.ctx)
if reflect.TypeOf(got) == reflect.TypeOf(errConn{}) {
if !tt.wantErr {
t.Errorf("got errConn, export Conn")
}
return
} else {
if tt.wantErr {
if i > tt.p.c.Active {
t.Errorf("got Conn, export errConn")
}
}
}
}
})
}
}
func TestPool_Close(t *testing.T) {
type args struct {
ctx context.Context
}
tests := []struct {
name string
p *Pool
args args
wantErr bool
g int
c int
}{
{
"Close",
NewPool(&Config{
Config: &pool.Config{
Active: 1,
Idle: 1,
},
Name: "test_get",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}),
args{context.TODO()},
false,
3,
3,
},
{
"CloseExceededPoolSize",
NewPool(&Config{
Config: &pool.Config{
Active: 1,
Idle: 1,
},
Name: "test_get_out",
Proto: "tcp",
Addr: testMemcacheAddr,
DialTimeout: xtime.Duration(time.Second),
ReadTimeout: xtime.Duration(time.Second),
WriteTimeout: xtime.Duration(time.Second),
}),
args{context.TODO()},
true,
5,
3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for i := 1; i <= tt.g; i++ {
got := tt.p.Get(tt.args.ctx)
if err := got.Close(); err != nil {
if !tt.wantErr {
t.Error(err)
}
}
if i <= tt.c {
if err := got.Close(); err != nil {
t.Error(err)
}
}
}
})
}
}

48
pkg/cache/memcache/test/BUILD.bazel vendored Normal file
View File

@ -0,0 +1,48 @@
load(
"@io_bazel_rules_go//go:def.bzl",
"go_library",
)
load(
"@io_bazel_rules_go//proto:def.bzl",
"go_proto_library",
)
go_library(
name = "go_default_library",
srcs = [],
embed = [":proto_go_proto"],
importpath = "go-common/library/cache/memcache/test",
tags = ["automanaged"],
visibility = ["//visibility:public"],
deps = ["@com_github_golang_protobuf//proto:go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)
proto_library(
name = "test_proto",
srcs = ["test.proto"],
import_prefix = "go-common/library/cache/memcache/test",
strip_import_prefix = "",
tags = ["automanaged"],
)
go_proto_library(
name = "proto_go_proto",
compilers = ["@io_bazel_rules_go//proto:go_proto"],
importpath = "go-common/library/cache/memcache/test",
proto = ":test_proto",
tags = ["automanaged"],
)

375
pkg/cache/memcache/test/test.pb.go vendored Normal file
View File

@ -0,0 +1,375 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: test.proto
/*
Package proto is a generated protocol buffer package.
It is generated from these files:
test.proto
It has these top-level messages:
TestItem
*/
package proto
import proto1 "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import io "io"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto1.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto1.ProtoPackageIsVersion2 // please upgrade the proto package
type FOO int32
const (
FOO_X FOO = 0
)
var FOO_name = map[int32]string{
0: "X",
}
var FOO_value = map[string]int32{
"X": 0,
}
func (x FOO) String() string {
return proto1.EnumName(FOO_name, int32(x))
}
func (FOO) EnumDescriptor() ([]byte, []int) { return fileDescriptorTest, []int{0} }
type TestItem struct {
Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"`
Age int32 `protobuf:"varint,2,opt,name=Age,proto3" json:"Age,omitempty"`
}
func (m *TestItem) Reset() { *m = TestItem{} }
func (m *TestItem) String() string { return proto1.CompactTextString(m) }
func (*TestItem) ProtoMessage() {}
func (*TestItem) Descriptor() ([]byte, []int) { return fileDescriptorTest, []int{0} }
func (m *TestItem) GetName() string {
if m != nil {
return m.Name
}
return ""
}
func (m *TestItem) GetAge() int32 {
if m != nil {
return m.Age
}
return 0
}
func init() {
proto1.RegisterType((*TestItem)(nil), "proto.TestItem")
proto1.RegisterEnum("proto.FOO", FOO_name, FOO_value)
}
func (m *TestItem) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalTo(dAtA)
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *TestItem) MarshalTo(dAtA []byte) (int, error) {
var i int
_ = i
var l int
_ = l
if len(m.Name) > 0 {
dAtA[i] = 0xa
i++
i = encodeVarintTest(dAtA, i, uint64(len(m.Name)))
i += copy(dAtA[i:], m.Name)
}
if m.Age != 0 {
dAtA[i] = 0x10
i++
i = encodeVarintTest(dAtA, i, uint64(m.Age))
}
return i, nil
}
func encodeFixed64Test(dAtA []byte, offset int, v uint64) int {
dAtA[offset] = uint8(v)
dAtA[offset+1] = uint8(v >> 8)
dAtA[offset+2] = uint8(v >> 16)
dAtA[offset+3] = uint8(v >> 24)
dAtA[offset+4] = uint8(v >> 32)
dAtA[offset+5] = uint8(v >> 40)
dAtA[offset+6] = uint8(v >> 48)
dAtA[offset+7] = uint8(v >> 56)
return offset + 8
}
func encodeFixed32Test(dAtA []byte, offset int, v uint32) int {
dAtA[offset] = uint8(v)
dAtA[offset+1] = uint8(v >> 8)
dAtA[offset+2] = uint8(v >> 16)
dAtA[offset+3] = uint8(v >> 24)
return offset + 4
}
func encodeVarintTest(dAtA []byte, offset int, v uint64) int {
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return offset + 1
}
func (m *TestItem) Size() (n int) {
var l int
_ = l
l = len(m.Name)
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
if m.Age != 0 {
n += 1 + sovTest(uint64(m.Age))
}
return n
}
func sovTest(x uint64) (n int) {
for {
n++
x >>= 7
if x == 0 {
break
}
}
return n
}
func sozTest(x uint64) (n int) {
return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *TestItem) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: TestItem: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: TestItem: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthTest
}
postIndex := iNdEx + intStringLen
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Name = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType)
}
m.Age = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Age |= (int32(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipTest(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthTest
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipTest(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
return iNdEx, nil
case 1:
iNdEx += 8
return iNdEx, nil
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
iNdEx += length
if length < 0 {
return 0, ErrInvalidLengthTest
}
return iNdEx, nil
case 3:
for {
var innerWire uint64
var start int = iNdEx
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
innerWire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
innerWireType := int(innerWire & 0x7)
if innerWireType == 4 {
break
}
next, err := skipTest(dAtA[start:])
if err != nil {
return 0, err
}
iNdEx = start + next
}
return iNdEx, nil
case 4:
return iNdEx, nil
case 5:
iNdEx += 4
return iNdEx, nil
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
}
panic("unreachable")
}
var (
ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowTest = fmt.Errorf("proto: integer overflow")
)
func init() { proto1.RegisterFile("test.proto", fileDescriptorTest) }
var fileDescriptorTest = []byte{
// 122 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x53, 0x4a, 0x06, 0x5c, 0x1c, 0x21, 0xa9,
0xc5, 0x25, 0x9e, 0x25, 0xa9, 0xb9, 0x42, 0x42, 0x5c, 0x2c, 0x7e, 0x89, 0xb9, 0xa9, 0x12, 0x8c,
0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x60, 0xb6, 0x90, 0x00, 0x17, 0xb3, 0x63, 0x7a, 0xaa, 0x04, 0x93,
0x02, 0xa3, 0x06, 0x6b, 0x10, 0x88, 0xa9, 0xc5, 0xc3, 0xc5, 0xec, 0xe6, 0xef, 0x2f, 0xc4, 0xca,
0xc5, 0x18, 0x21, 0xc0, 0xe0, 0x24, 0x70, 0xe2, 0x91, 0x1c, 0xe3, 0x85, 0x47, 0x72, 0x8c, 0x0f,
0x1e, 0xc9, 0x31, 0xce, 0x78, 0x2c, 0xc7, 0x90, 0xc4, 0x06, 0x36, 0xd8, 0x18, 0x10, 0x00, 0x00,
0xff, 0xff, 0x16, 0x80, 0x60, 0x15, 0x6d, 0x00, 0x00, 0x00,
}

12
pkg/cache/memcache/test/test.proto vendored Normal file
View File

@ -0,0 +1,12 @@
syntax = "proto3";
package proto;
enum FOO
{
X = 0;
};
message TestItem{
string Name = 1;
int32 Age = 2;
}

View File

@ -1,109 +0,0 @@
package memcache
import (
"context"
"strconv"
"strings"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/trace"
)
const (
_traceFamily = "memcache"
_traceSpanKind = "client"
_traceComponentName = "library/cache/memcache"
_tracePeerService = "memcache"
_slowLogDuration = time.Millisecond * 250
)
type traceConn struct {
Conn
ctx context.Context
address string
}
func (t *traceConn) setTrace(action, statement string) func(error) error {
now := time.Now()
parent, ok := trace.FromContext(t.ctx)
if !ok {
return func(err error) error { return err }
}
span := parent.Fork(_traceFamily, "Memcache:"+action)
span.SetTag(
trace.String(trace.TagSpanKind, _traceSpanKind),
trace.String(trace.TagComponent, _traceComponentName),
trace.String(trace.TagPeerService, _tracePeerService),
trace.String(trace.TagPeerAddress, t.address),
trace.String(trace.TagDBStatement, action+" "+statement),
)
return func(err error) error {
span.Finish(&err)
t := time.Since(now)
if t > _slowLogDuration {
log.Warn("%s slow log action: %s key: %s time: %v", _traceFamily, action, statement, t)
}
return err
}
}
func (t *traceConn) WithContext(ctx context.Context) Conn {
t.ctx = ctx
t.Conn = t.Conn.WithContext(ctx)
return t
}
func (t *traceConn) Add(item *Item) error {
finishFn := t.setTrace("Add", item.Key)
return finishFn(t.Conn.Add(item))
}
func (t *traceConn) Set(item *Item) error {
finishFn := t.setTrace("Set", item.Key)
return finishFn(t.Conn.Set(item))
}
func (t *traceConn) Replace(item *Item) error {
finishFn := t.setTrace("Replace", item.Key)
return finishFn(t.Conn.Replace(item))
}
func (t *traceConn) Get(key string) (*Item, error) {
finishFn := t.setTrace("Get", key)
item, err := t.Conn.Get(key)
return item, finishFn(err)
}
func (t *traceConn) GetMulti(keys []string) (map[string]*Item, error) {
finishFn := t.setTrace("GetMulti", strings.Join(keys, " "))
items, err := t.Conn.GetMulti(keys)
return items, finishFn(err)
}
func (t *traceConn) Delete(key string) error {
finishFn := t.setTrace("Delete", key)
return finishFn(t.Conn.Delete(key))
}
func (t *traceConn) Increment(key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace("Increment", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Increment(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) Decrement(key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace("Decrement", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Decrement(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) CompareAndSwap(item *Item) error {
finishFn := t.setTrace("CompareAndSwap", item.Key)
return finishFn(t.Conn.CompareAndSwap(item))
}
func (t *traceConn) Touch(key string, seconds int32) (err error) {
finishFn := t.setTrace("Touch", key+" "+strconv.Itoa(int(seconds)))
return finishFn(t.Conn.Touch(key, seconds))
}

103
pkg/cache/memcache/trace_conn.go vendored Normal file
View File

@ -0,0 +1,103 @@
package memcache
import (
"context"
"strconv"
"strings"
"time"
"github.com/bilibili/kratos/pkg/log"
"github.com/bilibili/kratos/pkg/net/trace"
)
const (
_slowLogDuration = time.Millisecond * 250
)
func newTraceConn(conn Conn, address string) Conn {
tags := []trace.Tag{
trace.String(trace.TagSpanKind, "client"),
trace.String(trace.TagComponent, "cache/memcache"),
trace.String(trace.TagPeerService, "memcache"),
trace.String(trace.TagPeerAddress, address),
}
return &traceConn{Conn: conn, tags: tags}
}
type traceConn struct {
Conn
tags []trace.Tag
}
func (t *traceConn) setTrace(ctx context.Context, action, statement string) func(error) error {
now := time.Now()
parent, ok := trace.FromContext(ctx)
if !ok {
return func(err error) error { return err }
}
span := parent.Fork("", "Memcache:"+action)
span.SetTag(t.tags...)
span.SetTag(trace.String(trace.TagDBStatement, action+" "+statement))
return func(err error) error {
span.Finish(&err)
t := time.Since(now)
if t > _slowLogDuration {
log.Warn("memcache slow log action: %s key: %s time: %v", action, statement, t)
}
return err
}
}
func (t *traceConn) AddContext(ctx context.Context, item *Item) error {
finishFn := t.setTrace(ctx, "Add", item.Key)
return finishFn(t.Conn.Add(item))
}
func (t *traceConn) SetContext(ctx context.Context, item *Item) error {
finishFn := t.setTrace(ctx, "Set", item.Key)
return finishFn(t.Conn.Set(item))
}
func (t *traceConn) ReplaceContext(ctx context.Context, item *Item) error {
finishFn := t.setTrace(ctx, "Replace", item.Key)
return finishFn(t.Conn.Replace(item))
}
func (t *traceConn) GetContext(ctx context.Context, key string) (*Item, error) {
finishFn := t.setTrace(ctx, "Get", key)
item, err := t.Conn.Get(key)
return item, finishFn(err)
}
func (t *traceConn) GetMultiContext(ctx context.Context, keys []string) (map[string]*Item, error) {
finishFn := t.setTrace(ctx, "GetMulti", strings.Join(keys, " "))
items, err := t.Conn.GetMulti(keys)
return items, finishFn(err)
}
func (t *traceConn) DeleteContext(ctx context.Context, key string) error {
finishFn := t.setTrace(ctx, "Delete", key)
return finishFn(t.Conn.Delete(key))
}
func (t *traceConn) IncrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace(ctx, "Increment", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Increment(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) DecrementContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) {
finishFn := t.setTrace(ctx, "Decrement", key+" "+strconv.FormatUint(delta, 10))
newValue, err = t.Conn.Decrement(key, delta)
return newValue, finishFn(err)
}
func (t *traceConn) CompareAndSwapContext(ctx context.Context, item *Item) error {
finishFn := t.setTrace(ctx, "CompareAndSwap", item.Key)
return finishFn(t.Conn.CompareAndSwap(item))
}
func (t *traceConn) TouchContext(ctx context.Context, key string, seconds int32) (err error) {
finishFn := t.setTrace(ctx, "Touch", key+" "+strconv.Itoa(int(seconds)))
return finishFn(t.Conn.Touch(key, seconds))
}

View File

@ -1,9 +1,57 @@
package memcache
import (
"context"
"time"
"github.com/gogo/protobuf/proto"
)
func legalKey(key string) bool {
if len(key) > 250 || len(key) == 0 {
return false
}
for i := 0; i < len(key); i++ {
if key[i] <= ' ' || key[i] == 0x7f {
return false
}
}
return true
}
// MockWith error
func MockWith(err error) Conn {
return errConn{err}
}
type errConn struct{ err error }
func (c errConn) Err() error { return c.err }
func (c errConn) Close() error { return c.err }
func (c errConn) Add(*Item) error { return c.err }
func (c errConn) Set(*Item) error { return c.err }
func (c errConn) Replace(*Item) error { return c.err }
func (c errConn) CompareAndSwap(*Item) error { return c.err }
func (c errConn) Get(string) (*Item, error) { return nil, c.err }
func (c errConn) GetMulti([]string) (map[string]*Item, error) { return nil, c.err }
func (c errConn) Touch(string, int32) error { return c.err }
func (c errConn) Delete(string) error { return c.err }
func (c errConn) Increment(string, uint64) (uint64, error) { return 0, c.err }
func (c errConn) Decrement(string, uint64) (uint64, error) { return 0, c.err }
func (c errConn) Scan(*Item, interface{}) error { return c.err }
func (c errConn) AddContext(context.Context, *Item) error { return c.err }
func (c errConn) SetContext(context.Context, *Item) error { return c.err }
func (c errConn) ReplaceContext(context.Context, *Item) error { return c.err }
func (c errConn) GetContext(context.Context, string) (*Item, error) { return nil, c.err }
func (c errConn) DecrementContext(context.Context, string, uint64) (uint64, error) { return 0, c.err }
func (c errConn) CompareAndSwapContext(context.Context, *Item) error { return c.err }
func (c errConn) TouchContext(context.Context, string, int32) error { return c.err }
func (c errConn) DeleteContext(context.Context, string) error { return c.err }
func (c errConn) IncrementContext(context.Context, string, uint64) (uint64, error) { return 0, c.err }
func (c errConn) GetMultiContext(context.Context, []string) (map[string]*Item, error) {
return nil, c.err
}
// RawItem item with FlagRAW flag.
//
// Expiration is the cache expiration time, in seconds: either a relative
@ -30,3 +78,12 @@ func JSONItem(key string, v interface{}, flags uint32, expiration int32) *Item {
func ProtobufItem(key string, message proto.Message, flags uint32, expiration int32) *Item {
return &Item{Key: key, Flags: flags | FlagProtobuf, Object: message, Expiration: expiration}
}
func shrinkDeadline(ctx context.Context, timeout time.Duration) time.Time {
// TODO: ignored context deadline to compatible old behaviour.
//deadline, ok := ctx.Deadline()
//if ok {
// return deadline
//}
return time.Now().Add(timeout)
}

75
pkg/cache/memcache/util_test.go vendored Normal file
View File

@ -0,0 +1,75 @@
package memcache
import (
"testing"
pb "github.com/bilibili/kratos/pkg/cache/memcache/test"
"github.com/stretchr/testify/assert"
)
func TestItemUtil(t *testing.T) {
item1 := RawItem("test", []byte("hh"), 0, 0)
assert.Equal(t, "test", item1.Key)
assert.Equal(t, []byte("hh"), item1.Value)
assert.Equal(t, FlagRAW, FlagRAW&item1.Flags)
item1 = JSONItem("test", &Item{}, 0, 0)
assert.Equal(t, "test", item1.Key)
assert.NotNil(t, item1.Object)
assert.Equal(t, FlagJSON, FlagJSON&item1.Flags)
item1 = ProtobufItem("test", &pb.TestItem{}, 0, 0)
assert.Equal(t, "test", item1.Key)
assert.NotNil(t, item1.Object)
assert.Equal(t, FlagProtobuf, FlagProtobuf&item1.Flags)
}
func TestLegalKey(t *testing.T) {
type args struct {
key string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "test empty key",
want: false,
},
{
name: "test too large key",
args: args{func() string {
var data []byte
for i := 0; i < 255; i++ {
data = append(data, 'k')
}
return string(data)
}()},
want: false,
},
{
name: "test invalid char",
args: args{"hello world"},
want: false,
},
{
name: "test invalid char",
args: args{string([]byte{0x7f})},
want: false,
},
{
name: "test normal key",
args: args{"hello"},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := legalKey(tt.args.key); got != tt.want {
t.Errorf("legalKey() = %v, want %v", got, tt.want)
}
})
}
}