mirror of
https://github.com/go-kratos/kratos.git
synced 2025-02-09 13:36:57 +02:00
parent
50d0129461
commit
637a6a3628
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +1,4 @@
|
||||
go.sum
|
||||
BUILD
|
||||
.DS_Store
|
||||
tool/kratos/kratos
|
||||
|
18
pkg/conf/flagvar/flagvar.go
Normal file
18
pkg/conf/flagvar/flagvar.go
Normal file
@ -0,0 +1,18 @@
|
||||
package flagvar
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// StringVars []string implement flag.Value
|
||||
type StringVars []string
|
||||
|
||||
func (s StringVars) String() string {
|
||||
return strings.Join(s, ",")
|
||||
}
|
||||
|
||||
// Set implement flag.Value
|
||||
func (s *StringVars) Set(val string) error {
|
||||
*s = append(*s, val)
|
||||
return nil
|
||||
}
|
@ -12,6 +12,7 @@ var (
|
||||
NothingFound = add(-404) // 啥都木有
|
||||
MethodNotAllowed = add(-405) // 不支持该方法
|
||||
Conflict = add(-409) // 冲突
|
||||
Canceled = add(-498) // 客户端取消请求
|
||||
ServerErr = add(-500) // 服务器错误
|
||||
ServiceUnavailable = add(-503) // 过载保护,服务暂不可用
|
||||
Deadline = add(-504) // 服务调用超时
|
||||
|
48
pkg/ecode/pb/ecode.go
Normal file
48
pkg/ecode/pb/ecode.go
Normal file
@ -0,0 +1,48 @@
|
||||
package pb
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
|
||||
any "github.com/golang/protobuf/ptypes/any"
|
||||
)
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return strconv.FormatInt(int64(e.GetErrCode()), 10)
|
||||
}
|
||||
|
||||
// Code is the code of error.
|
||||
func (e *Error) Code() int {
|
||||
return int(e.GetErrCode())
|
||||
}
|
||||
|
||||
// Message is error message.
|
||||
func (e *Error) Message() string {
|
||||
return e.GetErrMessage()
|
||||
}
|
||||
|
||||
// Equal compare whether two errors are equal.
|
||||
func (e *Error) Equal(ec error) bool {
|
||||
return ecode.Cause(ec).Code() == e.Code()
|
||||
}
|
||||
|
||||
// Details return error details.
|
||||
func (e *Error) Details() []interface{} {
|
||||
return []interface{}{e.GetErrDetail()}
|
||||
}
|
||||
|
||||
// From will convert ecode.Codes to pb.Error.
|
||||
//
|
||||
// Deprecated: please use ecode.Error
|
||||
func From(ec ecode.Codes) *Error {
|
||||
var detail *any.Any
|
||||
if details := ec.Details(); len(details) > 0 {
|
||||
detail, _ = details[0].(*any.Any)
|
||||
}
|
||||
return &Error{
|
||||
ErrCode: int32(ec.Code()),
|
||||
ErrMessage: ec.Message(),
|
||||
ErrDetail: detail,
|
||||
}
|
||||
}
|
96
pkg/ecode/pb/ecode.pb.go
Normal file
96
pkg/ecode/pb/ecode.pb.go
Normal file
@ -0,0 +1,96 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: error.proto
|
||||
|
||||
package pb
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import any "github.com/golang/protobuf/ptypes/any"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
// Deprecated: please use ecode.Error
|
||||
type Error struct {
|
||||
ErrCode int32 `protobuf:"varint,1,opt,name=err_code,json=errCode,proto3" json:"err_code,omitempty"`
|
||||
ErrMessage string `protobuf:"bytes,2,opt,name=err_message,json=errMessage,proto3" json:"err_message,omitempty"`
|
||||
ErrDetail *any.Any `protobuf:"bytes,3,opt,name=err_detail,json=errDetail,proto3" json:"err_detail,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Error) Reset() { *m = Error{} }
|
||||
func (m *Error) String() string { return proto.CompactTextString(m) }
|
||||
func (*Error) ProtoMessage() {}
|
||||
func (*Error) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_error_28aad86a4e53115b, []int{0}
|
||||
}
|
||||
func (m *Error) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Error.Unmarshal(m, b)
|
||||
}
|
||||
func (m *Error) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_Error.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (dst *Error) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Error.Merge(dst, src)
|
||||
}
|
||||
func (m *Error) XXX_Size() int {
|
||||
return xxx_messageInfo_Error.Size(m)
|
||||
}
|
||||
func (m *Error) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Error.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Error proto.InternalMessageInfo
|
||||
|
||||
func (m *Error) GetErrCode() int32 {
|
||||
if m != nil {
|
||||
return m.ErrCode
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Error) GetErrMessage() string {
|
||||
if m != nil {
|
||||
return m.ErrMessage
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Error) GetErrDetail() *any.Any {
|
||||
if m != nil {
|
||||
return m.ErrDetail
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Error)(nil), "err.Error")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("error.proto", fileDescriptor_error_28aad86a4e53115b) }
|
||||
|
||||
var fileDescriptor_error_28aad86a4e53115b = []byte{
|
||||
// 164 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x34, 0x8d, 0xc1, 0xca, 0x82, 0x40,
|
||||
0x14, 0x85, 0x99, 0x5f, 0xfc, 0xcb, 0x71, 0x37, 0xb4, 0xd0, 0x36, 0x49, 0x2b, 0x57, 0x23, 0xe4,
|
||||
0x13, 0x44, 0xb5, 0x6c, 0xe3, 0x0b, 0x88, 0xe6, 0x49, 0x02, 0xf3, 0xc6, 0xd1, 0x20, 0xdf, 0x3e,
|
||||
0x1c, 0x69, 0x79, 0xcf, 0xf7, 0x71, 0x3f, 0x1d, 0x82, 0x14, 0xda, 0x17, 0x65, 0x14, 0xe3, 0x81,
|
||||
0xdc, 0xc6, 0xad, 0x48, 0xdb, 0x21, 0x73, 0x53, 0xfd, 0xbe, 0x67, 0x55, 0x3f, 0x2d, 0x7c, 0xff,
|
||||
0xd1, 0xfe, 0x65, 0xd6, 0x4d, 0xac, 0xd7, 0x20, 0xcb, 0x9b, 0x34, 0x88, 0x54, 0xa2, 0x52, 0xbf,
|
||||
0x58, 0x81, 0x3c, 0x49, 0x03, 0xb3, 0x73, 0x2f, 0xcb, 0x27, 0x86, 0xa1, 0x6a, 0x11, 0xfd, 0x25,
|
||||
0x2a, 0x0d, 0x0a, 0x0d, 0xf2, 0xba, 0x2c, 0x26, 0xd7, 0xf3, 0x55, 0x36, 0x18, 0xab, 0x47, 0x17,
|
||||
0x79, 0x89, 0x4a, 0xc3, 0xc3, 0xc6, 0x2e, 0x51, 0xfb, 0x8b, 0xda, 0x63, 0x3f, 0x15, 0x01, 0xc8,
|
||||
0xb3, 0xd3, 0xea, 0x7f, 0x07, 0xf2, 0x6f, 0x00, 0x00, 0x00, 0xff, 0xff, 0xf7, 0x41, 0x22, 0xfd,
|
||||
0xaf, 0x00, 0x00, 0x00,
|
||||
}
|
13
pkg/ecode/pb/ecode.proto
Normal file
13
pkg/ecode/pb/ecode.proto
Normal file
@ -0,0 +1,13 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package pb;
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
|
||||
option go_package = "go-common/library/ecode/pb";
|
||||
|
||||
message Error {
|
||||
int32 err_code = 1;
|
||||
string err_message = 2;
|
||||
google.protobuf.Any err_detail = 3;
|
||||
}
|
103
pkg/ecode/status.go
Normal file
103
pkg/ecode/status.go
Normal file
@ -0,0 +1,103 @@
|
||||
package ecode
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode/types"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
)
|
||||
|
||||
// Error new status with code and message
|
||||
func Error(code Code, message string) *Status {
|
||||
return &Status{s: &types.Status{Code: int32(code.Code()), Message: message}}
|
||||
}
|
||||
|
||||
// Errorf new status with code and message
|
||||
func Errorf(code Code, format string, args ...interface{}) *Status {
|
||||
return Error(code, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
var _ Codes = &Status{}
|
||||
|
||||
// Status statusError is an alias of a status proto
|
||||
// implement ecode.Codes
|
||||
type Status struct {
|
||||
s *types.Status
|
||||
}
|
||||
|
||||
// Error implement error
|
||||
func (s *Status) Error() string {
|
||||
return s.Message()
|
||||
}
|
||||
|
||||
// Code return error code
|
||||
func (s *Status) Code() int {
|
||||
return int(s.s.Code)
|
||||
}
|
||||
|
||||
// Message return error message for developer
|
||||
func (s *Status) Message() string {
|
||||
if s.s.Message == "" {
|
||||
return strconv.Itoa(int(s.s.Code))
|
||||
}
|
||||
return s.s.Message
|
||||
}
|
||||
|
||||
// Details return error details
|
||||
func (s *Status) Details() []interface{} {
|
||||
if s == nil || s.s == nil {
|
||||
return nil
|
||||
}
|
||||
details := make([]interface{}, 0, len(s.s.Details))
|
||||
for _, any := range s.s.Details {
|
||||
detail := &ptypes.DynamicAny{}
|
||||
if err := ptypes.UnmarshalAny(any, detail); err != nil {
|
||||
details = append(details, err)
|
||||
continue
|
||||
}
|
||||
details = append(details, detail.Message)
|
||||
}
|
||||
return details
|
||||
}
|
||||
|
||||
// WithDetails WithDetails
|
||||
func (s *Status) WithDetails(pbs ...proto.Message) (*Status, error) {
|
||||
for _, pb := range pbs {
|
||||
anyMsg, err := ptypes.MarshalAny(pb)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
s.s.Details = append(s.s.Details, anyMsg)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Equal for compatible.
|
||||
// Deprecated: please use ecode.EqualError.
|
||||
func (s *Status) Equal(err error) bool {
|
||||
return EqualError(s, err)
|
||||
}
|
||||
|
||||
// Proto return origin protobuf message
|
||||
func (s *Status) Proto() *types.Status {
|
||||
return s.s
|
||||
}
|
||||
|
||||
// FromCode create status from ecode
|
||||
func FromCode(code Code) *Status {
|
||||
return &Status{s: &types.Status{Code: int32(code)}}
|
||||
}
|
||||
|
||||
// FromProto new status from grpc detail
|
||||
func FromProto(pbMsg proto.Message) Codes {
|
||||
if msg, ok := pbMsg.(*types.Status); ok {
|
||||
if msg.Message == "" {
|
||||
// NOTE: if message is empty convert to pure Code, will get message from config center.
|
||||
return Code(msg.Code)
|
||||
}
|
||||
return &Status{s: msg}
|
||||
}
|
||||
return Errorf(ServerErr, "invalid proto message get %v", pbMsg)
|
||||
}
|
66
pkg/ecode/status_test.go
Normal file
66
pkg/ecode/status_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package ecode
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/smartystreets/goconvey/convey"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode/types"
|
||||
)
|
||||
|
||||
func TestEqual(t *testing.T) {
|
||||
convey.Convey("Equal", t, func(ctx convey.C) {
|
||||
ctx.Convey("When err1=Error(RequestErr, 'test') and err2=Errorf(RequestErr, 'test')", func(ctx convey.C) {
|
||||
err1 := Error(RequestErr, "test")
|
||||
err2 := Errorf(RequestErr, "test")
|
||||
ctx.Convey("Then err1=err2, err1 != nil", func(ctx convey.C) {
|
||||
ctx.So(err1, convey.ShouldResemble, err2)
|
||||
ctx.So(err1, convey.ShouldNotBeNil)
|
||||
})
|
||||
})
|
||||
})
|
||||
// assert.True(t, OK.Equal(nil))
|
||||
// assert.True(t, err1.Equal(err2))
|
||||
// assert.False(t, err1.Equal(nil))
|
||||
// assert.True(t, Equal(nil, nil))
|
||||
}
|
||||
|
||||
func TestDetail(t *testing.T) {
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()}
|
||||
st, _ := Error(RequestErr, "RequestErr").WithDetails(m)
|
||||
|
||||
assert.Equal(t, "RequestErr", st.Message())
|
||||
assert.Equal(t, int(RequestErr), st.Code())
|
||||
assert.IsType(t, m, st.Details()[0])
|
||||
}
|
||||
|
||||
func TestFromCode(t *testing.T) {
|
||||
err := FromCode(RequestErr)
|
||||
|
||||
assert.Equal(t, int(RequestErr), err.Code())
|
||||
assert.Equal(t, "-400", err.Message())
|
||||
}
|
||||
|
||||
func TestFromProto(t *testing.T) {
|
||||
msg := &types.Status{Code: 2233, Message: "error"}
|
||||
err := FromProto(msg)
|
||||
|
||||
assert.Equal(t, 2233, err.Code())
|
||||
assert.Equal(t, "error", err.Message())
|
||||
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()}
|
||||
err = FromProto(m)
|
||||
assert.Equal(t, -500, err.Code())
|
||||
assert.Contains(t, err.Message(), "invalid proto message get")
|
||||
}
|
||||
|
||||
func TestEmpty(t *testing.T) {
|
||||
st := &Status{}
|
||||
assert.Len(t, st.Details(), 0)
|
||||
|
||||
st = nil
|
||||
assert.Len(t, st.Details(), 0)
|
||||
}
|
102
pkg/ecode/types/status.pb.go
Normal file
102
pkg/ecode/types/status.pb.go
Normal file
@ -0,0 +1,102 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: internal/types/status.proto
|
||||
|
||||
package types // import "github.com/bilibili/kratos/pkg/ecode/types"
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import any "github.com/golang/protobuf/ptypes/any"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type Status struct {
|
||||
// The error code see ecode.Code
|
||||
Code int32 `protobuf:"varint,1,opt,name=code" json:"code,omitempty"`
|
||||
// A developer-facing error message, which should be in English. Any
|
||||
Message string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"`
|
||||
// A list of messages that carry the error details. There is a common set of
|
||||
// message types for APIs to use.
|
||||
Details []*any.Any `protobuf:"bytes,3,rep,name=details" json:"details,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Status) Reset() { *m = Status{} }
|
||||
func (m *Status) String() string { return proto.CompactTextString(m) }
|
||||
func (*Status) ProtoMessage() {}
|
||||
func (*Status) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_status_88668d6b2bf80f08, []int{0}
|
||||
}
|
||||
func (m *Status) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Status.Unmarshal(m, b)
|
||||
}
|
||||
func (m *Status) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_Status.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (dst *Status) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Status.Merge(dst, src)
|
||||
}
|
||||
func (m *Status) XXX_Size() int {
|
||||
return xxx_messageInfo_Status.Size(m)
|
||||
}
|
||||
func (m *Status) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Status.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Status proto.InternalMessageInfo
|
||||
|
||||
func (m *Status) GetCode() int32 {
|
||||
if m != nil {
|
||||
return m.Code
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Status) GetMessage() string {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Status) GetDetails() []*any.Any {
|
||||
if m != nil {
|
||||
return m.Details
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Status)(nil), "bilibili.rpc.Status")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("internal/types/status.proto", fileDescriptor_status_88668d6b2bf80f08) }
|
||||
|
||||
var fileDescriptor_status_88668d6b2bf80f08 = []byte{
|
||||
// 220 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x54, 0x8f, 0xb1, 0x4a, 0x04, 0x31,
|
||||
0x10, 0x86, 0xd9, 0x5b, 0xbd, 0xc3, 0x9c, 0x85, 0x04, 0x8b, 0x55, 0x9b, 0xc5, 0x6a, 0x0b, 0x4d,
|
||||
0x40, 0x4b, 0x2b, 0xcf, 0x17, 0x58, 0x22, 0x36, 0x76, 0x49, 0x6e, 0x2e, 0x04, 0x92, 0xcc, 0x92,
|
||||
0xe4, 0x8a, 0xbc, 0x8e, 0x4f, 0x2a, 0x9b, 0x65, 0x41, 0x8b, 0x19, 0x66, 0x98, 0xff, 0xe7, 0xfb,
|
||||
0x87, 0x3c, 0xd8, 0x90, 0x21, 0x06, 0xe9, 0x78, 0x2e, 0x13, 0x24, 0x9e, 0xb2, 0xcc, 0xe7, 0xc4,
|
||||
0xa6, 0x88, 0x19, 0xe9, 0xb5, 0xb2, 0xce, 0xce, 0xc5, 0xe2, 0xa4, 0xef, 0xef, 0x0c, 0xa2, 0x71,
|
||||
0xc0, 0xeb, 0x4d, 0x9d, 0x4f, 0x5c, 0x86, 0xb2, 0x08, 0x1f, 0x4f, 0x64, 0xfb, 0x59, 0x8d, 0x94,
|
||||
0x92, 0x0b, 0x8d, 0x47, 0xe8, 0x9a, 0xbe, 0x19, 0x2e, 0x45, 0x9d, 0x69, 0x47, 0x76, 0x1e, 0x52,
|
||||
0x92, 0x06, 0xba, 0x4d, 0xdf, 0x0c, 0x57, 0x62, 0x5d, 0x29, 0x23, 0xbb, 0x23, 0x64, 0x69, 0x5d,
|
||||
0xea, 0xda, 0xbe, 0x1d, 0xf6, 0x2f, 0xb7, 0x6c, 0x81, 0xb0, 0x15, 0xc2, 0xde, 0x43, 0x11, 0xab,
|
||||
0xe8, 0xf0, 0x45, 0x6e, 0x34, 0x7a, 0xf6, 0x37, 0xd6, 0x61, 0xbf, 0x90, 0xc7, 0xd9, 0x30, 0x36,
|
||||
0xdf, 0x4f, 0x06, 0x9f, 0x35, 0x7a, 0x8f, 0x81, 0x3b, 0xab, 0xa2, 0x8c, 0x85, 0xc3, 0x9c, 0x82,
|
||||
0xff, 0x7f, 0xf4, 0xad, 0xf6, 0x9f, 0x4d, 0x2b, 0xc6, 0x0f, 0xb5, 0xad, 0xb4, 0xd7, 0xdf, 0x00,
|
||||
0x00, 0x00, 0xff, 0xff, 0x80, 0xa3, 0xc1, 0x82, 0x0d, 0x01, 0x00, 0x00,
|
||||
}
|
23
pkg/ecode/types/status.proto
Normal file
23
pkg/ecode/types/status.proto
Normal file
@ -0,0 +1,23 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package bilibili.rpc;
|
||||
|
||||
import "google/protobuf/any.proto";
|
||||
|
||||
option go_package = "github.com/bilibili/Kratos/pkg/ecode/types;types";
|
||||
option java_multiple_files = true;
|
||||
option java_outer_classname = "StatusProto";
|
||||
option java_package = "com.bilibili.rpc";
|
||||
option objc_class_prefix = "RPC";
|
||||
|
||||
message Status {
|
||||
// The error code see ecode.Code
|
||||
int32 code = 1;
|
||||
|
||||
// A developer-facing error message, which should be in English. Any
|
||||
string message = 2;
|
||||
|
||||
// A list of messages that carry the error details. There is a common set of
|
||||
// message types for APIs to use.
|
||||
repeated google.protobuf.Any details = 3;
|
||||
}
|
@ -37,3 +37,30 @@ const (
|
||||
// Device 客户端信息
|
||||
Device = "device"
|
||||
)
|
||||
|
||||
var outgoingKey = map[string]struct{}{
|
||||
Color: struct{}{},
|
||||
RemoteIP: struct{}{},
|
||||
RemotePort: struct{}{},
|
||||
Mirror: struct{}{},
|
||||
}
|
||||
|
||||
var incomingKey = map[string]struct{}{
|
||||
Caller: struct{}{},
|
||||
}
|
||||
|
||||
// IsOutgoingKey represent this key should propagate by rpc.
|
||||
func IsOutgoingKey(key string) bool {
|
||||
_, ok := outgoingKey[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsIncomingKey represent this key should extract from rpc metadata.
|
||||
func IsIncomingKey(key string) (ok bool) {
|
||||
_, ok = outgoingKey[key]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
_, ok = incomingKey[key]
|
||||
return
|
||||
}
|
||||
|
62
pkg/net/rpc/warden/CHANGELOG.md
Normal file
62
pkg/net/rpc/warden/CHANGELOG.md
Normal file
@ -0,0 +1,62 @@
|
||||
### net/rpc/warden
|
||||
##### Version 1.1.12
|
||||
1. 设置 caller 为 no_user 如果 user 不存在
|
||||
|
||||
##### Version 1.1.12
|
||||
1. warden支持mirror传递
|
||||
|
||||
##### Version 1.1.11
|
||||
1. Validate RequestErr支持详细报错信息
|
||||
|
||||
##### Version 1.1.10
|
||||
1. 默认读取环境中的color
|
||||
|
||||
##### Version 1.1.9
|
||||
1. 增加NonBlock模式
|
||||
|
||||
##### Version 1.1.8
|
||||
1. 新增appid mock
|
||||
|
||||
##### Version 1.1.7
|
||||
1. 兼容cpu为0和wrr dt为0的情况
|
||||
|
||||
##### Version 1.1.6
|
||||
1. 修改caller传递和获取方式
|
||||
2. 添加error detail example
|
||||
|
||||
##### Version 1.1.5
|
||||
1. 增加server端json格式支持
|
||||
|
||||
##### Version 1.1.4
|
||||
1. 判断reosvler.builder为nil之后再注册
|
||||
|
||||
##### Version 1.1.3
|
||||
1. 支持zone和clusters
|
||||
|
||||
##### Version 1.1.2
|
||||
1. 业务错误日志记为 WARN
|
||||
|
||||
##### Version 1.1.1
|
||||
1. server实现了返回cpu信息
|
||||
|
||||
##### Version 1.1.0
|
||||
1. 增加ErrorDetail
|
||||
2. 修复日志打印error信息丢失问题
|
||||
|
||||
##### Version 1.0.3
|
||||
1. 给server增加keepalive参数
|
||||
|
||||
##### Version 1.0.2
|
||||
|
||||
1. 替代默认的timoue,使用durtaion.Shrink()来传递context
|
||||
2. 修复peer.Addr为nil时会panic的问题
|
||||
|
||||
##### Version 1.0.1
|
||||
|
||||
1. 去除timeout的手动传递,改为使用grpc默认自带的grpc-timeout
|
||||
2. 获取server address改为使用call option的方式,去除对balancer的依赖
|
||||
|
||||
##### Version 1.0.0
|
||||
|
||||
1. 使用NewClient来新建一个RPC客户端,并默认集成trace、log、recovery、moniter拦截器
|
||||
2. 使用NewServer来新建一个RPC服务端,并默认集成trace、log、recovery、moniter拦截器
|
10
pkg/net/rpc/warden/OWNERS
Normal file
10
pkg/net/rpc/warden/OWNERS
Normal file
@ -0,0 +1,10 @@
|
||||
# See the OWNERS docs at https://go.k8s.io/owners
|
||||
|
||||
approvers:
|
||||
- caoguoliang
|
||||
- maojian
|
||||
labels:
|
||||
- library
|
||||
reviewers:
|
||||
- caoguoliang
|
||||
- maojian
|
13
pkg/net/rpc/warden/README.md
Normal file
13
pkg/net/rpc/warden/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
#### net/rcp/warden
|
||||
|
||||
##### 项目简介
|
||||
|
||||
来自 bilibili 主站技术部的 RPC 框架,融合主站技术部的核心科技,带来如飞一般的体验。
|
||||
|
||||
##### 编译环境
|
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行**
|
||||
|
||||
##### 依赖包
|
||||
|
||||
- [grpc](google.golang.org/grpc)
|
20
pkg/net/rpc/warden/balancer/p2c/CHANGELOG.md
Normal file
20
pkg/net/rpc/warden/balancer/p2c/CHANGELOG.md
Normal file
@ -0,0 +1,20 @@
|
||||
### business/warden/balancer/p2c
|
||||
|
||||
### Version 1.3.1
|
||||
1. add more test
|
||||
|
||||
### Version 1.3
|
||||
1. P2C替换smooth weighted round-robin
|
||||
|
||||
##### Version 1.2.1
|
||||
1. 删除了netflix ribbon的权重算法,改成了平方根算法
|
||||
|
||||
##### Version 1.2.0
|
||||
1. 实现了动态计算的调度轮询算法(使用了服务端的成功率数据,替换基于本地计算的成功率数据)
|
||||
|
||||
##### Version 1.1.0
|
||||
1. 实现了动态计算的调度轮询算法
|
||||
|
||||
##### Version 1.0.0
|
||||
|
||||
1. 实现了带权重可以识别Color的轮询算法
|
9
pkg/net/rpc/warden/balancer/p2c/OWNERS
Normal file
9
pkg/net/rpc/warden/balancer/p2c/OWNERS
Normal file
@ -0,0 +1,9 @@
|
||||
# See the OWNERS docs at https://go.k8s.io/owners
|
||||
|
||||
approvers:
|
||||
- caoguoliang
|
||||
labels:
|
||||
- library
|
||||
reviewers:
|
||||
- caoguoliang
|
||||
- maojian
|
13
pkg/net/rpc/warden/balancer/p2c/README.md
Normal file
13
pkg/net/rpc/warden/balancer/p2c/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
#### business/warden/balancer/wrr
|
||||
|
||||
##### 项目简介
|
||||
|
||||
warden 的 weighted round robin负载均衡模块,主要用于为每个RPC请求返回一个Server节点以供调用
|
||||
|
||||
##### 编译环境
|
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行**
|
||||
|
||||
##### 依赖包
|
||||
|
||||
- [grpc](google.golang.org/grpc)
|
269
pkg/net/rpc/warden/balancer/p2c/p2c.go
Normal file
269
pkg/net/rpc/warden/balancer/p2c/p2c.go
Normal file
@ -0,0 +1,269 @@
|
||||
package p2c
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
wmd "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/base"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// The mean lifetime of `cost`, it reaches its half-life after Tau*ln(2).
|
||||
tau = int64(time.Millisecond * 600)
|
||||
// if statistic not collected,we add a big penalty to endpoint
|
||||
penalty = uint64(1000 * time.Millisecond * 250)
|
||||
|
||||
forceGap = int64(time.Second * 3)
|
||||
)
|
||||
|
||||
var _ base.PickerBuilder = &p2cPickerBuilder{}
|
||||
var _ balancer.Picker = &p2cPicker{}
|
||||
|
||||
// Name is the name of pick of two random choices balancer.
|
||||
const Name = "p2c"
|
||||
|
||||
// newBuilder creates a new weighted-roundrobin balancer builder.
|
||||
func newBuilder() balancer.Builder {
|
||||
return base.NewBalancerBuilder(Name, &p2cPickerBuilder{})
|
||||
}
|
||||
|
||||
func init() {
|
||||
balancer.Register(newBuilder())
|
||||
}
|
||||
|
||||
type subConn struct {
|
||||
// metadata
|
||||
conn balancer.SubConn
|
||||
addr resolver.Address
|
||||
meta wmd.MD
|
||||
|
||||
//client statistic data
|
||||
lag uint64
|
||||
success uint64
|
||||
inflight int64
|
||||
// server statistic data
|
||||
svrCPU uint64
|
||||
|
||||
//last collected timestamp
|
||||
stamp int64
|
||||
//last pick timestamp
|
||||
pick int64
|
||||
// request number in a period time
|
||||
reqs int64
|
||||
}
|
||||
|
||||
func (sc *subConn) health() uint64 {
|
||||
return atomic.LoadUint64(&sc.success)
|
||||
}
|
||||
|
||||
func (sc *subConn) cost() uint64 {
|
||||
load := atomic.LoadUint64(&sc.svrCPU) * atomic.LoadUint64(&sc.lag) * uint64(atomic.LoadInt64(&sc.inflight))
|
||||
if load == 0 {
|
||||
// penalty是初始化没有数据时的惩罚值,默认为1e9 * 250
|
||||
load = penalty
|
||||
}
|
||||
return load
|
||||
}
|
||||
|
||||
// statistics is info for log
|
||||
type statistic struct {
|
||||
addr string
|
||||
score float64
|
||||
cs uint64
|
||||
lantency uint64
|
||||
cpu uint64
|
||||
inflight int64
|
||||
reqs int64
|
||||
}
|
||||
|
||||
type p2cPickerBuilder struct{}
|
||||
|
||||
func (*p2cPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
|
||||
p := &p2cPicker{
|
||||
colors: make(map[string]*p2cPicker),
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
for addr, sc := range readySCs {
|
||||
meta, ok := addr.Metadata.(wmd.MD)
|
||||
if !ok {
|
||||
meta = wmd.MD{
|
||||
Weight: 10,
|
||||
}
|
||||
}
|
||||
subc := &subConn{
|
||||
conn: sc,
|
||||
addr: addr,
|
||||
meta: meta,
|
||||
|
||||
svrCPU: 500,
|
||||
lag: 0,
|
||||
success: 1000,
|
||||
inflight: 1,
|
||||
}
|
||||
if meta.Color == "" {
|
||||
p.subConns = append(p.subConns, subc)
|
||||
continue
|
||||
}
|
||||
// if color not empty, use color picker
|
||||
cp, ok := p.colors[meta.Color]
|
||||
if !ok {
|
||||
cp = &p2cPicker{r: rand.New(rand.NewSource(time.Now().UnixNano()))}
|
||||
p.colors[meta.Color] = cp
|
||||
}
|
||||
cp.subConns = append(cp.subConns, subc)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
type p2cPicker struct {
|
||||
// subConns is the snapshot of the weighted-roundrobin balancer when this picker was
|
||||
// created. The slice is immutable. Each Get() will do a round robin
|
||||
// selection from it and return the selected SubConn.
|
||||
subConns []*subConn
|
||||
colors map[string]*p2cPicker
|
||||
logTs int64
|
||||
r *rand.Rand
|
||||
lk sync.Mutex
|
||||
}
|
||||
|
||||
func (p *p2cPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
|
||||
// FIXME refactor to unify the color logic
|
||||
color := nmd.String(ctx, nmd.Color)
|
||||
if color == "" && env.Color != "" {
|
||||
color = env.Color
|
||||
}
|
||||
if color != "" {
|
||||
if cp, ok := p.colors[color]; ok {
|
||||
return cp.pick(ctx, opts)
|
||||
}
|
||||
}
|
||||
return p.pick(ctx, opts)
|
||||
}
|
||||
|
||||
func (p *p2cPicker) pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
|
||||
var pc, upc *subConn
|
||||
start := time.Now().UnixNano()
|
||||
|
||||
if len(p.subConns) <= 0 {
|
||||
return nil, nil, balancer.ErrNoSubConnAvailable
|
||||
} else if len(p.subConns) == 1 {
|
||||
pc = p.subConns[0]
|
||||
} else {
|
||||
// choose two distinct nodes
|
||||
p.lk.Lock()
|
||||
a := p.r.Intn(len(p.subConns))
|
||||
b := p.r.Intn(len(p.subConns) - 1)
|
||||
p.lk.Unlock()
|
||||
if b >= a {
|
||||
b = b + 1
|
||||
}
|
||||
nodeA, nodeB := p.subConns[a], p.subConns[b]
|
||||
// meta.Weight为服务发布者在disocvery中设置的权重
|
||||
if nodeA.cost()*nodeB.health()*nodeB.meta.Weight > nodeB.cost()*nodeA.health()*nodeA.meta.Weight {
|
||||
pc, upc = nodeB, nodeA
|
||||
} else {
|
||||
pc, upc = nodeA, nodeB
|
||||
}
|
||||
// 如果选中的节点,在forceGap期间内没有被选中一次,那么强制一次
|
||||
// 利用强制的机会,来触发成功率、延迟的衰减
|
||||
// 原子锁conn.pick保证并发安全,放行一次
|
||||
pick := atomic.LoadInt64(&upc.pick)
|
||||
if start-pick > forceGap && atomic.CompareAndSwapInt64(&upc.pick, pick, start) {
|
||||
pc = upc
|
||||
}
|
||||
}
|
||||
|
||||
// 节点未发生切换才更新pick时间
|
||||
if pc != upc {
|
||||
atomic.StoreInt64(&pc.pick, start)
|
||||
}
|
||||
atomic.AddInt64(&pc.inflight, 1)
|
||||
atomic.AddInt64(&pc.reqs, 1)
|
||||
return pc.conn, func(di balancer.DoneInfo) {
|
||||
atomic.AddInt64(&pc.inflight, -1)
|
||||
now := time.Now().UnixNano()
|
||||
// get moving average ratio w
|
||||
stamp := atomic.SwapInt64(&pc.stamp, now)
|
||||
td := now - stamp
|
||||
if td < 0 {
|
||||
td = 0
|
||||
}
|
||||
w := math.Exp(float64(-td) / float64(tau))
|
||||
|
||||
lag := now - start
|
||||
if lag < 0 {
|
||||
lag = 0
|
||||
}
|
||||
oldLag := atomic.LoadUint64(&pc.lag)
|
||||
if oldLag == 0 {
|
||||
w = 0.0
|
||||
}
|
||||
lag = int64(float64(oldLag)*w + float64(lag)*(1.0-w))
|
||||
atomic.StoreUint64(&pc.lag, uint64(lag))
|
||||
|
||||
success := uint64(1000) // error value ,if error set 1
|
||||
if di.Err != nil {
|
||||
if st, ok := status.FromError(di.Err); ok {
|
||||
// only counter the local grpc error, ignore any business error
|
||||
if st.Code() != codes.Unknown && st.Code() != codes.OK {
|
||||
success = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
oldSuc := atomic.LoadUint64(&pc.success)
|
||||
success = uint64(float64(oldSuc)*w + float64(success)*(1.0-w))
|
||||
atomic.StoreUint64(&pc.success, success)
|
||||
|
||||
trailer := di.Trailer
|
||||
if strs, ok := trailer[wmd.CPUUsage]; ok {
|
||||
if cpu, err2 := strconv.ParseUint(strs[0], 10, 64); err2 == nil && cpu > 0 {
|
||||
atomic.StoreUint64(&pc.svrCPU, cpu)
|
||||
}
|
||||
}
|
||||
|
||||
logTs := atomic.LoadInt64(&p.logTs)
|
||||
if now-logTs > int64(time.Second*3) {
|
||||
if atomic.CompareAndSwapInt64(&p.logTs, logTs, now) {
|
||||
p.printStats()
|
||||
}
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *p2cPicker) printStats() {
|
||||
if len(p.subConns) <= 0 {
|
||||
return
|
||||
}
|
||||
stats := make([]statistic, 0, len(p.subConns))
|
||||
for _, conn := range p.subConns {
|
||||
var stat statistic
|
||||
stat.addr = conn.addr.Addr
|
||||
stat.cpu = atomic.LoadUint64(&conn.svrCPU)
|
||||
stat.cs = atomic.LoadUint64(&conn.success)
|
||||
stat.inflight = atomic.LoadInt64(&conn.inflight)
|
||||
stat.lantency = atomic.LoadUint64(&conn.lag)
|
||||
stat.reqs = atomic.SwapInt64(&conn.reqs, 0)
|
||||
load := stat.cpu * uint64(stat.inflight) * stat.lantency
|
||||
if load != 0 {
|
||||
stat.score = float64(stat.cs*conn.meta.Weight*1e8) / float64(load)
|
||||
}
|
||||
stats = append(stats, stat)
|
||||
}
|
||||
log.Info("p2c %s : %+v", p.subConns[0].addr.ServerName, stats)
|
||||
//fmt.Printf("%+v\n", stats)
|
||||
}
|
347
pkg/net/rpc/warden/balancer/p2c/p2c_test.go
Normal file
347
pkg/net/rpc/warden/balancer/p2c/p2c_test.go
Normal file
@ -0,0 +1,347 @@
|
||||
package p2c
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var serverNum int
|
||||
var cliNum int
|
||||
var concurrency int
|
||||
var extraLoad int64
|
||||
var extraDelay int64
|
||||
var extraWeight uint64
|
||||
|
||||
func init() {
|
||||
flag.IntVar(&serverNum, "snum", 5, "-snum 6")
|
||||
flag.IntVar(&cliNum, "cnum", 5, "-cnum 12")
|
||||
flag.IntVar(&concurrency, "concurrency", 5, "-cc 10")
|
||||
flag.Int64Var(&extraLoad, "exload", 3, "-exload 3")
|
||||
flag.Int64Var(&extraDelay, "exdelay", 0, "-exdelay 250")
|
||||
flag.Uint64Var(&extraWeight, "extraWeight", 0, "-exdelay 50")
|
||||
}
|
||||
|
||||
type testSubConn struct {
|
||||
addr resolver.Address
|
||||
wait chan struct{}
|
||||
//statics
|
||||
reqs int64
|
||||
usage int64
|
||||
cpu int64
|
||||
prevReq int64
|
||||
prevUsage int64
|
||||
//control params
|
||||
loadJitter int64
|
||||
delayJitter int64
|
||||
}
|
||||
|
||||
func newTestSubConn(addr string, weight uint64, color string) (sc *testSubConn) {
|
||||
sc = &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: addr,
|
||||
Metadata: wmeta.MD{
|
||||
Weight: weight,
|
||||
Color: color,
|
||||
},
|
||||
},
|
||||
wait: make(chan struct{}, 1000),
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
for i := 0; i < 210; i++ {
|
||||
<-sc.wait
|
||||
}
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *testSubConn) connect(ctx context.Context) {
|
||||
time.Sleep(time.Millisecond * 15)
|
||||
//add qps counter when request come in
|
||||
atomic.AddInt64(&s.reqs, 1)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case s.wait <- struct{}{}:
|
||||
atomic.AddInt64(&s.usage, 1)
|
||||
}
|
||||
load := atomic.LoadInt64(&s.loadJitter)
|
||||
if load > 0 {
|
||||
for i := 0; i <= rand.Intn(int(load)); i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case s.wait <- struct{}{}:
|
||||
atomic.AddInt64(&s.usage, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
delay := atomic.LoadInt64(&s.delayJitter)
|
||||
if delay > 0 {
|
||||
delay = rand.Int63n(delay)
|
||||
time.Sleep(time.Millisecond * time.Duration(delay))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testSubConn) UpdateAddresses([]resolver.Address) {
|
||||
|
||||
}
|
||||
|
||||
// Connect starts the connecting for this SubConn.
|
||||
func (s *testSubConn) Connect() {
|
||||
|
||||
}
|
||||
|
||||
func TestBalancerPick(t *testing.T) {
|
||||
scs := map[resolver.Address]balancer.SubConn{}
|
||||
sc1 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test1",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
sc2 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test2",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 4,
|
||||
Color: "red",
|
||||
},
|
||||
},
|
||||
}
|
||||
sc3 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test3",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 2,
|
||||
Color: "red",
|
||||
},
|
||||
},
|
||||
}
|
||||
sc4 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test4",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 2,
|
||||
Color: "purple",
|
||||
},
|
||||
},
|
||||
}
|
||||
scs[sc1.addr] = sc1
|
||||
scs[sc2.addr] = sc2
|
||||
scs[sc3.addr] = sc3
|
||||
scs[sc4.addr] = sc4
|
||||
b := &p2cPickerBuilder{}
|
||||
picker := b.Build(scs)
|
||||
res := []string{"test1", "test1", "test1", "test1"}
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, _, err := picker.Pick(context.Background(), balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != res[i] {
|
||||
t.Fatalf("the subconn picked(%s),but expected(%s)", sc.addr.Addr, res[i])
|
||||
}
|
||||
}
|
||||
|
||||
ctx := nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "black"}))
|
||||
for i := 0; i < 4; i++ {
|
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != res[i] {
|
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i])
|
||||
}
|
||||
}
|
||||
|
||||
env.Color = "purple"
|
||||
ctx2 := context.Background()
|
||||
for i := 0; i < 4; i++ {
|
||||
conn, _, err := picker.Pick(ctx2, balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != "test4" {
|
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i])
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Benchmark_Wrr(b *testing.B) {
|
||||
scs := map[resolver.Address]balancer.SubConn{}
|
||||
for i := 0; i < 50; i++ {
|
||||
addr := resolver.Address{
|
||||
Addr: fmt.Sprintf("addr_%d", i),
|
||||
Metadata: wmeta.MD{Weight: 10},
|
||||
}
|
||||
scs[addr] = &testSubConn{addr: addr}
|
||||
}
|
||||
wpb := &p2cPickerBuilder{}
|
||||
picker := wpb.Build(scs)
|
||||
opt := balancer.PickOptions{}
|
||||
ctx := context.Background()
|
||||
for idx := 0; idx < b.N; idx++ {
|
||||
_, done, err := picker.Pick(ctx, opt)
|
||||
if err != nil {
|
||||
done(balancer.DoneInfo{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChaosPick(t *testing.T) {
|
||||
flag.Parse()
|
||||
fmt.Printf("start chaos test!svrNum:%d cliNum:%d concurrency:%d exLoad:%d exDelay:%d\n", serverNum, cliNum, concurrency, extraLoad, extraDelay)
|
||||
c := newController(serverNum, cliNum)
|
||||
c.launch(concurrency)
|
||||
go c.updateStatics()
|
||||
go c.control(extraLoad, extraDelay)
|
||||
time.Sleep(time.Second * 50)
|
||||
}
|
||||
|
||||
func newController(svrNum int, cliNum int) *controller {
|
||||
//new servers
|
||||
servers := []*testSubConn{}
|
||||
var weight uint64 = 10
|
||||
if extraWeight > 0 {
|
||||
weight = extraWeight
|
||||
}
|
||||
for i := 0; i < svrNum; i++ {
|
||||
weight += extraWeight
|
||||
sc := newTestSubConn(fmt.Sprintf("addr_%d", i), weight, "")
|
||||
servers = append(servers, sc)
|
||||
}
|
||||
//new clients
|
||||
var clients []balancer.Picker
|
||||
scs := map[resolver.Address]balancer.SubConn{}
|
||||
for _, v := range servers {
|
||||
scs[v.addr] = v
|
||||
}
|
||||
for i := 0; i < cliNum; i++ {
|
||||
wpb := &p2cPickerBuilder{}
|
||||
picker := wpb.Build(scs)
|
||||
clients = append(clients, picker)
|
||||
}
|
||||
|
||||
c := &controller{
|
||||
servers: servers,
|
||||
clients: clients,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type controller struct {
|
||||
servers []*testSubConn
|
||||
clients []balancer.Picker
|
||||
}
|
||||
|
||||
func (c *controller) launch(concurrency int) {
|
||||
opt := balancer.PickOptions{}
|
||||
bkg := context.Background()
|
||||
for i := range c.clients {
|
||||
for j := 0; j < concurrency; j++ {
|
||||
picker := c.clients[i]
|
||||
go func() {
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(bkg, time.Millisecond*250)
|
||||
sc, done, _ := picker.Pick(ctx, opt)
|
||||
server := sc.(*testSubConn)
|
||||
server.connect(ctx)
|
||||
var err error
|
||||
if ctx.Err() != nil {
|
||||
err = status.Errorf(codes.DeadlineExceeded, "dead")
|
||||
}
|
||||
cancel()
|
||||
cpu := atomic.LoadInt64(&server.cpu)
|
||||
md := make(map[string]string)
|
||||
md[wmeta.CPUUsage] = strconv.FormatInt(cpu, 10)
|
||||
done(balancer.DoneInfo{Trailer: metadata.New(md), Err: err})
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controller) updateStatics() {
|
||||
for {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
for _, sc := range c.servers {
|
||||
usage := atomic.LoadInt64(&sc.usage)
|
||||
avgCpu := (usage - sc.prevUsage) * 2
|
||||
atomic.StoreInt64(&sc.cpu, avgCpu)
|
||||
sc.prevUsage = usage
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *controller) control(extraLoad, extraDelay int64) {
|
||||
var chaos int
|
||||
for {
|
||||
fmt.Printf("\n")
|
||||
//make some chaos
|
||||
n := rand.Intn(3)
|
||||
chaos = n + 1
|
||||
for i := 0; i < chaos; i++ {
|
||||
if extraLoad > 0 {
|
||||
degree := rand.Int63n(extraLoad)
|
||||
degree++
|
||||
atomic.StoreInt64(&c.servers[i].loadJitter, degree)
|
||||
fmt.Printf("set addr_%d load:%d ", i, degree)
|
||||
}
|
||||
if extraDelay > 0 {
|
||||
degree := rand.Int63n(extraDelay)
|
||||
atomic.StoreInt64(&c.servers[i].delayJitter, degree)
|
||||
fmt.Printf("set addr_%d delay:%dms ", i, degree)
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
sleep := int64(5)
|
||||
time.Sleep(time.Second * time.Duration(sleep))
|
||||
for _, sc := range c.servers {
|
||||
req := atomic.LoadInt64(&sc.reqs)
|
||||
qps := (req - sc.prevReq) / sleep
|
||||
wait := len(sc.wait)
|
||||
sc.prevReq = req
|
||||
fmt.Printf("%s qps:%d waits:%d\n", sc.addr.Addr, qps, wait)
|
||||
}
|
||||
for _, picker := range c.clients {
|
||||
p := picker.(*p2cPicker)
|
||||
p.printStats()
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
//reset chaos
|
||||
for i := 0; i < chaos; i++ {
|
||||
atomic.StoreInt64(&c.servers[i].loadJitter, 0)
|
||||
atomic.StoreInt64(&c.servers[i].delayJitter, 0)
|
||||
}
|
||||
chaos = 0
|
||||
}
|
||||
}
|
17
pkg/net/rpc/warden/balancer/wrr/CHANGELOG.md
Normal file
17
pkg/net/rpc/warden/balancer/wrr/CHANGELOG.md
Normal file
@ -0,0 +1,17 @@
|
||||
### business/warden/balancer/wrr
|
||||
|
||||
##### Version 1.3.0
|
||||
1. 迁移 stat.Summary 到 metric.RollingCounter,metric.RollingGauge
|
||||
|
||||
##### Version 1.2.1
|
||||
1. 删除了netflix ribbon的权重算法,改成了平方根算法
|
||||
|
||||
##### Version 1.2.0
|
||||
1. 实现了动态计算的调度轮询算法(使用了服务端的成功率数据,替换基于本地计算的成功率数据)
|
||||
|
||||
##### Version 1.1.0
|
||||
1. 实现了动态计算的调度轮询算法
|
||||
|
||||
##### Version 1.0.0
|
||||
|
||||
1. 实现了带权重可以识别Color的轮询算法
|
9
pkg/net/rpc/warden/balancer/wrr/OWNERS
Normal file
9
pkg/net/rpc/warden/balancer/wrr/OWNERS
Normal file
@ -0,0 +1,9 @@
|
||||
# See the OWNERS docs at https://go.k8s.io/owners
|
||||
|
||||
approvers:
|
||||
- caoguoliang
|
||||
labels:
|
||||
- library
|
||||
reviewers:
|
||||
- caoguoliang
|
||||
- maojian
|
13
pkg/net/rpc/warden/balancer/wrr/README.md
Normal file
13
pkg/net/rpc/warden/balancer/wrr/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
#### business/warden/balancer/wrr
|
||||
|
||||
##### 项目简介
|
||||
|
||||
warden 的 weighted round robin负载均衡模块,主要用于为每个RPC请求返回一个Server节点以供调用
|
||||
|
||||
##### 编译环境
|
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行**
|
||||
|
||||
##### 依赖包
|
||||
|
||||
- [grpc](google.golang.org/grpc)
|
302
pkg/net/rpc/warden/balancer/wrr/wrr.go
Normal file
302
pkg/net/rpc/warden/balancer/wrr/wrr.go
Normal file
@ -0,0 +1,302 @@
|
||||
package wrr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
|
||||
"github.com/bilibili/kratos/pkg/stat/metric"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/base"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var _ base.PickerBuilder = &wrrPickerBuilder{}
|
||||
var _ balancer.Picker = &wrrPicker{}
|
||||
|
||||
// var dwrrFeature feature.Feature = "dwrr"
|
||||
|
||||
// Name is the name of round_robin balancer.
|
||||
const Name = "wrr"
|
||||
|
||||
// newBuilder creates a new weighted-roundrobin balancer builder.
|
||||
func newBuilder() balancer.Builder {
|
||||
return base.NewBalancerBuilder(Name, &wrrPickerBuilder{})
|
||||
}
|
||||
|
||||
func init() {
|
||||
//feature.DefaultGate.Add(map[feature.Feature]feature.Spec{
|
||||
// dwrrFeature: {Default: false},
|
||||
//})
|
||||
|
||||
balancer.Register(newBuilder())
|
||||
}
|
||||
|
||||
type serverInfo struct {
|
||||
cpu int64
|
||||
success uint64 // float64 bits
|
||||
}
|
||||
|
||||
type subConn struct {
|
||||
conn balancer.SubConn
|
||||
addr resolver.Address
|
||||
meta wmeta.MD
|
||||
|
||||
err metric.RollingCounter
|
||||
latency metric.RollingGauge
|
||||
si serverInfo
|
||||
// effective weight
|
||||
ewt int64
|
||||
// current weight
|
||||
cwt int64
|
||||
// last score
|
||||
score float64
|
||||
}
|
||||
|
||||
func (c *subConn) errSummary() (err int64, req int64) {
|
||||
c.err.Reduce(func(iterator metric.Iterator) float64 {
|
||||
for iterator.Next() {
|
||||
bucket := iterator.Bucket()
|
||||
req += bucket.Count
|
||||
for _, p := range bucket.Points {
|
||||
err += int64(p)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (c *subConn) latencySummary() (latency float64, count int64) {
|
||||
c.latency.Reduce(func(iterator metric.Iterator) float64 {
|
||||
for iterator.Next() {
|
||||
bucket := iterator.Bucket()
|
||||
count += bucket.Count
|
||||
for _, p := range bucket.Points {
|
||||
latency += p
|
||||
}
|
||||
}
|
||||
return 0
|
||||
})
|
||||
return latency / float64(count), count
|
||||
}
|
||||
|
||||
// statistics is info for log
|
||||
type statistics struct {
|
||||
addr string
|
||||
ewt int64
|
||||
cs float64
|
||||
ss float64
|
||||
lantency float64
|
||||
cpu float64
|
||||
req int64
|
||||
}
|
||||
|
||||
// Stats is grpc Interceptor for client to collect server stats
|
||||
func Stats() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
|
||||
var (
|
||||
trailer metadata.MD
|
||||
md nmd.MD
|
||||
ok bool
|
||||
)
|
||||
if md, ok = nmd.FromContext(ctx); !ok {
|
||||
md = nmd.MD{}
|
||||
} else {
|
||||
md = md.Copy()
|
||||
}
|
||||
ctx = nmd.NewContext(ctx, md)
|
||||
opts = append(opts, grpc.Trailer(&trailer))
|
||||
|
||||
err = invoker(ctx, method, req, reply, cc, opts...)
|
||||
|
||||
conn, ok := md["conn"].(*subConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if strs, ok := trailer[wmeta.CPUUsage]; ok {
|
||||
if cpu, err2 := strconv.ParseInt(strs[0], 10, 64); err2 == nil && cpu > 0 {
|
||||
atomic.StoreInt64(&conn.si.cpu, cpu)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type wrrPickerBuilder struct{}
|
||||
|
||||
func (*wrrPickerBuilder) Build(readySCs map[resolver.Address]balancer.SubConn) balancer.Picker {
|
||||
p := &wrrPicker{
|
||||
colors: make(map[string]*wrrPicker),
|
||||
}
|
||||
for addr, sc := range readySCs {
|
||||
meta, ok := addr.Metadata.(wmeta.MD)
|
||||
if !ok {
|
||||
meta = wmeta.MD{
|
||||
Weight: 10,
|
||||
}
|
||||
}
|
||||
subc := &subConn{
|
||||
conn: sc,
|
||||
addr: addr,
|
||||
|
||||
meta: meta,
|
||||
ewt: int64(meta.Weight),
|
||||
score: -1,
|
||||
|
||||
err: metric.NewRollingCounter(metric.RollingCounterOpts{
|
||||
Size: 10,
|
||||
BucketDuration: time.Millisecond * 100,
|
||||
}),
|
||||
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{
|
||||
Size: 10,
|
||||
BucketDuration: time.Millisecond * 100,
|
||||
}),
|
||||
|
||||
si: serverInfo{cpu: 500, success: math.Float64bits(1)},
|
||||
}
|
||||
if meta.Color == "" {
|
||||
p.subConns = append(p.subConns, subc)
|
||||
continue
|
||||
}
|
||||
// if color not empty, use color picker
|
||||
cp, ok := p.colors[meta.Color]
|
||||
if !ok {
|
||||
cp = &wrrPicker{}
|
||||
p.colors[meta.Color] = cp
|
||||
}
|
||||
cp.subConns = append(cp.subConns, subc)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
type wrrPicker struct {
|
||||
// subConns is the snapshot of the weighted-roundrobin balancer when this picker was
|
||||
// created. The slice is immutable. Each Get() will do a round robin
|
||||
// selection from it and return the selected SubConn.
|
||||
subConns []*subConn
|
||||
colors map[string]*wrrPicker
|
||||
updateAt int64
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (p *wrrPicker) Pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
|
||||
// FIXME refactor to unify the color logic
|
||||
color := nmd.String(ctx, nmd.Color)
|
||||
if color == "" && env.Color != "" {
|
||||
color = env.Color
|
||||
}
|
||||
if color != "" {
|
||||
if cp, ok := p.colors[color]; ok {
|
||||
return cp.pick(ctx, opts)
|
||||
}
|
||||
}
|
||||
return p.pick(ctx, opts)
|
||||
}
|
||||
|
||||
func (p *wrrPicker) pick(ctx context.Context, opts balancer.PickOptions) (balancer.SubConn, func(balancer.DoneInfo), error) {
|
||||
var (
|
||||
conn *subConn
|
||||
totalWeight int64
|
||||
)
|
||||
if len(p.subConns) <= 0 {
|
||||
return nil, nil, balancer.ErrNoSubConnAvailable
|
||||
}
|
||||
p.mu.Lock()
|
||||
// nginx wrr load balancing algorithm: http://blog.csdn.net/zhangskd/article/details/50194069
|
||||
for _, sc := range p.subConns {
|
||||
totalWeight += sc.ewt
|
||||
sc.cwt += sc.ewt
|
||||
if conn == nil || conn.cwt < sc.cwt {
|
||||
conn = sc
|
||||
}
|
||||
}
|
||||
conn.cwt -= totalWeight
|
||||
p.mu.Unlock()
|
||||
start := time.Now()
|
||||
if cmd, ok := nmd.FromContext(ctx); ok {
|
||||
cmd["conn"] = conn
|
||||
}
|
||||
//if !feature.DefaultGate.Enabled(dwrrFeature) {
|
||||
// return conn.conn, nil, nil
|
||||
//}
|
||||
return conn.conn, func(di balancer.DoneInfo) {
|
||||
ev := int64(0) // error value ,if error set 1
|
||||
if di.Err != nil {
|
||||
if st, ok := status.FromError(di.Err); ok {
|
||||
// only counter the local grpc error, ignore any business error
|
||||
if st.Code() != codes.Unknown && st.Code() != codes.OK {
|
||||
ev = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
conn.err.Add(ev)
|
||||
|
||||
now := time.Now()
|
||||
conn.latency.Add(now.Sub(start).Nanoseconds() / 1e5)
|
||||
u := atomic.LoadInt64(&p.updateAt)
|
||||
if now.UnixNano()-u < int64(time.Second) {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt64(&p.updateAt, u, now.UnixNano()) {
|
||||
return
|
||||
}
|
||||
var (
|
||||
stats = make([]statistics, len(p.subConns))
|
||||
count int
|
||||
total float64
|
||||
)
|
||||
for i, conn := range p.subConns {
|
||||
cpu := float64(atomic.LoadInt64(&conn.si.cpu))
|
||||
ss := math.Float64frombits(atomic.LoadUint64(&conn.si.success))
|
||||
errc, req := conn.errSummary()
|
||||
lagv, lagc := conn.latencySummary()
|
||||
|
||||
if req > 0 && lagc > 0 && lagv > 0 {
|
||||
// client-side success ratio
|
||||
cs := 1 - (float64(errc) / float64(req))
|
||||
if cs <= 0 {
|
||||
cs = 0.1
|
||||
} else if cs <= 0.2 && req <= 5 {
|
||||
cs = 0.2
|
||||
}
|
||||
conn.score = math.Sqrt((cs * ss * ss * 1e9) / (lagv * cpu))
|
||||
stats[i] = statistics{cs: cs, ss: ss, lantency: lagv, cpu: cpu, req: req}
|
||||
}
|
||||
stats[i].addr = conn.addr.Addr
|
||||
|
||||
if conn.score > 0 {
|
||||
total += conn.score
|
||||
count++
|
||||
}
|
||||
}
|
||||
// count must be greater than 1,otherwise will lead ewt to 0
|
||||
if count < 2 {
|
||||
return
|
||||
}
|
||||
avgscore := total / float64(count)
|
||||
p.mu.Lock()
|
||||
for i, conn := range p.subConns {
|
||||
if conn.score <= 0 {
|
||||
conn.score = avgscore
|
||||
}
|
||||
conn.ewt = int64(conn.score * float64(conn.meta.Weight))
|
||||
stats[i].ewt = conn.ewt
|
||||
}
|
||||
p.mu.Unlock()
|
||||
log.Info("warden wrr(%s): %+v", conn.addr.ServerName, stats)
|
||||
}, nil
|
||||
|
||||
}
|
189
pkg/net/rpc/warden/balancer/wrr/wrr_test.go
Normal file
189
pkg/net/rpc/warden/balancer/wrr/wrr_test.go
Normal file
@ -0,0 +1,189 @@
|
||||
package wrr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
|
||||
"github.com/bilibili/kratos/pkg/stat/metric"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
type testSubConn struct {
|
||||
addr resolver.Address
|
||||
}
|
||||
|
||||
func (s *testSubConn) UpdateAddresses([]resolver.Address) {
|
||||
|
||||
}
|
||||
|
||||
// Connect starts the connecting for this SubConn.
|
||||
func (s *testSubConn) Connect() {
|
||||
fmt.Println(s.addr.Addr)
|
||||
}
|
||||
|
||||
func TestBalancerPick(t *testing.T) {
|
||||
scs := map[resolver.Address]balancer.SubConn{}
|
||||
sc1 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test1",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
sc2 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test2",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 4,
|
||||
Color: "red",
|
||||
},
|
||||
},
|
||||
}
|
||||
sc3 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test3",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 2,
|
||||
Color: "red",
|
||||
},
|
||||
},
|
||||
}
|
||||
scs[sc1.addr] = sc1
|
||||
scs[sc2.addr] = sc2
|
||||
scs[sc3.addr] = sc3
|
||||
b := &wrrPickerBuilder{}
|
||||
picker := b.Build(scs)
|
||||
res := []string{"test1", "test1", "test1", "test1"}
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, _, err := picker.Pick(context.Background(), balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != res[i] {
|
||||
t.Fatalf("the subconn picked(%s),but expected(%s)", sc.addr.Addr, res[i])
|
||||
}
|
||||
}
|
||||
res2 := []string{"test2", "test3", "test2", "test2", "test3", "test2"}
|
||||
ctx := nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "red"}))
|
||||
for i := 0; i < 6; i++ {
|
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != res2[i] {
|
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res2[i])
|
||||
}
|
||||
}
|
||||
ctx = nmd.NewContext(context.Background(), nmd.New(map[string]interface{}{"color": "black"}))
|
||||
for i := 0; i < 4; i++ {
|
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != res[i] {
|
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res[i])
|
||||
}
|
||||
}
|
||||
|
||||
// test for env color
|
||||
ctx = context.Background()
|
||||
env.Color = "red"
|
||||
for i := 0; i < 6; i++ {
|
||||
conn, _, err := picker.Pick(ctx, balancer.PickOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("picker.Pick failed!idx:=%d", i)
|
||||
}
|
||||
sc := conn.(*testSubConn)
|
||||
if sc.addr.Addr != res2[i] {
|
||||
t.Fatalf("the (%d) subconn picked(%s),but expected(%s)", i, sc.addr.Addr, res2[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBalancerDone(t *testing.T) {
|
||||
scs := map[resolver.Address]balancer.SubConn{}
|
||||
sc1 := &testSubConn{
|
||||
addr: resolver.Address{
|
||||
Addr: "test1",
|
||||
Metadata: wmeta.MD{
|
||||
Weight: 8,
|
||||
},
|
||||
},
|
||||
}
|
||||
scs[sc1.addr] = sc1
|
||||
b := &wrrPickerBuilder{}
|
||||
picker := b.Build(scs)
|
||||
|
||||
_, done, _ := picker.Pick(context.Background(), balancer.PickOptions{})
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
done(balancer.DoneInfo{Err: status.Errorf(codes.Unknown, "test")})
|
||||
err, req := picker.(*wrrPicker).subConns[0].errSummary()
|
||||
assert.Equal(t, int64(0), err)
|
||||
assert.Equal(t, int64(1), req)
|
||||
|
||||
latency, count := picker.(*wrrPicker).subConns[0].latencySummary()
|
||||
expectLatency := float64(100*time.Millisecond) / 1e5
|
||||
if !(expectLatency < latency && latency < (expectLatency+100)) {
|
||||
t.Fatalf("latency is less than 100ms or greter than 100ms, %f", latency)
|
||||
}
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
_, done, _ = picker.Pick(context.Background(), balancer.PickOptions{})
|
||||
done(balancer.DoneInfo{Err: status.Errorf(codes.Aborted, "test")})
|
||||
err, req = picker.(*wrrPicker).subConns[0].errSummary()
|
||||
assert.Equal(t, int64(1), err)
|
||||
assert.Equal(t, int64(2), req)
|
||||
}
|
||||
|
||||
func TestErrSummary(t *testing.T) {
|
||||
sc := &subConn{
|
||||
err: metric.NewRollingCounter(metric.RollingCounterOpts{
|
||||
Size: 10,
|
||||
BucketDuration: time.Millisecond * 100,
|
||||
}),
|
||||
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{
|
||||
Size: 10,
|
||||
BucketDuration: time.Millisecond * 100,
|
||||
}),
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
sc.err.Add(0)
|
||||
sc.err.Add(1)
|
||||
}
|
||||
err, req := sc.errSummary()
|
||||
assert.Equal(t, int64(10), err)
|
||||
assert.Equal(t, int64(20), req)
|
||||
}
|
||||
|
||||
func TestLatencySummary(t *testing.T) {
|
||||
sc := &subConn{
|
||||
err: metric.NewRollingCounter(metric.RollingCounterOpts{
|
||||
Size: 10,
|
||||
BucketDuration: time.Millisecond * 100,
|
||||
}),
|
||||
latency: metric.NewRollingGauge(metric.RollingGaugeOpts{
|
||||
Size: 10,
|
||||
BucketDuration: time.Millisecond * 100,
|
||||
}),
|
||||
}
|
||||
for i := 1; i <= 100; i++ {
|
||||
sc.latency.Add(int64(i))
|
||||
}
|
||||
latency, count := sc.latencySummary()
|
||||
assert.Equal(t, 50.50, latency)
|
||||
assert.Equal(t, int64(100), count)
|
||||
}
|
334
pkg/net/rpc/warden/client.go
Normal file
334
pkg/net/rpc/warden/client.go
Normal file
@ -0,0 +1,334 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
"github.com/bilibili/kratos/pkg/conf/flagvar"
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
"github.com/bilibili/kratos/pkg/naming"
|
||||
"github.com/bilibili/kratos/pkg/naming/discovery"
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/balancer/p2c"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/status"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
|
||||
"github.com/bilibili/kratos/pkg/net/trace"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var _grpcTarget flagvar.StringVars
|
||||
|
||||
var (
|
||||
_once sync.Once
|
||||
_defaultCliConf = &ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Millisecond * 250),
|
||||
Subset: 50,
|
||||
}
|
||||
_defaultClient *Client
|
||||
)
|
||||
|
||||
func baseMetadata() metadata.MD {
|
||||
gmd := metadata.MD{nmd.Caller: []string{env.AppID}}
|
||||
if env.Color != "" {
|
||||
gmd[nmd.Color] = []string{env.Color}
|
||||
}
|
||||
return gmd
|
||||
}
|
||||
|
||||
// ClientConfig is rpc client conf.
|
||||
type ClientConfig struct {
|
||||
Dial xtime.Duration
|
||||
Timeout xtime.Duration
|
||||
Breaker *breaker.Config
|
||||
Method map[string]*ClientConfig
|
||||
Clusters []string
|
||||
Zone string
|
||||
Subset int
|
||||
NonBlock bool
|
||||
}
|
||||
|
||||
// Client is the framework's client side instance, it contains the ctx, opt and interceptors.
|
||||
// Create an instance of Client, by using NewClient().
|
||||
type Client struct {
|
||||
conf *ClientConfig
|
||||
breaker *breaker.Group
|
||||
mutex sync.RWMutex
|
||||
|
||||
opt []grpc.DialOption
|
||||
handlers []grpc.UnaryClientInterceptor
|
||||
}
|
||||
|
||||
// handle returns a new unary client interceptor for OpenTracing\Logging\LinkTimeout.
|
||||
func (c *Client) handle() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
|
||||
var (
|
||||
ok bool
|
||||
cmd nmd.MD
|
||||
t trace.Trace
|
||||
gmd metadata.MD
|
||||
conf *ClientConfig
|
||||
cancel context.CancelFunc
|
||||
addr string
|
||||
p peer.Peer
|
||||
)
|
||||
var ec ecode.Codes = ecode.OK
|
||||
// apm tracing
|
||||
if t, ok = trace.FromContext(ctx); ok {
|
||||
t = t.Fork("", method)
|
||||
defer t.Finish(&err)
|
||||
}
|
||||
|
||||
// setup metadata
|
||||
gmd = baseMetadata()
|
||||
trace.Inject(t, trace.GRPCFormat, gmd)
|
||||
c.mutex.RLock()
|
||||
if conf, ok = c.conf.Method[method]; !ok {
|
||||
conf = c.conf
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
brk := c.breaker.Get(method)
|
||||
if err = brk.Allow(); err != nil {
|
||||
statsClient.Incr(method, "breaker")
|
||||
return
|
||||
}
|
||||
defer onBreaker(brk, &err)
|
||||
_, ctx, cancel = conf.Timeout.Shrink(ctx)
|
||||
defer cancel()
|
||||
if cmd, ok = nmd.FromContext(ctx); ok {
|
||||
for netKey, val := range cmd {
|
||||
if !nmd.IsOutgoingKey(netKey) {
|
||||
continue
|
||||
}
|
||||
valstr, ok := val.(string)
|
||||
if ok {
|
||||
gmd[netKey] = []string{valstr}
|
||||
}
|
||||
}
|
||||
}
|
||||
// merge with old matadata if exists
|
||||
if oldmd, ok := metadata.FromOutgoingContext(ctx); ok {
|
||||
gmd = metadata.Join(gmd, oldmd)
|
||||
}
|
||||
ctx = metadata.NewOutgoingContext(ctx, gmd)
|
||||
|
||||
opts = append(opts, grpc.Peer(&p))
|
||||
if err = invoker(ctx, method, req, reply, cc, opts...); err != nil {
|
||||
gst, _ := gstatus.FromError(err)
|
||||
ec = status.ToEcode(gst)
|
||||
err = errors.WithMessage(ec, gst.Message())
|
||||
}
|
||||
if p.Addr != nil {
|
||||
addr = p.Addr.String()
|
||||
}
|
||||
if t != nil {
|
||||
t.SetTag(trace.String(trace.TagAddress, addr), trace.String(trace.TagComment, ""))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func onBreaker(breaker breaker.Breaker, err *error) {
|
||||
if err != nil && *err != nil {
|
||||
if ecode.ServerErr.Equal(*err) || ecode.ServiceUnavailable.Equal(*err) || ecode.Deadline.Equal(*err) || ecode.LimitExceed.Equal(*err) {
|
||||
breaker.MarkFailed()
|
||||
return
|
||||
}
|
||||
}
|
||||
breaker.MarkSuccess()
|
||||
}
|
||||
|
||||
// NewConn will create a grpc conn by default config.
|
||||
func NewConn(target string, opt ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||
return DefaultClient().Dial(context.Background(), target, opt...)
|
||||
}
|
||||
|
||||
// NewClient returns a new blank Client instance with a default client interceptor.
|
||||
// opt can be used to add grpc dial options.
|
||||
func NewClient(conf *ClientConfig, opt ...grpc.DialOption) *Client {
|
||||
resolver.Register(discovery.Builder())
|
||||
c := new(Client)
|
||||
if err := c.SetConfig(conf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
c.UseOpt(grpc.WithBalancerName(p2c.Name))
|
||||
c.UseOpt(opt...)
|
||||
c.Use(c.recovery(), clientLogging(), c.handle())
|
||||
return c
|
||||
}
|
||||
|
||||
// DefaultClient returns a new default Client instance with a default client interceptor and default dialoption.
|
||||
// opt can be used to add grpc dial options.
|
||||
func DefaultClient() *Client {
|
||||
resolver.Register(discovery.Builder())
|
||||
_once.Do(func() {
|
||||
_defaultClient = NewClient(nil)
|
||||
})
|
||||
return _defaultClient
|
||||
}
|
||||
|
||||
// SetConfig hot reloads client config
|
||||
func (c *Client) SetConfig(conf *ClientConfig) (err error) {
|
||||
if conf == nil {
|
||||
conf = _defaultCliConf
|
||||
}
|
||||
if conf.Dial <= 0 {
|
||||
conf.Dial = xtime.Duration(time.Second * 10)
|
||||
}
|
||||
if conf.Timeout <= 0 {
|
||||
conf.Timeout = xtime.Duration(time.Millisecond * 250)
|
||||
}
|
||||
if conf.Subset <= 0 {
|
||||
conf.Subset = 50
|
||||
}
|
||||
|
||||
// FIXME(maojian) check Method dial/timeout
|
||||
c.mutex.Lock()
|
||||
c.conf = conf
|
||||
if c.breaker == nil {
|
||||
c.breaker = breaker.NewGroup(conf.Breaker)
|
||||
} else {
|
||||
c.breaker.Reload(conf.Breaker)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use attachs a global inteceptor to the Client.
|
||||
// For example, this is the right place for a circuit breaker or error management inteceptor.
|
||||
func (c *Client) Use(handlers ...grpc.UnaryClientInterceptor) *Client {
|
||||
finalSize := len(c.handlers) + len(handlers)
|
||||
if finalSize >= int(_abortIndex) {
|
||||
panic("warden: client use too many handlers")
|
||||
}
|
||||
mergedHandlers := make([]grpc.UnaryClientInterceptor, finalSize)
|
||||
copy(mergedHandlers, c.handlers)
|
||||
copy(mergedHandlers[len(c.handlers):], handlers)
|
||||
c.handlers = mergedHandlers
|
||||
return c
|
||||
}
|
||||
|
||||
// UseOpt attachs a global grpc DialOption to the Client.
|
||||
func (c *Client) UseOpt(opt ...grpc.DialOption) *Client {
|
||||
c.opt = append(c.opt, opt...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Dial creates a client connection to the given target.
|
||||
// Target format is scheme://authority/endpoint?query_arg=value
|
||||
// example: discovery://default/account.account.service?cluster=shfy01&cluster=shfy02
|
||||
func (c *Client) Dial(ctx context.Context, target string, opt ...grpc.DialOption) (conn *grpc.ClientConn, err error) {
|
||||
if !c.conf.NonBlock {
|
||||
c.opt = append(c.opt, grpc.WithBlock())
|
||||
}
|
||||
c.opt = append(c.opt, grpc.WithInsecure())
|
||||
c.opt = append(c.opt, grpc.WithUnaryInterceptor(c.chainUnaryClient()))
|
||||
c.opt = append(c.opt, opt...)
|
||||
c.mutex.RLock()
|
||||
conf := c.conf
|
||||
c.mutex.RUnlock()
|
||||
if conf.Dial > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.Dial))
|
||||
defer cancel()
|
||||
}
|
||||
if u, e := url.Parse(target); e == nil {
|
||||
v := u.Query()
|
||||
for _, c := range c.conf.Clusters {
|
||||
v.Add(naming.MetaCluster, c)
|
||||
}
|
||||
if c.conf.Zone != "" {
|
||||
v.Add(naming.MetaZone, c.conf.Zone)
|
||||
}
|
||||
if v.Get("subset") == "" && c.conf.Subset > 0 {
|
||||
v.Add("subset", strconv.FormatInt(int64(c.conf.Subset), 10))
|
||||
}
|
||||
u.RawQuery = v.Encode()
|
||||
// 比较_grpcTarget中的appid是否等于u.path中的appid,并替换成mock的地址
|
||||
for _, t := range _grpcTarget {
|
||||
strs := strings.SplitN(t, "=", 2)
|
||||
if len(strs) == 2 && ("/"+strs[0]) == u.Path {
|
||||
u.Path = "/" + strs[1]
|
||||
u.Scheme = "passthrough"
|
||||
u.RawQuery = ""
|
||||
break
|
||||
}
|
||||
}
|
||||
target = u.String()
|
||||
}
|
||||
if conn, err = grpc.DialContext(ctx, target, c.opt...); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warden client: dial %s error %v!", target, err)
|
||||
}
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
// DialTLS creates a client connection over tls transport to the given target.
|
||||
func (c *Client) DialTLS(ctx context.Context, target string, file string, name string) (conn *grpc.ClientConn, err error) {
|
||||
var creds credentials.TransportCredentials
|
||||
creds, err = credentials.NewClientTLSFromFile(file, name)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
c.opt = append(c.opt, grpc.WithBlock())
|
||||
c.opt = append(c.opt, grpc.WithTransportCredentials(creds))
|
||||
c.opt = append(c.opt, grpc.WithUnaryInterceptor(c.chainUnaryClient()))
|
||||
c.mutex.RLock()
|
||||
conf := c.conf
|
||||
c.mutex.RUnlock()
|
||||
if conf.Dial > 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(conf.Dial))
|
||||
defer cancel()
|
||||
}
|
||||
conn, err = grpc.DialContext(ctx, target, c.opt...)
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
||||
// chainUnaryClient creates a single interceptor out of a chain of many interceptors.
|
||||
//
|
||||
// Execution is done in left-to-right order, including passing of context.
|
||||
// For example ChainUnaryClient(one, two, three) will execute one before two before three.
|
||||
func (c *Client) chainUnaryClient() grpc.UnaryClientInterceptor {
|
||||
n := len(c.handlers)
|
||||
if n == 0 {
|
||||
return func(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx context.Context, method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
var (
|
||||
i int
|
||||
chainHandler grpc.UnaryInvoker
|
||||
)
|
||||
chainHandler = func(ictx context.Context, imethod string, ireq, ireply interface{}, ic *grpc.ClientConn, iopts ...grpc.CallOption) error {
|
||||
if i == n-1 {
|
||||
return invoker(ictx, imethod, ireq, ireply, ic, iopts...)
|
||||
}
|
||||
i++
|
||||
return c.handlers[i](ictx, imethod, ireq, ireply, ic, chainHandler, iopts...)
|
||||
}
|
||||
|
||||
return c.handlers[0](ctx, method, req, reply, cc, chainHandler, opts...)
|
||||
}
|
||||
}
|
91
pkg/net/rpc/warden/exapmle_test.go
Normal file
91
pkg/net/rpc/warden/exapmle_test.go
Normal file
@ -0,0 +1,91 @@
|
||||
package warden_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type helloServer struct {
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
|
||||
for i := 0; i < 3; i++ {
|
||||
in, err := ss.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true}
|
||||
err = ss.Send(ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func ExampleServer() {
|
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second), Addr: ":8080"})
|
||||
// apply server interceptor middleware
|
||||
s.Use(func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
newctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancel()
|
||||
resp, err := handler(newctx, req)
|
||||
return resp, err
|
||||
})
|
||||
pb.RegisterGreeterServer(s.Server(), &helloServer{})
|
||||
s.Start()
|
||||
}
|
||||
|
||||
func ExampleClient() {
|
||||
client := warden.NewClient(&warden.ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
})
|
||||
// apply client interceptor middleware
|
||||
client.Use(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (ret error) {
|
||||
newctx, cancel := context.WithTimeout(ctx, time.Second*5)
|
||||
defer cancel()
|
||||
ret = invoker(newctx, method, req, reply, cc, opts...)
|
||||
return
|
||||
})
|
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
log.Error("did not connect: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
c := pb.NewGreeterClient(conn)
|
||||
name := "2233"
|
||||
rp, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: name, Age: 18})
|
||||
if err != nil {
|
||||
log.Error("could not greet: %v", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("rp", *rp)
|
||||
}
|
189
pkg/net/rpc/warden/internal/benchmark/bench/client/client.go
Normal file
189
pkg/net/rpc/warden/internal/benchmark/bench/client/client.go
Normal file
@ -0,0 +1,189 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/benchmark/bench/proto"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
goproto "github.com/gogo/protobuf/proto"
|
||||
"github.com/montanaflynn/stats"
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
iws = 65535 * 1000
|
||||
iwsc = 65535 * 10000
|
||||
readBuffer = 32 * 1024
|
||||
writeBuffer = 32 * 1024
|
||||
)
|
||||
|
||||
var concurrency = flag.Int("c", 50, "concurrency")
|
||||
var total = flag.Int("t", 500000, "total requests for all clients")
|
||||
var host = flag.String("s", "127.0.0.1:8972", "server ip and port")
|
||||
var isWarden = flag.Bool("w", true, "is warden or grpc client")
|
||||
var strLen = flag.Int("l", 600, "the length of the str")
|
||||
|
||||
func wardenCli() proto.HelloClient {
|
||||
log.Println("start warden cli")
|
||||
client := warden.NewClient(&warden.ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
},
|
||||
grpc.WithInitialWindowSize(iws),
|
||||
grpc.WithInitialConnWindowSize(iwsc),
|
||||
grpc.WithReadBufferSize(readBuffer),
|
||||
grpc.WithWriteBufferSize(writeBuffer))
|
||||
conn, err := client.Dial(context.Background(), *host)
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
cli := proto.NewHelloClient(conn)
|
||||
return cli
|
||||
}
|
||||
|
||||
func grpcCli() proto.HelloClient {
|
||||
log.Println("start grpc cli")
|
||||
conn, err := grpc.Dial(*host, grpc.WithInsecure(),
|
||||
grpc.WithInitialWindowSize(iws),
|
||||
grpc.WithInitialConnWindowSize(iwsc),
|
||||
grpc.WithReadBufferSize(readBuffer),
|
||||
grpc.WithWriteBufferSize(writeBuffer))
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
cli := proto.NewHelloClient(conn)
|
||||
return cli
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
c := *concurrency
|
||||
m := *total / c
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(c)
|
||||
log.Printf("concurrency: %d\nrequests per client: %d\n\n", c, m)
|
||||
|
||||
args := prepareArgs()
|
||||
b, _ := goproto.Marshal(args)
|
||||
log.Printf("message size: %d bytes\n\n", len(b))
|
||||
|
||||
var trans uint64
|
||||
var transOK uint64
|
||||
d := make([][]int64, c)
|
||||
for i := 0; i < c; i++ {
|
||||
dt := make([]int64, 0, m)
|
||||
d = append(d, dt)
|
||||
}
|
||||
var cli proto.HelloClient
|
||||
if *isWarden {
|
||||
cli = wardenCli()
|
||||
} else {
|
||||
cli = grpcCli()
|
||||
}
|
||||
//warmup
|
||||
cli.Say(context.Background(), args)
|
||||
|
||||
totalT := time.Now().UnixNano()
|
||||
for i := 0; i < c; i++ {
|
||||
go func(i int) {
|
||||
for j := 0; j < m; j++ {
|
||||
t := time.Now().UnixNano()
|
||||
reply, err := cli.Say(context.Background(), args)
|
||||
t = time.Now().UnixNano() - t
|
||||
d[i] = append(d[i], t)
|
||||
if err == nil && reply.Field1 == "OK" {
|
||||
atomic.AddUint64(&transOK, 1)
|
||||
}
|
||||
atomic.AddUint64(&trans, 1)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
totalT = time.Now().UnixNano() - totalT
|
||||
totalT = totalT / 1e6
|
||||
log.Printf("took %d ms for %d requests\n", totalT, *total)
|
||||
totalD := make([]int64, 0, *total)
|
||||
for _, k := range d {
|
||||
totalD = append(totalD, k...)
|
||||
}
|
||||
totalD2 := make([]float64, 0, *total)
|
||||
for _, k := range totalD {
|
||||
totalD2 = append(totalD2, float64(k))
|
||||
}
|
||||
|
||||
mean, _ := stats.Mean(totalD2)
|
||||
median, _ := stats.Median(totalD2)
|
||||
max, _ := stats.Max(totalD2)
|
||||
min, _ := stats.Min(totalD2)
|
||||
tp99, _ := stats.Percentile(totalD2, 99)
|
||||
tp999, _ := stats.Percentile(totalD2, 99.9)
|
||||
|
||||
log.Printf("sent requests : %d\n", *total)
|
||||
log.Printf("received requests_OK : %d\n", atomic.LoadUint64(&transOK))
|
||||
log.Printf("throughput (TPS) : %d\n", int64(c*m)*1000/totalT)
|
||||
log.Printf("mean: %v ms, median: %v ms, max: %v ms, min: %v ms, p99: %v ms, p999:%v ms\n", mean/1e6, median/1e6, max/1e6, min/1e6, tp99/1e6, tp999/1e6)
|
||||
|
||||
}
|
||||
|
||||
func prepareArgs() *proto.BenchmarkMessage {
|
||||
b := true
|
||||
var i int32 = 120000
|
||||
var i64 int64 = 98765432101234
|
||||
var s = "许多往事在眼前一幕一幕,变的那麼模糊"
|
||||
repeat := *strLen / (8 * 54)
|
||||
if repeat == 0 {
|
||||
repeat = 1
|
||||
}
|
||||
var str string
|
||||
for i := 0; i < repeat; i++ {
|
||||
str += s
|
||||
}
|
||||
var args proto.BenchmarkMessage
|
||||
|
||||
v := reflect.ValueOf(&args).Elem()
|
||||
num := v.NumField()
|
||||
for k := 0; k < num; k++ {
|
||||
field := v.Field(k)
|
||||
if field.Type().Kind() == reflect.Ptr {
|
||||
switch v.Field(k).Type().Elem().Kind() {
|
||||
case reflect.Int, reflect.Int32:
|
||||
field.Set(reflect.ValueOf(&i))
|
||||
case reflect.Int64:
|
||||
field.Set(reflect.ValueOf(&i64))
|
||||
case reflect.Bool:
|
||||
field.Set(reflect.ValueOf(&b))
|
||||
case reflect.String:
|
||||
field.Set(reflect.ValueOf(&str))
|
||||
}
|
||||
} else {
|
||||
switch field.Kind() {
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
field.SetInt(9876543)
|
||||
case reflect.Bool:
|
||||
field.SetBool(true)
|
||||
case reflect.String:
|
||||
field.SetString(str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &args
|
||||
}
|
1686
pkg/net/rpc/warden/internal/benchmark/bench/proto/hello.pb.go
Normal file
1686
pkg/net/rpc/warden/internal/benchmark/bench/proto/hello.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,60 @@
|
||||
syntax = "proto3";
|
||||
package proto;
|
||||
|
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
|
||||
|
||||
option optimize_for = SPEED;
|
||||
option (gogoproto.goproto_enum_prefix_all) = false;
|
||||
option (gogoproto.goproto_getters_all) = false;
|
||||
option (gogoproto.unmarshaler_all) = true;
|
||||
option (gogoproto.marshaler_all) = true;
|
||||
option (gogoproto.sizer_all) = true;
|
||||
|
||||
service Hello {
|
||||
// Sends a greeting
|
||||
rpc Say (BenchmarkMessage) returns (BenchmarkMessage) {}
|
||||
}
|
||||
|
||||
|
||||
message BenchmarkMessage {
|
||||
string field1 = 1;
|
||||
string field9 = 9;
|
||||
string field18 = 18;
|
||||
bool field80 = 80;
|
||||
bool field81 = 81;
|
||||
int32 field2 = 2;
|
||||
int32 field3 = 3;
|
||||
int32 field280 = 280;
|
||||
int32 field6 = 6;
|
||||
int64 field22 = 22;
|
||||
string field4 = 4;
|
||||
fixed64 field5 = 5;
|
||||
bool field59 = 59;
|
||||
string field7 = 7;
|
||||
int32 field16 = 16;
|
||||
int32 field130 = 130;
|
||||
bool field12 = 12;
|
||||
bool field17 = 17;
|
||||
bool field13 = 13;
|
||||
bool field14 = 14;
|
||||
int32 field104 = 104;
|
||||
int32 field100 = 100;
|
||||
int32 field101 = 101;
|
||||
string field102 = 102;
|
||||
string field103 = 103;
|
||||
int32 field29 = 29;
|
||||
bool field30 = 30;
|
||||
int32 field60 = 60;
|
||||
int32 field271 = 271;
|
||||
int32 field272 = 272;
|
||||
int32 field150 = 150;
|
||||
int32 field23 = 23;
|
||||
bool field24 = 24 ;
|
||||
int32 field25 = 25 ;
|
||||
bool field78 = 78;
|
||||
int32 field67 = 67;
|
||||
int32 field68 = 68;
|
||||
int32 field128 = 128;
|
||||
string field129 = 129;
|
||||
int32 field131 = 131;
|
||||
}
|
103
pkg/net/rpc/warden/internal/benchmark/bench/server/server.go
Normal file
103
pkg/net/rpc/warden/internal/benchmark/bench/server/server.go
Normal file
@ -0,0 +1,103 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/benchmark/bench/proto"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
iws = 65535 * 1000
|
||||
iwsc = 65535 * 10000
|
||||
readBuffer = 32 * 1024
|
||||
writeBuffer = 32 * 1024
|
||||
)
|
||||
|
||||
var reqNum uint64
|
||||
|
||||
type Hello struct{}
|
||||
|
||||
func (t *Hello) Say(ctx context.Context, args *proto.BenchmarkMessage) (reply *proto.BenchmarkMessage, err error) {
|
||||
s := "OK"
|
||||
var i int32 = 100
|
||||
args.Field1 = s
|
||||
args.Field2 = i
|
||||
atomic.AddUint64(&reqNum, 1)
|
||||
return args, nil
|
||||
}
|
||||
|
||||
var host = flag.String("s", "0.0.0.0:8972", "listened ip and port")
|
||||
var isWarden = flag.Bool("w", true, "is warden or grpc client")
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println("run http at :6060")
|
||||
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
|
||||
h := promhttp.Handler()
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
|
||||
}()
|
||||
|
||||
flag.Parse()
|
||||
|
||||
go stat()
|
||||
if *isWarden {
|
||||
runWarden()
|
||||
} else {
|
||||
runGrpc()
|
||||
}
|
||||
}
|
||||
|
||||
func runGrpc() {
|
||||
log.Println("run grpc")
|
||||
lis, err := net.Listen("tcp", *host)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen: %v", err)
|
||||
}
|
||||
s := grpc.NewServer(grpc.InitialWindowSize(iws),
|
||||
grpc.InitialConnWindowSize(iwsc),
|
||||
grpc.ReadBufferSize(readBuffer),
|
||||
grpc.WriteBufferSize(writeBuffer))
|
||||
proto.RegisterHelloServer(s, &Hello{})
|
||||
s.Serve(lis)
|
||||
}
|
||||
|
||||
func runWarden() {
|
||||
log.Println("run warden")
|
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second * 3)},
|
||||
grpc.InitialWindowSize(iws),
|
||||
grpc.InitialConnWindowSize(iwsc),
|
||||
grpc.ReadBufferSize(readBuffer),
|
||||
grpc.WriteBufferSize(writeBuffer))
|
||||
proto.RegisterHelloServer(s.Server(), &Hello{})
|
||||
s.Run(*host)
|
||||
}
|
||||
|
||||
func stat() {
|
||||
ticker := time.NewTicker(time.Second * 5)
|
||||
defer ticker.Stop()
|
||||
var last uint64
|
||||
lastTs := uint64(time.Now().UnixNano())
|
||||
for {
|
||||
<-ticker.C
|
||||
now := atomic.LoadUint64(&reqNum)
|
||||
nowTs := uint64(time.Now().UnixNano())
|
||||
qps := (now - last) * 1e6 / ((nowTs - lastTs) / 1e3)
|
||||
last = now
|
||||
lastTs = nowTs
|
||||
log.Println("qps:", qps)
|
||||
}
|
||||
}
|
15
pkg/net/rpc/warden/internal/benchmark/helloworld/client.sh
Executable file
15
pkg/net/rpc/warden/internal/benchmark/helloworld/client.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
go build -o client greeter_client.go
|
||||
echo size 100 concurrent 30
|
||||
./client -s 100 -c 30
|
||||
echo size 1000 concurrent 30
|
||||
./client -s 1000 -c 30
|
||||
echo size 10000 concurrent 30
|
||||
./client -s 10000 -c 30
|
||||
echo size 100 concurrent 300
|
||||
./client -s 100 -c 300
|
||||
echo size 1000 concurrent 300
|
||||
./client -s 1000 -c 300
|
||||
echo size 10000 concurrent 300
|
||||
./client -s 10000 -c 300
|
||||
rm client
|
@ -0,0 +1,85 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
)
|
||||
|
||||
var (
|
||||
ccf = &warden.ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
}
|
||||
cli pb.GreeterClient
|
||||
wg sync.WaitGroup
|
||||
reqSize int
|
||||
concurrency int
|
||||
request int
|
||||
all int64
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.IntVar(&reqSize, "s", 10, "request size")
|
||||
flag.IntVar(&concurrency, "c", 10, "concurrency")
|
||||
flag.IntVar(&request, "r", 1000, "request per routine")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
name := randSeq(reqSize)
|
||||
cli = newClient()
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go sayHello(&pb.HelloRequest{Name: name})
|
||||
}
|
||||
wg.Wait()
|
||||
fmt.Printf("per request cost %v\n", all/int64(request*concurrency))
|
||||
|
||||
}
|
||||
|
||||
func sayHello(in *pb.HelloRequest) {
|
||||
defer wg.Done()
|
||||
now := time.Now()
|
||||
for i := 0; i < request; i++ {
|
||||
cli.SayHello(context.TODO(), in)
|
||||
}
|
||||
delta := time.Since(now)
|
||||
atomic.AddInt64(&all, int64(delta/time.Millisecond))
|
||||
}
|
||||
|
||||
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
|
||||
func randSeq(n int) string {
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func newClient() (cli pb.GreeterClient) {
|
||||
client := warden.NewClient(ccf)
|
||||
conn, err := client.Dial(context.TODO(), "127.0.0.1:9999")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cli = pb.NewGreeterClient(conn)
|
||||
return
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
var (
|
||||
config = &warden.ServerConfig{Timeout: xtime.Duration(time.Second)}
|
||||
)
|
||||
|
||||
func main() {
|
||||
newServer()
|
||||
}
|
||||
|
||||
type hello struct {
|
||||
}
|
||||
|
||||
func (s *hello) SayHello(c context.Context, in *pb.HelloRequest) (out *pb.HelloReply, err error) {
|
||||
out = new(pb.HelloReply)
|
||||
out.Message = in.Name
|
||||
return
|
||||
}
|
||||
|
||||
func (s *hello) StreamHello(ss pb.Greeter_StreamHelloServer) error {
|
||||
return nil
|
||||
}
|
||||
func newServer() {
|
||||
server := warden.NewServer(config)
|
||||
pb.RegisterGreeterServer(server.Server(), &hello{})
|
||||
go func() {
|
||||
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
|
||||
h := promhttp.Handler()
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
http.ListenAndServe("0.0.0.0:9998", nil)
|
||||
}()
|
||||
err := server.Run(":9999")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
53
pkg/net/rpc/warden/internal/encoding/json/json.go
Normal file
53
pkg/net/rpc/warden/internal/encoding/json/json.go
Normal file
@ -0,0 +1,53 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gogo/protobuf/jsonpb"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"google.golang.org/grpc/encoding"
|
||||
)
|
||||
|
||||
//Reference https://jbrandhorst.com/post/grpc-json/
|
||||
func init() {
|
||||
encoding.RegisterCodec(JSON{
|
||||
Marshaler: jsonpb.Marshaler{
|
||||
EmitDefaults: true,
|
||||
OrigName: true,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// JSON is impl of encoding.Codec
|
||||
type JSON struct {
|
||||
jsonpb.Marshaler
|
||||
jsonpb.Unmarshaler
|
||||
}
|
||||
|
||||
// Name is name of JSON
|
||||
func (j JSON) Name() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
// Marshal is json marshal
|
||||
func (j JSON) Marshal(v interface{}) (out []byte, err error) {
|
||||
if pm, ok := v.(proto.Message); ok {
|
||||
b := new(bytes.Buffer)
|
||||
err := j.Marshaler.Marshal(b, pm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
// Unmarshal is json unmarshal
|
||||
func (j JSON) Unmarshal(data []byte, v interface{}) (err error) {
|
||||
if pm, ok := v.(proto.Message); ok {
|
||||
b := bytes.NewBuffer(data)
|
||||
return j.Unmarshaler.Unmarshal(b, pm)
|
||||
}
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
31
pkg/net/rpc/warden/internal/examples/client/client.go
Normal file
31
pkg/net/rpc/warden/internal/examples/client/client.go
Normal file
@ -0,0 +1,31 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
)
|
||||
|
||||
// usage: ./client -grpc.target=test.service=127.0.0.1:8080
|
||||
func main() {
|
||||
log.Init(&log.Config{Stdout: true})
|
||||
flag.Parse()
|
||||
conn, err := warden.NewClient(nil).Dial(context.Background(), "127.0.0.1:8081")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cli := pb.NewGreeterClient(conn)
|
||||
normalCall(cli)
|
||||
}
|
||||
|
||||
func normalCall(cli pb.GreeterClient) {
|
||||
reply, err := cli.SayHello(context.Background(), &pb.HelloRequest{Name: "tom", Age: 23})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println("get reply:", *reply)
|
||||
}
|
191
pkg/net/rpc/warden/internal/examples/grpcDebug/client.go
Normal file
191
pkg/net/rpc/warden/internal/examples/grpcDebug/client.go
Normal file
@ -0,0 +1,191 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/gogo/protobuf/jsonpb"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/encoding"
|
||||
)
|
||||
|
||||
// Reply for test
|
||||
type Reply struct {
|
||||
res []byte
|
||||
}
|
||||
|
||||
type Discovery struct {
|
||||
HttpClient *http.Client
|
||||
Nodes []string
|
||||
}
|
||||
|
||||
var (
|
||||
data string
|
||||
file string
|
||||
method string
|
||||
addr string
|
||||
tlsCert string
|
||||
tlsServerName string
|
||||
appID string
|
||||
env string
|
||||
)
|
||||
|
||||
//Reference https://jbrandhorst.com/post/grpc-json/
|
||||
func init() {
|
||||
encoding.RegisterCodec(JSON{
|
||||
Marshaler: jsonpb.Marshaler{
|
||||
EmitDefaults: true,
|
||||
OrigName: true,
|
||||
},
|
||||
})
|
||||
flag.StringVar(&data, "data", `{"name":"longxia","age":19}`, `{"name":"longxia","age":19}`)
|
||||
flag.StringVar(&file, "file", ``, `./data.json`)
|
||||
flag.StringVar(&method, "method", "/testproto.Greeter/SayHello", `/testproto.Greeter/SayHello`)
|
||||
flag.StringVar(&addr, "addr", "127.0.0.1:8080", `127.0.0.1:8080`)
|
||||
flag.StringVar(&tlsCert, "cert", "", `./cert.pem`)
|
||||
flag.StringVar(&tlsServerName, "server_name", "", `hello_server`)
|
||||
flag.StringVar(&appID, "appid", "", `appid`)
|
||||
flag.StringVar(&env, "env", "", `env`)
|
||||
}
|
||||
|
||||
// 该example因为使用的是json传输格式所以只能用于调试或测试,用于线上会导致性能下降
|
||||
// 使用方法:
|
||||
// ./grpcDebug -data='{"name":"xia","age":19}' -addr=127.0.0.1:8080 -method=/testproto.Greeter/SayHello
|
||||
// ./grpcDebug -file=data.json -addr=127.0.0.1:8080 -method=/testproto.Greeter/SayHello
|
||||
// DEPLOY_ENV=uat ./grpcDebug -appid=main.community.reply-service -method=/reply.service.v1.Reply/ReplyInfoCache -data='{"rp_id"=1493769244}'
|
||||
func main() {
|
||||
flag.Parse()
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithInsecure(),
|
||||
grpc.WithDefaultCallOptions(grpc.CallContentSubtype(JSON{}.Name())),
|
||||
}
|
||||
if tlsCert != "" {
|
||||
creds, err := credentials.NewClientTLSFromFile(tlsCert, tlsServerName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
opts = append(opts, grpc.WithTransportCredentials(creds))
|
||||
}
|
||||
if file != "" {
|
||||
content, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
fmt.Println("ioutil.ReadFile %s failed!err:=%v", file, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if len(content) > 0 {
|
||||
data = string(content)
|
||||
}
|
||||
}
|
||||
if appID != "" {
|
||||
addr = ipFromDiscovery(appID, env)
|
||||
}
|
||||
conn, err := grpc.Dial(addr, opts...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var reply Reply
|
||||
err = grpc.Invoke(context.Background(), method, []byte(data), &reply, conn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(string(reply.res))
|
||||
}
|
||||
|
||||
func ipFromDiscovery(appID, env string) string {
|
||||
d := &Discovery{
|
||||
Nodes: []string{"discovery.bilibili.co", "api.bilibili.co"},
|
||||
HttpClient: http.DefaultClient,
|
||||
}
|
||||
deployEnv := os.Getenv("DEPLOY_ENV")
|
||||
if deployEnv != "" {
|
||||
env = deployEnv
|
||||
}
|
||||
return d.addr(appID, env, d.nodes())
|
||||
}
|
||||
|
||||
func (d *Discovery) nodes() (addrs []string) {
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Data []struct {
|
||||
Addr string `json:"addr"`
|
||||
} `json:"data"`
|
||||
})
|
||||
resp, err := d.HttpClient.Get(fmt.Sprintf("http://%s/discovery/nodes", d.Nodes[rand.Intn(len(d.Nodes))]))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err = json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, data := range res.Data {
|
||||
addrs = append(addrs, data.Addr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Discovery) addr(appID, env string, nodes []string) (ip string) {
|
||||
res := new(struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]*struct {
|
||||
ZoneInstances map[string][]*struct {
|
||||
AppID string `json:"appid"`
|
||||
Addrs []string `json:"addrs"`
|
||||
} `json:"zone_instances"`
|
||||
} `json:"data"`
|
||||
})
|
||||
host, _ := os.Hostname()
|
||||
resp, err := d.HttpClient.Get(fmt.Sprintf("http://%s/discovery/polls?appid=%s&env=%s&hostname=%s", nodes[rand.Intn(len(nodes))], appID, env, host))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if err = json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, data := range res.Data {
|
||||
for _, zoneInstance := range data.ZoneInstances {
|
||||
for _, instance := range zoneInstance {
|
||||
if instance.AppID == appID {
|
||||
for _, addr := range instance.Addrs {
|
||||
if strings.Contains(addr, "grpc://") {
|
||||
return strings.Replace(addr, "grpc://", "", -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// JSON is impl of encoding.Codec
|
||||
type JSON struct {
|
||||
jsonpb.Marshaler
|
||||
jsonpb.Unmarshaler
|
||||
}
|
||||
|
||||
// Name is name of JSON
|
||||
func (j JSON) Name() string {
|
||||
return "json"
|
||||
}
|
||||
|
||||
// Marshal is json marshal
|
||||
func (j JSON) Marshal(v interface{}) (out []byte, err error) {
|
||||
return v.([]byte), nil
|
||||
}
|
||||
|
||||
// Unmarshal is json unmarshal
|
||||
func (j JSON) Unmarshal(data []byte, v interface{}) (err error) {
|
||||
v.(*Reply).res = data
|
||||
return nil
|
||||
}
|
1
pkg/net/rpc/warden/internal/examples/grpcDebug/data.json
Normal file
1
pkg/net/rpc/warden/internal/examples/grpcDebug/data.json
Normal file
@ -0,0 +1 @@
|
||||
{"name":"xia","age":19}
|
108
pkg/net/rpc/warden/internal/examples/server/main.go
Normal file
108
pkg/net/rpc/warden/internal/examples/server/main.go
Normal file
@ -0,0 +1,108 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
epb "github.com/bilibili/kratos/pkg/ecode/pb"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type helloServer struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
if in.Name == "err_detail_test" {
|
||||
any, _ := ptypes.MarshalAny(&pb.HelloReply{Success: true, Message: "this is test detail"})
|
||||
err := epb.From(ecode.AccessDenied)
|
||||
err.ErrDetail = any
|
||||
return nil, err
|
||||
}
|
||||
return &pb.HelloReply{Message: fmt.Sprintf("hello %s from %s", in.Name, s.addr)}, nil
|
||||
}
|
||||
|
||||
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
|
||||
for i := 0; i < 3; i++ {
|
||||
in, err := ss.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true}
|
||||
err = ss.Send(ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runServer(addr string) *warden.Server {
|
||||
server := warden.NewServer(&warden.ServerConfig{
|
||||
//服务端每个请求的默认超时时间
|
||||
Timeout: xtime.Duration(time.Second),
|
||||
})
|
||||
server.Use(middleware())
|
||||
pb.RegisterGreeterServer(server.Server(), &helloServer{addr: addr})
|
||||
go func() {
|
||||
err := server.Run(addr)
|
||||
if err != nil {
|
||||
panic("run server failed!" + err.Error())
|
||||
}
|
||||
}()
|
||||
return server
|
||||
}
|
||||
|
||||
func main() {
|
||||
log.Init(&log.Config{Stdout: true})
|
||||
server := runServer("0.0.0.0:8081")
|
||||
signalHandler(server)
|
||||
}
|
||||
|
||||
//类似于中间件
|
||||
func middleware() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
//记录调用方法
|
||||
log.Info("method:%s", info.FullMethod)
|
||||
//call chain
|
||||
resp, err = handler(ctx, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func signalHandler(s *warden.Server) {
|
||||
var (
|
||||
ch = make(chan os.Signal, 1)
|
||||
)
|
||||
signal.Notify(ch, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
|
||||
for {
|
||||
si := <-ch
|
||||
switch si {
|
||||
case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
|
||||
log.Info("get a signal %s, stop the consume process", si.String())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
defer cancel()
|
||||
//gracefully shutdown with timeout
|
||||
s.Shutdown(ctx)
|
||||
return
|
||||
case syscall.SIGHUP:
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
11
pkg/net/rpc/warden/internal/metadata/metadata.go
Normal file
11
pkg/net/rpc/warden/internal/metadata/metadata.go
Normal file
@ -0,0 +1,11 @@
|
||||
package metadata
|
||||
|
||||
const (
|
||||
CPUUsage = "cpu_usage"
|
||||
)
|
||||
|
||||
// MD is context metadata for balancer and resolver
|
||||
type MD struct {
|
||||
Weight uint64
|
||||
Color string
|
||||
}
|
642
pkg/net/rpc/warden/internal/proto/testproto/hello.pb.go
Normal file
642
pkg/net/rpc/warden/internal/proto/testproto/hello.pb.go
Normal file
@ -0,0 +1,642 @@
|
||||
// Code generated by protoc-gen-gogo. DO NOT EDIT.
|
||||
// source: hello.proto
|
||||
|
||||
/*
|
||||
Package testproto is a generated protocol buffer package.
|
||||
|
||||
It is generated from these files:
|
||||
hello.proto
|
||||
|
||||
It has these top-level messages:
|
||||
HelloRequest
|
||||
HelloReply
|
||||
*/
|
||||
package testproto
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
import _ "github.com/gogo/protobuf/gogoproto"
|
||||
|
||||
import context "golang.org/x/net/context"
|
||||
import grpc "google.golang.org/grpc"
|
||||
|
||||
import io "io"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
// The request message containing the user's name.
|
||||
type HelloRequest struct {
|
||||
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name" validate:"required"`
|
||||
Age int32 `protobuf:"varint,2,opt,name=age,proto3" json:"age" validate:"min=0"`
|
||||
}
|
||||
|
||||
func (m *HelloRequest) Reset() { *m = HelloRequest{} }
|
||||
func (m *HelloRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*HelloRequest) ProtoMessage() {}
|
||||
func (*HelloRequest) Descriptor() ([]byte, []int) { return fileDescriptorHello, []int{0} }
|
||||
|
||||
// The response message containing the greetings
|
||||
type HelloReply struct {
|
||||
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
|
||||
Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"`
|
||||
}
|
||||
|
||||
func (m *HelloReply) Reset() { *m = HelloReply{} }
|
||||
func (m *HelloReply) String() string { return proto.CompactTextString(m) }
|
||||
func (*HelloReply) ProtoMessage() {}
|
||||
func (*HelloReply) Descriptor() ([]byte, []int) { return fileDescriptorHello, []int{1} }
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*HelloRequest)(nil), "testproto.HelloRequest")
|
||||
proto.RegisterType((*HelloReply)(nil), "testproto.HelloReply")
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConn
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion4
|
||||
|
||||
// Client API for Greeter service
|
||||
|
||||
type GreeterClient interface {
|
||||
// Sends a greeting
|
||||
SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error)
|
||||
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
|
||||
StreamHello(ctx context.Context, opts ...grpc.CallOption) (Greeter_StreamHelloClient, error)
|
||||
}
|
||||
|
||||
type greeterClient struct {
|
||||
cc *grpc.ClientConn
|
||||
}
|
||||
|
||||
func NewGreeterClient(cc *grpc.ClientConn) GreeterClient {
|
||||
return &greeterClient{cc}
|
||||
}
|
||||
|
||||
func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) {
|
||||
out := new(HelloReply)
|
||||
err := grpc.Invoke(ctx, "/testproto.Greeter/SayHello", in, out, c.cc, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *greeterClient) StreamHello(ctx context.Context, opts ...grpc.CallOption) (Greeter_StreamHelloClient, error) {
|
||||
stream, err := grpc.NewClientStream(ctx, &_Greeter_serviceDesc.Streams[0], c.cc, "/testproto.Greeter/StreamHello", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &greeterStreamHelloClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type Greeter_StreamHelloClient interface {
|
||||
Send(*HelloRequest) error
|
||||
Recv() (*HelloReply, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type greeterStreamHelloClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *greeterStreamHelloClient) Send(m *HelloRequest) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *greeterStreamHelloClient) Recv() (*HelloReply, error) {
|
||||
m := new(HelloReply)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Server API for Greeter service
|
||||
|
||||
type GreeterServer interface {
|
||||
// Sends a greeting
|
||||
SayHello(context.Context, *HelloRequest) (*HelloReply, error)
|
||||
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
|
||||
StreamHello(Greeter_StreamHelloServer) error
|
||||
}
|
||||
|
||||
func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) {
|
||||
s.RegisterService(&_Greeter_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(HelloRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(GreeterServer).SayHello(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/testproto.Greeter/SayHello",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(GreeterServer).SayHello(ctx, req.(*HelloRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _Greeter_StreamHello_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(GreeterServer).StreamHello(&greeterStreamHelloServer{stream})
|
||||
}
|
||||
|
||||
type Greeter_StreamHelloServer interface {
|
||||
Send(*HelloReply) error
|
||||
Recv() (*HelloRequest, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type greeterStreamHelloServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *greeterStreamHelloServer) Send(m *HelloReply) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *greeterStreamHelloServer) Recv() (*HelloRequest, error) {
|
||||
m := new(HelloRequest)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var _Greeter_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "testproto.Greeter",
|
||||
HandlerType: (*GreeterServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "SayHello",
|
||||
Handler: _Greeter_SayHello_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "StreamHello",
|
||||
Handler: _Greeter_StreamHello_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "hello.proto",
|
||||
}
|
||||
|
||||
func (m *HelloRequest) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalTo(dAtA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *HelloRequest) MarshalTo(dAtA []byte) (int, error) {
|
||||
var i int
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if len(m.Name) > 0 {
|
||||
dAtA[i] = 0xa
|
||||
i++
|
||||
i = encodeVarintHello(dAtA, i, uint64(len(m.Name)))
|
||||
i += copy(dAtA[i:], m.Name)
|
||||
}
|
||||
if m.Age != 0 {
|
||||
dAtA[i] = 0x10
|
||||
i++
|
||||
i = encodeVarintHello(dAtA, i, uint64(m.Age))
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (m *HelloReply) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalTo(dAtA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *HelloReply) MarshalTo(dAtA []byte) (int, error) {
|
||||
var i int
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if len(m.Message) > 0 {
|
||||
dAtA[i] = 0xa
|
||||
i++
|
||||
i = encodeVarintHello(dAtA, i, uint64(len(m.Message)))
|
||||
i += copy(dAtA[i:], m.Message)
|
||||
}
|
||||
if m.Success {
|
||||
dAtA[i] = 0x10
|
||||
i++
|
||||
if m.Success {
|
||||
dAtA[i] = 1
|
||||
} else {
|
||||
dAtA[i] = 0
|
||||
}
|
||||
i++
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func encodeVarintHello(dAtA []byte, offset int, v uint64) int {
|
||||
for v >= 1<<7 {
|
||||
dAtA[offset] = uint8(v&0x7f | 0x80)
|
||||
v >>= 7
|
||||
offset++
|
||||
}
|
||||
dAtA[offset] = uint8(v)
|
||||
return offset + 1
|
||||
}
|
||||
func (m *HelloRequest) Size() (n int) {
|
||||
var l int
|
||||
_ = l
|
||||
l = len(m.Name)
|
||||
if l > 0 {
|
||||
n += 1 + l + sovHello(uint64(l))
|
||||
}
|
||||
if m.Age != 0 {
|
||||
n += 1 + sovHello(uint64(m.Age))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *HelloReply) Size() (n int) {
|
||||
var l int
|
||||
_ = l
|
||||
l = len(m.Message)
|
||||
if l > 0 {
|
||||
n += 1 + l + sovHello(uint64(l))
|
||||
}
|
||||
if m.Success {
|
||||
n += 2
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func sovHello(x uint64) (n int) {
|
||||
for {
|
||||
n++
|
||||
x >>= 7
|
||||
if x == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
func sozHello(x uint64) (n int) {
|
||||
return sovHello(uint64((x << 1) ^ uint64((int64(x) >> 63))))
|
||||
}
|
||||
func (m *HelloRequest) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: HelloRequest: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: HelloRequest: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType)
|
||||
}
|
||||
var stringLen uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
stringLen |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
intStringLen := int(stringLen)
|
||||
if intStringLen < 0 {
|
||||
return ErrInvalidLengthHello
|
||||
}
|
||||
postIndex := iNdEx + intStringLen
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Name = string(dAtA[iNdEx:postIndex])
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType)
|
||||
}
|
||||
m.Age = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.Age |= (int32(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipHello(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if skippy < 0 {
|
||||
return ErrInvalidLengthHello
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *HelloReply) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: HelloReply: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: HelloReply: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 2 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType)
|
||||
}
|
||||
var stringLen uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
stringLen |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
intStringLen := int(stringLen)
|
||||
if intStringLen < 0 {
|
||||
return ErrInvalidLengthHello
|
||||
}
|
||||
postIndex := iNdEx + intStringLen
|
||||
if postIndex > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
m.Message = string(dAtA[iNdEx:postIndex])
|
||||
iNdEx = postIndex
|
||||
case 2:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Success", wireType)
|
||||
}
|
||||
var v int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
v |= (int(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
m.Success = bool(v != 0)
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipHello(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if skippy < 0 {
|
||||
return ErrInvalidLengthHello
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func skipHello(dAtA []byte) (n int, err error) {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return 0, ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
wireType := int(wire & 0x7)
|
||||
switch wireType {
|
||||
case 0:
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return 0, ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx++
|
||||
if dAtA[iNdEx-1] < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return iNdEx, nil
|
||||
case 1:
|
||||
iNdEx += 8
|
||||
return iNdEx, nil
|
||||
case 2:
|
||||
var length int
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return 0, ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
length |= (int(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
iNdEx += length
|
||||
if length < 0 {
|
||||
return 0, ErrInvalidLengthHello
|
||||
}
|
||||
return iNdEx, nil
|
||||
case 3:
|
||||
for {
|
||||
var innerWire uint64
|
||||
var start int = iNdEx
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return 0, ErrIntOverflowHello
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
innerWire |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
innerWireType := int(innerWire & 0x7)
|
||||
if innerWireType == 4 {
|
||||
break
|
||||
}
|
||||
next, err := skipHello(dAtA[start:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
iNdEx = start + next
|
||||
}
|
||||
return iNdEx, nil
|
||||
case 4:
|
||||
return iNdEx, nil
|
||||
case 5:
|
||||
iNdEx += 4
|
||||
return iNdEx, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
|
||||
}
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidLengthHello = fmt.Errorf("proto: negative length found during unmarshaling")
|
||||
ErrIntOverflowHello = fmt.Errorf("proto: integer overflow")
|
||||
)
|
||||
|
||||
func init() { proto.RegisterFile("hello.proto", fileDescriptorHello) }
|
||||
|
||||
var fileDescriptorHello = []byte{
|
||||
// 296 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x90, 0x3f, 0x4e, 0xc3, 0x30,
|
||||
0x14, 0xc6, 0x63, 0xfe, 0xb5, 0x75, 0x19, 0x90, 0x11, 0x22, 0x2a, 0x92, 0x53, 0x79, 0xca, 0xd2,
|
||||
0xb4, 0xa2, 0x1b, 0x02, 0x09, 0x85, 0x01, 0xe6, 0xf4, 0x04, 0x4e, 0xfa, 0x48, 0x23, 0x25, 0x75,
|
||||
0x6a, 0x3b, 0x48, 0xb9, 0x03, 0x07, 0xe0, 0x48, 0x1d, 0x7b, 0x82, 0x88, 0x86, 0xad, 0x63, 0x4f,
|
||||
0x80, 0x62, 0x28, 0x20, 0xb1, 0x75, 0x7b, 0x3f, 0x7f, 0xfa, 0x7e, 0x4f, 0x7e, 0xb8, 0x3b, 0x83,
|
||||
0x34, 0x15, 0x5e, 0x2e, 0x85, 0x16, 0xa4, 0xa3, 0x41, 0x69, 0x33, 0xf6, 0x06, 0x71, 0xa2, 0x67,
|
||||
0x45, 0xe8, 0x45, 0x22, 0x1b, 0xc6, 0x22, 0x16, 0x43, 0xf3, 0x1c, 0x16, 0xcf, 0x86, 0x0c, 0x98,
|
||||
0xe9, 0xab, 0xc9, 0x24, 0x3e, 0x7d, 0x6a, 0x44, 0x01, 0x2c, 0x0a, 0x50, 0x9a, 0x8c, 0xf1, 0xd1,
|
||||
0x9c, 0x67, 0x60, 0xa3, 0x3e, 0x72, 0x3b, 0xbe, 0xb3, 0xa9, 0x1c, 0xc3, 0xdb, 0xca, 0x39, 0x7f,
|
||||
0xe1, 0x69, 0x32, 0xe5, 0x1a, 0x6e, 0x98, 0x84, 0x45, 0x91, 0x48, 0x98, 0xb2, 0xc0, 0x84, 0x64,
|
||||
0x80, 0x0f, 0x79, 0x0c, 0xf6, 0x41, 0x1f, 0xb9, 0xc7, 0xfe, 0xd5, 0xa6, 0x72, 0x1a, 0xdc, 0x56,
|
||||
0xce, 0xd9, 0x6f, 0x25, 0x4b, 0xe6, 0x77, 0x23, 0x16, 0x34, 0x01, 0xbb, 0xc7, 0xf8, 0x7b, 0x67,
|
||||
0x9e, 0x96, 0xc4, 0xc6, 0xad, 0x0c, 0x94, 0x6a, 0x04, 0x66, 0x69, 0xb0, 0xc3, 0x26, 0x51, 0x45,
|
||||
0x14, 0x81, 0x52, 0x46, 0xdd, 0x0e, 0x76, 0x78, 0xfd, 0x8a, 0x70, 0xeb, 0x51, 0x02, 0x68, 0x90,
|
||||
0xe4, 0x16, 0xb7, 0x27, 0xbc, 0x34, 0x42, 0x72, 0xe9, 0xfd, 0x1c, 0xc2, 0xfb, 0xfb, 0xad, 0xde,
|
||||
0xc5, 0xff, 0x20, 0x4f, 0x4b, 0x66, 0x91, 0x07, 0xdc, 0x9d, 0x68, 0x09, 0x3c, 0xdb, 0x53, 0xe0,
|
||||
0xa2, 0x11, 0xf2, 0xed, 0xe5, 0x9a, 0x5a, 0xab, 0x35, 0xb5, 0x96, 0x35, 0x45, 0xab, 0x9a, 0xa2,
|
||||
0xf7, 0x9a, 0xa2, 0xb7, 0x0f, 0x6a, 0x85, 0x27, 0xa6, 0x31, 0xfe, 0x0c, 0x00, 0x00, 0xff, 0xff,
|
||||
0x13, 0x57, 0x88, 0x03, 0xae, 0x01, 0x00, 0x00,
|
||||
}
|
33
pkg/net/rpc/warden/internal/proto/testproto/hello.proto
Normal file
33
pkg/net/rpc/warden/internal/proto/testproto/hello.proto
Normal file
@ -0,0 +1,33 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package testproto;
|
||||
|
||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
|
||||
|
||||
option (gogoproto.goproto_enum_prefix_all) = false;
|
||||
option (gogoproto.goproto_getters_all) = false;
|
||||
option (gogoproto.unmarshaler_all) = true;
|
||||
option (gogoproto.marshaler_all) = true;
|
||||
option (gogoproto.sizer_all) = true;
|
||||
option (gogoproto.goproto_registration) = true;
|
||||
|
||||
// The greeting service definition.
|
||||
service Greeter {
|
||||
// Sends a greeting
|
||||
rpc SayHello (HelloRequest) returns (HelloReply) {}
|
||||
|
||||
// A bidirectional streaming RPC call recvice HelloRequest return HelloReply
|
||||
rpc StreamHello(stream HelloRequest) returns (stream HelloReply) {}
|
||||
}
|
||||
|
||||
// The request message containing the user's name.
|
||||
message HelloRequest {
|
||||
string name = 1 [(gogoproto.jsontag) = "name", (gogoproto.moretags) = "validate:\"required\""];
|
||||
int32 age = 2 [(gogoproto.jsontag) = "age", (gogoproto.moretags) = "validate:\"min=0\""];
|
||||
}
|
||||
|
||||
// The response message containing the greetings
|
||||
message HelloReply {
|
||||
string message = 1;
|
||||
bool success = 2;
|
||||
}
|
151
pkg/net/rpc/warden/internal/status/status.go
Normal file
151
pkg/net/rpc/warden/internal/status/status.go
Normal file
@ -0,0 +1,151 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
"github.com/bilibili/kratos/pkg/ecode/pb"
|
||||
)
|
||||
|
||||
// togRPCCode convert ecode.Codo to gRPC code
|
||||
func togRPCCode(code ecode.Codes) codes.Code {
|
||||
switch code.Code() {
|
||||
case ecode.OK.Code():
|
||||
return codes.OK
|
||||
case ecode.RequestErr.Code():
|
||||
return codes.InvalidArgument
|
||||
case ecode.NothingFound.Code():
|
||||
return codes.NotFound
|
||||
case ecode.Unauthorized.Code():
|
||||
return codes.Unauthenticated
|
||||
case ecode.AccessDenied.Code():
|
||||
return codes.PermissionDenied
|
||||
case ecode.LimitExceed.Code():
|
||||
return codes.ResourceExhausted
|
||||
case ecode.MethodNotAllowed.Code():
|
||||
return codes.Unimplemented
|
||||
case ecode.Deadline.Code():
|
||||
return codes.DeadlineExceeded
|
||||
case ecode.ServiceUnavailable.Code():
|
||||
return codes.Unavailable
|
||||
}
|
||||
return codes.Unknown
|
||||
}
|
||||
|
||||
func toECode(gst *status.Status) ecode.Code {
|
||||
gcode := gst.Code()
|
||||
switch gcode {
|
||||
case codes.OK:
|
||||
return ecode.OK
|
||||
case codes.InvalidArgument:
|
||||
return ecode.RequestErr
|
||||
case codes.NotFound:
|
||||
return ecode.NothingFound
|
||||
case codes.PermissionDenied:
|
||||
return ecode.AccessDenied
|
||||
case codes.Unauthenticated:
|
||||
return ecode.Unauthorized
|
||||
case codes.ResourceExhausted:
|
||||
return ecode.LimitExceed
|
||||
case codes.Unimplemented:
|
||||
return ecode.MethodNotAllowed
|
||||
case codes.DeadlineExceeded:
|
||||
return ecode.Deadline
|
||||
case codes.Unavailable:
|
||||
return ecode.ServiceUnavailable
|
||||
case codes.Unknown:
|
||||
return ecode.String(gst.Message())
|
||||
}
|
||||
return ecode.ServerErr
|
||||
}
|
||||
|
||||
// FromError convert error for service reply and try to convert it to grpc.Status.
|
||||
func FromError(svrErr error) (gst *status.Status) {
|
||||
var err error
|
||||
svrErr = errors.Cause(svrErr)
|
||||
if code, ok := svrErr.(ecode.Codes); ok {
|
||||
// TODO: deal with err
|
||||
if gst, err = gRPCStatusFromEcode(code); err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// for some special error convert context.Canceled to ecode.Canceled,
|
||||
// context.DeadlineExceeded to ecode.DeadlineExceeded only for raw error
|
||||
// if err be wrapped will not effect.
|
||||
switch svrErr {
|
||||
case context.Canceled:
|
||||
gst, _ = gRPCStatusFromEcode(ecode.Canceled)
|
||||
case context.DeadlineExceeded:
|
||||
gst, _ = gRPCStatusFromEcode(ecode.Deadline)
|
||||
default:
|
||||
gst, _ = status.FromError(svrErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func gRPCStatusFromEcode(code ecode.Codes) (*status.Status, error) {
|
||||
var st *ecode.Status
|
||||
switch v := code.(type) {
|
||||
// compatible old pb.Error remove it after nobody use pb.Error.
|
||||
case *pb.Error:
|
||||
return status.New(codes.Unknown, v.Error()).WithDetails(v)
|
||||
case *ecode.Status:
|
||||
st = v
|
||||
case ecode.Code:
|
||||
st = ecode.FromCode(v)
|
||||
default:
|
||||
st = ecode.Error(ecode.Code(code.Code()), code.Message())
|
||||
for _, detail := range code.Details() {
|
||||
if msg, ok := detail.(proto.Message); ok {
|
||||
st.WithDetails(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
// gst := status.New(togRPCCode(st), st.Message())
|
||||
// NOTE: compatible with PHP swoole gRPC put code in status message as string.
|
||||
// gst := status.New(togRPCCode(st), strconv.Itoa(st.Code()))
|
||||
gst := status.New(codes.Unknown, strconv.Itoa(st.Code()))
|
||||
pbe := &pb.Error{ErrCode: int32(st.Code()), ErrMessage: gst.Message()}
|
||||
// NOTE: server return ecode.Status will be covert to pb.Error details will be ignored
|
||||
// and put it at details[0] for compatible old client
|
||||
return gst.WithDetails(pbe, st.Proto())
|
||||
}
|
||||
|
||||
// ToEcode convert grpc.status to ecode.Codes
|
||||
func ToEcode(gst *status.Status) ecode.Codes {
|
||||
details := gst.Details()
|
||||
// reverse range details, details may contain three case,
|
||||
// if details contain pb.Error and ecode.Status use eocde.Status first.
|
||||
//
|
||||
// Details layout:
|
||||
// pb.Error [0: pb.Error]
|
||||
// both pb.Error and ecode.Status [0: pb.Error, 1: ecode.Status]
|
||||
// ecode.Status [0: ecode.Status]
|
||||
for i := len(details) - 1; i >= 0; i-- {
|
||||
detail := details[i]
|
||||
// compatible with old pb.Error.
|
||||
if pe, ok := detail.(*pb.Error); ok {
|
||||
st := ecode.Error(ecode.Code(pe.ErrCode), pe.ErrMessage)
|
||||
if pe.ErrDetail != nil {
|
||||
dynMsg := new(ptypes.DynamicAny)
|
||||
// TODO deal with unmarshalAny error.
|
||||
if err := ptypes.UnmarshalAny(pe.ErrDetail, dynMsg); err == nil {
|
||||
st, _ = st.WithDetails(dynMsg.Message)
|
||||
}
|
||||
}
|
||||
return st
|
||||
}
|
||||
// convert detail to status only use first detail
|
||||
if pb, ok := detail.(proto.Message); ok {
|
||||
return ecode.FromProto(pb)
|
||||
}
|
||||
}
|
||||
return toECode(gst)
|
||||
}
|
164
pkg/net/rpc/warden/internal/status/status_test.go
Normal file
164
pkg/net/rpc/warden/internal/status/status_test.go
Normal file
@ -0,0 +1,164 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
pkgerr "github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
"github.com/bilibili/kratos/pkg/ecode/pb"
|
||||
)
|
||||
|
||||
func TestCodeConvert(t *testing.T) {
|
||||
var table = map[codes.Code]ecode.Code{
|
||||
codes.OK: ecode.OK,
|
||||
// codes.Canceled
|
||||
codes.Unknown: ecode.ServerErr,
|
||||
codes.InvalidArgument: ecode.RequestErr,
|
||||
codes.DeadlineExceeded: ecode.Deadline,
|
||||
codes.NotFound: ecode.NothingFound,
|
||||
// codes.AlreadyExists
|
||||
codes.PermissionDenied: ecode.AccessDenied,
|
||||
codes.ResourceExhausted: ecode.LimitExceed,
|
||||
// codes.FailedPrecondition
|
||||
// codes.Aborted
|
||||
// codes.OutOfRange
|
||||
codes.Unimplemented: ecode.MethodNotAllowed,
|
||||
codes.Unavailable: ecode.ServiceUnavailable,
|
||||
// codes.DataLoss
|
||||
codes.Unauthenticated: ecode.Unauthorized,
|
||||
}
|
||||
for k, v := range table {
|
||||
assert.Equal(t, toECode(status.New(k, "-500")), v)
|
||||
}
|
||||
for k, v := range table {
|
||||
assert.Equal(t, togRPCCode(v), k, fmt.Sprintf("togRPC code error: %d -> %d", v, k))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoDetailsConvert(t *testing.T) {
|
||||
gst := status.New(codes.Unknown, "-2233")
|
||||
assert.Equal(t, toECode(gst).Code(), -2233)
|
||||
|
||||
gst = status.New(codes.Internal, "")
|
||||
assert.Equal(t, toECode(gst).Code(), -500)
|
||||
}
|
||||
|
||||
func TestFromError(t *testing.T) {
|
||||
t.Run("input general error", func(t *testing.T) {
|
||||
err := errors.New("general error")
|
||||
gst := FromError(err)
|
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code())
|
||||
assert.Contains(t, gst.Message(), "general")
|
||||
})
|
||||
t.Run("input wrap error", func(t *testing.T) {
|
||||
err := pkgerr.Wrap(ecode.RequestErr, "hh")
|
||||
gst := FromError(err)
|
||||
|
||||
assert.Equal(t, "-400", gst.Message())
|
||||
})
|
||||
t.Run("input ecode.Code", func(t *testing.T) {
|
||||
err := ecode.RequestErr
|
||||
gst := FromError(err)
|
||||
|
||||
//assert.Equal(t, codes.InvalidArgument, gst.Code())
|
||||
// NOTE: set all grpc.status as Unkown when error is ecode.Codes for compatible
|
||||
assert.Equal(t, codes.Unknown, gst.Code())
|
||||
// NOTE: gst.Message == str(ecode.Code) for compatible php leagcy code
|
||||
assert.Equal(t, err.Message(), gst.Message())
|
||||
})
|
||||
t.Run("input raw Canceled", func(t *testing.T) {
|
||||
gst := FromError(context.Canceled)
|
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code())
|
||||
assert.Equal(t, "-498", gst.Message())
|
||||
})
|
||||
t.Run("input raw DeadlineExceeded", func(t *testing.T) {
|
||||
gst := FromError(context.DeadlineExceeded)
|
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code())
|
||||
assert.Equal(t, "-504", gst.Message())
|
||||
})
|
||||
t.Run("input pb.Error", func(t *testing.T) {
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()}
|
||||
detail, _ := ptypes.MarshalAny(m)
|
||||
err := &pb.Error{ErrCode: 2233, ErrMessage: "message", ErrDetail: detail}
|
||||
gst := FromError(err)
|
||||
|
||||
assert.Equal(t, codes.Unknown, gst.Code())
|
||||
assert.Len(t, gst.Details(), 1)
|
||||
assert.Equal(t, "2233", gst.Message())
|
||||
})
|
||||
t.Run("input ecode.Status", func(t *testing.T) {
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()}
|
||||
err, _ := ecode.Error(ecode.Unauthorized, "unauthorized").WithDetails(m)
|
||||
gst := FromError(err)
|
||||
|
||||
//assert.Equal(t, codes.Unauthenticated, gst.Code())
|
||||
// NOTE: set all grpc.status as Unkown when error is ecode.Codes for compatible
|
||||
assert.Equal(t, codes.Unknown, gst.Code())
|
||||
assert.Len(t, gst.Details(), 2)
|
||||
details := gst.Details()
|
||||
assert.IsType(t, &pb.Error{}, details[0])
|
||||
assert.IsType(t, err.Proto(), details[1])
|
||||
})
|
||||
}
|
||||
|
||||
func TestToEcode(t *testing.T) {
|
||||
t.Run("input general grpc.Status", func(t *testing.T) {
|
||||
gst := status.New(codes.Unknown, "unknown")
|
||||
ec := ToEcode(gst)
|
||||
|
||||
assert.Equal(t, int(ecode.ServerErr), ec.Code())
|
||||
assert.Equal(t, "-500", ec.Message())
|
||||
assert.Len(t, ec.Details(), 0)
|
||||
})
|
||||
|
||||
t.Run("input pb.Error", func(t *testing.T) {
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()}
|
||||
detail, _ := ptypes.MarshalAny(m)
|
||||
gst := status.New(codes.InvalidArgument, "requesterr")
|
||||
gst, _ = gst.WithDetails(&pb.Error{ErrCode: 1122, ErrMessage: "message", ErrDetail: detail})
|
||||
ec := ToEcode(gst)
|
||||
|
||||
assert.Equal(t, 1122, ec.Code())
|
||||
assert.Equal(t, "message", ec.Message())
|
||||
assert.Len(t, ec.Details(), 1)
|
||||
assert.IsType(t, m, ec.Details()[0])
|
||||
})
|
||||
|
||||
t.Run("input pb.Error and ecode.Status", func(t *testing.T) {
|
||||
gst := status.New(codes.InvalidArgument, "requesterr")
|
||||
gst, _ = gst.WithDetails(
|
||||
&pb.Error{ErrCode: 1122, ErrMessage: "message"},
|
||||
ecode.Errorf(ecode.AccessKeyErr, "AccessKeyErr").Proto(),
|
||||
)
|
||||
ec := ToEcode(gst)
|
||||
|
||||
assert.Equal(t, int(ecode.AccessKeyErr), ec.Code())
|
||||
assert.Equal(t, "AccessKeyErr", ec.Message())
|
||||
})
|
||||
|
||||
t.Run("input encode.Status", func(t *testing.T) {
|
||||
m := ×tamp.Timestamp{Seconds: time.Now().Unix()}
|
||||
st, _ := ecode.Errorf(ecode.AccessKeyErr, "AccessKeyErr").WithDetails(m)
|
||||
gst := status.New(codes.InvalidArgument, "requesterr")
|
||||
gst, _ = gst.WithDetails(st.Proto())
|
||||
ec := ToEcode(gst)
|
||||
|
||||
assert.Equal(t, int(ecode.AccessKeyErr), ec.Code())
|
||||
assert.Equal(t, "AccessKeyErr", ec.Message())
|
||||
assert.Len(t, ec.Details(), 1)
|
||||
assert.IsType(t, m, ec.Details()[0])
|
||||
})
|
||||
}
|
118
pkg/net/rpc/warden/logging.go
Normal file
118
pkg/net/rpc/warden/logging.go
Normal file
@ -0,0 +1,118 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/peer"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
"github.com/bilibili/kratos/pkg/net/metadata"
|
||||
"github.com/bilibili/kratos/pkg/stat"
|
||||
)
|
||||
|
||||
var (
|
||||
statsClient = stat.RPCClient
|
||||
statsServer = stat.RPCServer
|
||||
)
|
||||
|
||||
func logFn(code int, dt time.Duration) func(context.Context, ...log.D) {
|
||||
switch {
|
||||
case code < 0:
|
||||
return log.Errorv
|
||||
case dt >= time.Millisecond*500:
|
||||
// TODO: slowlog make it configurable.
|
||||
return log.Warnv
|
||||
case code > 0:
|
||||
return log.Warnv
|
||||
}
|
||||
return log.Infov
|
||||
}
|
||||
|
||||
// clientLogging warden grpc logging
|
||||
func clientLogging() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
|
||||
startTime := time.Now()
|
||||
var peerInfo peer.Peer
|
||||
opts = append(opts, grpc.Peer(&peerInfo))
|
||||
|
||||
// invoker requests
|
||||
err := invoker(ctx, method, req, reply, cc, opts...)
|
||||
|
||||
// after request
|
||||
code := ecode.Cause(err).Code()
|
||||
duration := time.Since(startTime)
|
||||
// monitor
|
||||
statsClient.Timing(method, int64(duration/time.Millisecond))
|
||||
statsClient.Incr(method, strconv.Itoa(code))
|
||||
|
||||
var ip string
|
||||
if peerInfo.Addr != nil {
|
||||
ip = peerInfo.Addr.String()
|
||||
}
|
||||
logFields := []log.D{
|
||||
log.KVString("ip", ip),
|
||||
log.KVString("path", method),
|
||||
log.KVInt("ret", code),
|
||||
// TODO: it will panic if someone remove String method from protobuf message struct that auto generate from protoc.
|
||||
log.KVString("args", req.(fmt.Stringer).String()),
|
||||
log.KVFloat64("ts", duration.Seconds()),
|
||||
log.KVString("source", "grpc-access-log"),
|
||||
}
|
||||
if err != nil {
|
||||
logFields = append(logFields, log.KV("error", err.Error()), log.KVString("stack", fmt.Sprintf("%+v", err)))
|
||||
}
|
||||
logFn(code, duration)(ctx, logFields...)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// serverLogging warden grpc logging
|
||||
func serverLogging() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
startTime := time.Now()
|
||||
caller := metadata.String(ctx, metadata.Caller)
|
||||
if caller == "" {
|
||||
caller = "no_user"
|
||||
}
|
||||
var remoteIP string
|
||||
if peerInfo, ok := peer.FromContext(ctx); ok {
|
||||
remoteIP = peerInfo.Addr.String()
|
||||
}
|
||||
var quota float64
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
quota = time.Until(deadline).Seconds()
|
||||
}
|
||||
|
||||
// call server handler
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
// after server response
|
||||
code := ecode.Cause(err).Code()
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// monitor
|
||||
statsServer.Timing(caller, int64(duration/time.Millisecond), info.FullMethod)
|
||||
statsServer.Incr(caller, info.FullMethod, strconv.Itoa(code))
|
||||
logFields := []log.D{
|
||||
log.KVString("user", caller),
|
||||
log.KVString("ip", remoteIP),
|
||||
log.KVString("path", info.FullMethod),
|
||||
log.KVInt("ret", code),
|
||||
// TODO: it will panic if someone remove String method from protobuf message struct that auto generate from protoc.
|
||||
log.KVString("args", req.(fmt.Stringer).String()),
|
||||
log.KVFloat64("ts", duration.Seconds()),
|
||||
log.KVFloat64("timeout_quota", quota),
|
||||
log.KVString("source", "grpc-access-log"),
|
||||
}
|
||||
if err != nil {
|
||||
logFields = append(logFields, log.KV("error", err.Error()), log.KV("stack", fmt.Sprintf("%+v", err)))
|
||||
}
|
||||
logFn(code, duration)(ctx, logFields...)
|
||||
return resp, err
|
||||
}
|
||||
}
|
55
pkg/net/rpc/warden/logging_test.go
Normal file
55
pkg/net/rpc/warden/logging_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
)
|
||||
|
||||
func Test_logFn(t *testing.T) {
|
||||
type args struct {
|
||||
code int
|
||||
dt time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want func(context.Context, ...log.D)
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{code: 0, dt: time.Millisecond},
|
||||
want: log.Infov,
|
||||
},
|
||||
{
|
||||
name: "slowlog",
|
||||
args: args{code: 0, dt: time.Second},
|
||||
want: log.Warnv,
|
||||
},
|
||||
{
|
||||
name: "business error",
|
||||
args: args{code: 2233, dt: time.Millisecond},
|
||||
want: log.Warnv,
|
||||
},
|
||||
{
|
||||
name: "system error",
|
||||
args: args{code: -1, dt: 0},
|
||||
want: log.Errorv,
|
||||
},
|
||||
{
|
||||
name: "system error and slowlog",
|
||||
args: args{code: -1, dt: time.Second},
|
||||
want: log.Errorv,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := logFn(tt.args.code, tt.args.dt); reflect.ValueOf(got).Pointer() != reflect.ValueOf(tt.want).Pointer() {
|
||||
t.Errorf("unexpect log function!")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
61
pkg/net/rpc/warden/recovery.go
Normal file
61
pkg/net/rpc/warden/recovery.go
Normal file
@ -0,0 +1,61 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// recovery is a server interceptor that recovers from any panics.
|
||||
func (s *Server) recovery() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
defer func() {
|
||||
if rerr := recover(); rerr != nil {
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
rs := runtime.Stack(buf, false)
|
||||
if rs > size {
|
||||
rs = size
|
||||
}
|
||||
buf = buf[:rs]
|
||||
pl := fmt.Sprintf("grpc server panic: %v\n%v\n%s\n", req, rerr, buf)
|
||||
fmt.Fprintf(os.Stderr, pl)
|
||||
log.Error(pl)
|
||||
err = status.Errorf(codes.Unknown, ecode.ServerErr.Error())
|
||||
}
|
||||
}()
|
||||
resp, err = handler(ctx, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// recovery return a client interceptor that recovers from any panics.
|
||||
func (c *Client) recovery() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
|
||||
defer func() {
|
||||
if rerr := recover(); rerr != nil {
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
rs := runtime.Stack(buf, false)
|
||||
if rs > size {
|
||||
rs = size
|
||||
}
|
||||
buf = buf[:rs]
|
||||
pl := fmt.Sprintf("grpc client panic: %v\n%v\n%v\n%s\n", req, reply, rerr, buf)
|
||||
fmt.Fprintf(os.Stderr, pl)
|
||||
log.Error(pl)
|
||||
err = ecode.ServerErr
|
||||
}
|
||||
}()
|
||||
err = invoker(ctx, method, req, reply, cc, opts...)
|
||||
return
|
||||
}
|
||||
}
|
17
pkg/net/rpc/warden/resolver/CHANGELOG.md
Normal file
17
pkg/net/rpc/warden/resolver/CHANGELOG.md
Normal file
@ -0,0 +1,17 @@
|
||||
### business/warden/resolver
|
||||
|
||||
##### Version 1.1.1
|
||||
1. add dial helper
|
||||
|
||||
##### Version 1.1.0
|
||||
1. 增加了子集选择算法
|
||||
|
||||
##### Version 1.0.2
|
||||
1. 增加GET接口
|
||||
|
||||
##### Version 1.0.1
|
||||
1. 支持zone和clusters
|
||||
|
||||
|
||||
##### Version 1.0.0
|
||||
1. 实现了基本的服务发现功能
|
9
pkg/net/rpc/warden/resolver/OWNERS
Normal file
9
pkg/net/rpc/warden/resolver/OWNERS
Normal file
@ -0,0 +1,9 @@
|
||||
# See the OWNERS docs at https://go.k8s.io/owners
|
||||
|
||||
approvers:
|
||||
- caoguoliang
|
||||
labels:
|
||||
- library
|
||||
reviewers:
|
||||
- caoguoliang
|
||||
- maojian
|
13
pkg/net/rpc/warden/resolver/README.md
Normal file
13
pkg/net/rpc/warden/resolver/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
#### business/warden/resolver
|
||||
|
||||
##### 项目简介
|
||||
|
||||
warden 的 服务发现模块,用于从底层的注册中心中获取Server节点列表并返回给GRPC
|
||||
|
||||
##### 编译环境
|
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行**
|
||||
|
||||
##### 依赖包
|
||||
|
||||
- [grpc](google.golang.org/grpc)
|
6
pkg/net/rpc/warden/resolver/direct/CHANGELOG.md
Normal file
6
pkg/net/rpc/warden/resolver/direct/CHANGELOG.md
Normal file
@ -0,0 +1,6 @@
|
||||
### business/warden/resolver/direct
|
||||
|
||||
|
||||
##### Version 1.0.0
|
||||
|
||||
1. 实现了基本的服务发现直连功能
|
14
pkg/net/rpc/warden/resolver/direct/README.md
Normal file
14
pkg/net/rpc/warden/resolver/direct/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
#### business/warden/resolver/direct
|
||||
|
||||
##### 项目简介
|
||||
|
||||
warden 的直连服务模块,用于通过IP地址列表直接连接后端服务
|
||||
连接字符串格式: direct://default/192.168.1.1:8080,192.168.1.2:8081
|
||||
|
||||
##### 编译环境
|
||||
|
||||
- **请只用 Golang v1.9.x 以上版本编译执行**
|
||||
|
||||
##### 依赖包
|
||||
|
||||
- [grpc](google.golang.org/grpc)
|
77
pkg/net/rpc/warden/resolver/direct/direct.go
Normal file
77
pkg/net/rpc/warden/resolver/direct/direct.go
Normal file
@ -0,0 +1,77 @@
|
||||
package direct
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
"github.com/bilibili/kratos/pkg/naming"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
|
||||
)
|
||||
|
||||
const (
|
||||
// Name is the name of direct resolver
|
||||
Name = "direct"
|
||||
)
|
||||
|
||||
var _ naming.Resolver = &Direct{}
|
||||
|
||||
// New return Direct
|
||||
func New() *Direct {
|
||||
return &Direct{}
|
||||
}
|
||||
|
||||
// Build build direct.
|
||||
func Build(id string) *Direct {
|
||||
return &Direct{id: id}
|
||||
}
|
||||
|
||||
// Direct is a resolver for conneting endpoints directly.
|
||||
// example format: direct://default/192.168.1.1:8080,192.168.1.2:8081
|
||||
type Direct struct {
|
||||
id string
|
||||
}
|
||||
|
||||
// Build direct build.
|
||||
func (d *Direct) Build(id string) naming.Resolver {
|
||||
return &Direct{id: id}
|
||||
}
|
||||
|
||||
// Scheme return the Scheme of Direct
|
||||
func (d *Direct) Scheme() string {
|
||||
return Name
|
||||
}
|
||||
|
||||
// Watch a tree
|
||||
func (d *Direct) Watch() <-chan struct{} {
|
||||
ch := make(chan struct{}, 1)
|
||||
ch <- struct{}{}
|
||||
return ch
|
||||
}
|
||||
|
||||
//Unwatch a tree
|
||||
func (d *Direct) Unwatch(id string) {
|
||||
}
|
||||
|
||||
//Fetch fetch isntances
|
||||
func (d *Direct) Fetch(ctx context.Context) (insMap map[string][]*naming.Instance, found bool) {
|
||||
var ins []*naming.Instance
|
||||
|
||||
addrs := strings.Split(d.id, ",")
|
||||
for _, addr := range addrs {
|
||||
ins = append(ins, &naming.Instance{
|
||||
Addrs: []string{fmt.Sprintf("%s://%s", resolver.Scheme, addr)},
|
||||
})
|
||||
}
|
||||
if len(ins) > 0 {
|
||||
found = true
|
||||
}
|
||||
insMap = map[string][]*naming.Instance{env.Zone: ins}
|
||||
return
|
||||
}
|
||||
|
||||
//Close close Direct
|
||||
func (d *Direct) Close() error {
|
||||
return nil
|
||||
}
|
85
pkg/net/rpc/warden/resolver/direct/direct_test.go
Normal file
85
pkg/net/rpc/warden/resolver/direct/direct_test.go
Normal file
@ -0,0 +1,85 @@
|
||||
package direct
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
)
|
||||
|
||||
type testServer struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (ts *testServer) SayHello(context.Context, *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
return &pb.HelloReply{Message: ts.name, Success: true}, nil
|
||||
}
|
||||
|
||||
func (ts *testServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
|
||||
panic("not implement error")
|
||||
}
|
||||
|
||||
func createServer(name, listen string) *warden.Server {
|
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second)})
|
||||
ts := &testServer{name}
|
||||
pb.RegisterGreeterServer(s.Server(), ts)
|
||||
go func() {
|
||||
if err := s.Run(listen); err != nil {
|
||||
panic(fmt.Sprintf("run warden server fail! err: %s", err))
|
||||
}
|
||||
}()
|
||||
return s
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
resolver.Register(New())
|
||||
ctx := context.TODO()
|
||||
s1 := createServer("server1", "127.0.0.1:18081")
|
||||
s2 := createServer("server2", "127.0.0.1:18082")
|
||||
defer s1.Shutdown(ctx)
|
||||
defer s2.Shutdown(ctx)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func createTestClient(t *testing.T, connStr string) pb.GreeterClient {
|
||||
client := warden.NewClient(&warden.ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
})
|
||||
conn, err := client.Dial(context.TODO(), connStr)
|
||||
if err != nil {
|
||||
t.Fatalf("create client fail!err%s", err)
|
||||
}
|
||||
return pb.NewGreeterClient(conn)
|
||||
}
|
||||
|
||||
func TestDirect(t *testing.T) {
|
||||
cli := createTestClient(t, "direct://default/127.0.0.1:18083,127.0.0.1:18082")
|
||||
count := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
if resp, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("TestDirect: SayHello failed!err:=%v", err)
|
||||
} else {
|
||||
if resp.Message == "server2" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
if count != 10 {
|
||||
t.Fatalf("TestDirect: get server2 times must be 10")
|
||||
}
|
||||
}
|
204
pkg/net/rpc/warden/resolver/resolver.go
Normal file
204
pkg/net/rpc/warden/resolver/resolver.go
Normal file
@ -0,0 +1,204 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
"github.com/bilibili/kratos/pkg/naming"
|
||||
wmeta "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
|
||||
|
||||
"github.com/dgryski/go-farm"
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc/resolver"
|
||||
)
|
||||
|
||||
const (
|
||||
// Scheme is the scheme of discovery address
|
||||
Scheme = "grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
_ resolver.Resolver = &Resolver{}
|
||||
_ resolver.Builder = &Builder{}
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
// Register register resolver builder if nil.
|
||||
func Register(b naming.Builder) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if resolver.Get(b.Scheme()) == nil {
|
||||
resolver.Register(&Builder{b})
|
||||
}
|
||||
}
|
||||
|
||||
// Set override any registered builder
|
||||
func Set(b naming.Builder) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
resolver.Register(&Builder{b})
|
||||
}
|
||||
|
||||
// Builder is also a resolver builder.
|
||||
// It's build() function always returns itself.
|
||||
type Builder struct {
|
||||
naming.Builder
|
||||
}
|
||||
|
||||
// Build returns itself for Resolver, because it's both a builder and a resolver.
|
||||
func (b *Builder) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOption) (resolver.Resolver, error) {
|
||||
var zone = env.Zone
|
||||
ss := int64(50)
|
||||
clusters := map[string]struct{}{}
|
||||
str := strings.SplitN(target.Endpoint, "?", 2)
|
||||
if len(str) == 0 {
|
||||
return nil, errors.Errorf("warden resolver: parse target.Endpoint(%s) failed!err:=endpoint is empty", target.Endpoint)
|
||||
} else if len(str) == 2 {
|
||||
m, err := url.ParseQuery(str[1])
|
||||
if err == nil {
|
||||
for _, c := range m[naming.MetaCluster] {
|
||||
clusters[c] = struct{}{}
|
||||
}
|
||||
zones := m[naming.MetaZone]
|
||||
if len(zones) > 0 {
|
||||
zone = zones[0]
|
||||
}
|
||||
if sub, ok := m["subset"]; ok {
|
||||
if t, err := strconv.ParseInt(sub[0], 10, 64); err == nil {
|
||||
ss = t
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
r := &Resolver{
|
||||
nr: b.Builder.Build(str[0]),
|
||||
cc: cc,
|
||||
quit: make(chan struct{}, 1),
|
||||
clusters: clusters,
|
||||
zone: zone,
|
||||
subsetSize: ss,
|
||||
}
|
||||
go r.updateproc()
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Resolver watches for the updates on the specified target.
|
||||
// Updates include address updates and service config updates.
|
||||
type Resolver struct {
|
||||
nr naming.Resolver
|
||||
cc resolver.ClientConn
|
||||
quit chan struct{}
|
||||
|
||||
clusters map[string]struct{}
|
||||
zone string
|
||||
subsetSize int64
|
||||
}
|
||||
|
||||
// Close is a noop for Resolver.
|
||||
func (r *Resolver) Close() {
|
||||
select {
|
||||
case r.quit <- struct{}{}:
|
||||
r.nr.Close()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveNow is a noop for Resolver.
|
||||
func (r *Resolver) ResolveNow(o resolver.ResolveNowOption) {
|
||||
}
|
||||
|
||||
func (r *Resolver) updateproc() {
|
||||
event := r.nr.Watch()
|
||||
for {
|
||||
select {
|
||||
case <-r.quit:
|
||||
return
|
||||
case _, ok := <-event:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
if insInfo, ok := r.nr.Fetch(context.Background()); ok {
|
||||
instances, ok := insInfo.Instances[r.zone]
|
||||
if !ok {
|
||||
for _, value := range insInfo.Instances {
|
||||
instances = append(instances, value...)
|
||||
}
|
||||
}
|
||||
if r.subsetSize > 0 && len(instances) > 0 {
|
||||
instances = r.subset(instances, env.Hostname, r.subsetSize)
|
||||
}
|
||||
r.newAddress(instances)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) subset(backends []*naming.Instance, clientID string, size int64) []*naming.Instance {
|
||||
if len(backends) <= int(size) {
|
||||
return backends
|
||||
}
|
||||
sort.Slice(backends, func(i, j int) bool {
|
||||
return backends[i].Hostname < backends[j].Hostname
|
||||
})
|
||||
count := int64(len(backends)) / size
|
||||
|
||||
id := farm.Fingerprint64([]byte(clientID))
|
||||
round := int64(id / uint64(count))
|
||||
|
||||
s := rand.NewSource(round)
|
||||
ra := rand.New(s)
|
||||
ra.Shuffle(len(backends), func(i, j int) {
|
||||
backends[i], backends[j] = backends[j], backends[i]
|
||||
})
|
||||
start := (id % uint64(count)) * uint64(size)
|
||||
return backends[int(start) : int(start)+int(size)]
|
||||
}
|
||||
|
||||
func (r *Resolver) newAddress(instances []*naming.Instance) {
|
||||
if len(instances) <= 0 {
|
||||
return
|
||||
}
|
||||
addrs := make([]resolver.Address, 0, len(instances))
|
||||
for _, ins := range instances {
|
||||
if len(r.clusters) > 0 {
|
||||
if _, ok := r.clusters[ins.Metadata[naming.MetaCluster]]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var weight int64
|
||||
if weight, _ = strconv.ParseInt(ins.Metadata[naming.MetaWeight], 10, 64); weight <= 0 {
|
||||
weight = 10
|
||||
}
|
||||
var rpc string
|
||||
for _, a := range ins.Addrs {
|
||||
u, err := url.Parse(a)
|
||||
if err == nil && u.Scheme == Scheme {
|
||||
rpc = u.Host
|
||||
}
|
||||
}
|
||||
if rpc == "" {
|
||||
fmt.Fprintf(os.Stderr, "warden/resolver: app(%s,%s) no valid grpc address(%v) found!", ins.AppID, ins.Hostname, ins.Addrs)
|
||||
log.Warn("warden/resolver: invalid rpc address(%s,%s,%v) found!", ins.AppID, ins.Hostname, ins.Addrs)
|
||||
continue
|
||||
}
|
||||
addr := resolver.Address{
|
||||
Addr: rpc,
|
||||
Type: resolver.Backend,
|
||||
ServerName: ins.AppID,
|
||||
Metadata: wmeta.MD{Weight: uint64(weight), Color: ins.Metadata[naming.MetaColor]},
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
r.cc.NewAddress(addrs)
|
||||
}
|
87
pkg/net/rpc/warden/resolver/test/mockdiscovery.go
Normal file
87
pkg/net/rpc/warden/resolver/test/mockdiscovery.go
Normal file
@ -0,0 +1,87 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
"github.com/bilibili/kratos/pkg/naming"
|
||||
)
|
||||
|
||||
type mockDiscoveryBuilder struct {
|
||||
instances map[string]*naming.Instance
|
||||
watchch map[string][]*mockDiscoveryResolver
|
||||
}
|
||||
|
||||
func (mb *mockDiscoveryBuilder) Build(id string) naming.Resolver {
|
||||
mr := &mockDiscoveryResolver{
|
||||
d: mb,
|
||||
watchch: make(chan struct{}, 1),
|
||||
}
|
||||
mb.watchch[id] = append(mb.watchch[id], mr)
|
||||
mr.watchch <- struct{}{}
|
||||
return mr
|
||||
}
|
||||
func (mb *mockDiscoveryBuilder) Scheme() string {
|
||||
return "mockdiscovery"
|
||||
}
|
||||
|
||||
type mockDiscoveryResolver struct {
|
||||
//instances map[string]*naming.Instance
|
||||
d *mockDiscoveryBuilder
|
||||
watchch chan struct{}
|
||||
}
|
||||
|
||||
var _ naming.Resolver = &mockDiscoveryResolver{}
|
||||
|
||||
func (md *mockDiscoveryResolver) Fetch(ctx context.Context) (map[string][]*naming.Instance, bool) {
|
||||
zones := make(map[string][]*naming.Instance)
|
||||
for _, v := range md.d.instances {
|
||||
zones[v.Zone] = append(zones[v.Zone], v)
|
||||
}
|
||||
return zones, len(zones) > 0
|
||||
}
|
||||
|
||||
func (md *mockDiscoveryResolver) Watch() <-chan struct{} {
|
||||
return md.watchch
|
||||
}
|
||||
|
||||
func (md *mockDiscoveryResolver) Close() error {
|
||||
close(md.watchch)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (md *mockDiscoveryResolver) Scheme() string {
|
||||
return "mockdiscovery"
|
||||
}
|
||||
|
||||
func (mb *mockDiscoveryBuilder) registry(appID string, hostname, rpc string, metadata map[string]string) {
|
||||
ins := &naming.Instance{
|
||||
AppID: appID,
|
||||
Env: "hello=world",
|
||||
Hostname: hostname,
|
||||
Addrs: []string{"grpc://" + rpc},
|
||||
Version: "1.1",
|
||||
Zone: env.Zone,
|
||||
Metadata: metadata,
|
||||
}
|
||||
mb.instances[hostname] = ins
|
||||
if ch, ok := mb.watchch[appID]; ok {
|
||||
var bullet struct{}
|
||||
for _, c := range ch {
|
||||
c.watchch <- bullet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *mockDiscoveryBuilder) cancel(hostname string) {
|
||||
ins, ok := mb.instances[hostname]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(mb.instances, hostname)
|
||||
if ch, ok := mb.watchch[ins.AppID]; ok {
|
||||
var bullet struct{}
|
||||
for _, c := range ch {
|
||||
c.watchch <- bullet
|
||||
}
|
||||
}
|
||||
}
|
312
pkg/net/rpc/warden/resolver/test/resovler_test.go
Normal file
312
pkg/net/rpc/warden/resolver/test/resovler_test.go
Normal file
@ -0,0 +1,312 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/env"
|
||||
"github.com/bilibili/kratos/pkg/naming"
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/resolver"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var testServerMap map[string]*testServer
|
||||
|
||||
func init() {
|
||||
testServerMap = make(map[string]*testServer)
|
||||
}
|
||||
|
||||
const testAppID = "main.test"
|
||||
|
||||
type testServer struct {
|
||||
SayHelloCount int
|
||||
}
|
||||
|
||||
func resetCount() {
|
||||
for _, s := range testServerMap {
|
||||
s.SayHelloCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *testServer) SayHello(context.Context, *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
ts.SayHelloCount++
|
||||
return &pb.HelloReply{Message: "hello", Success: true}, nil
|
||||
}
|
||||
|
||||
func (ts *testServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
|
||||
panic("not implement error")
|
||||
}
|
||||
|
||||
func createServer(name, listen string) *warden.Server {
|
||||
s := warden.NewServer(&warden.ServerConfig{Timeout: xtime.Duration(time.Second)})
|
||||
ts := &testServer{}
|
||||
testServerMap[name] = ts
|
||||
pb.RegisterGreeterServer(s.Server(), ts)
|
||||
go func() {
|
||||
if err := s.Run(listen); err != nil {
|
||||
panic(fmt.Sprintf("run warden server fail! err: %s", err))
|
||||
}
|
||||
}()
|
||||
return s
|
||||
}
|
||||
|
||||
func NSayHello(c pb.GreeterClient, n int) func(*testing.T) {
|
||||
return func(t *testing.T) {
|
||||
for i := 0; i < n; i++ {
|
||||
if _, err := c.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTestClient(t *testing.T) pb.GreeterClient {
|
||||
client := warden.NewClient(&warden.ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
})
|
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test")
|
||||
if err != nil {
|
||||
t.Fatalf("create client fail!err%s", err)
|
||||
}
|
||||
return pb.NewGreeterClient(conn)
|
||||
}
|
||||
|
||||
var mockResolver *mockDiscoveryBuilder
|
||||
|
||||
func newMockDiscoveryBuilder() *mockDiscoveryBuilder {
|
||||
return &mockDiscoveryBuilder{
|
||||
instances: make(map[string]*naming.Instance),
|
||||
watchch: make(map[string][]*mockDiscoveryResolver),
|
||||
}
|
||||
}
|
||||
func TestMain(m *testing.M) {
|
||||
ctx := context.TODO()
|
||||
mockResolver = newMockDiscoveryBuilder()
|
||||
resolver.Set(mockResolver)
|
||||
s1 := createServer("server1", "127.0.0.1:18081")
|
||||
s2 := createServer("server2", "127.0.0.1:18082")
|
||||
s3 := createServer("server3", "127.0.0.1:18083")
|
||||
s4 := createServer("server4", "127.0.0.1:18084")
|
||||
s5 := createServer("server5", "127.0.0.1:18085")
|
||||
defer s1.Shutdown(ctx)
|
||||
defer s2.Shutdown(ctx)
|
||||
defer s3.Shutdown(ctx)
|
||||
defer s4.Shutdown(ctx)
|
||||
defer s5.Shutdown(ctx)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestAddResolver(t *testing.T) {
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
|
||||
c := createTestClient(t)
|
||||
t.Run("test_say_hello", NSayHello(c, 10))
|
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount)
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestDeleteResolver(t *testing.T) {
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
|
||||
c := createTestClient(t)
|
||||
t.Run("test_say_hello", NSayHello(c, 10))
|
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount)
|
||||
|
||||
mockResolver.cancel("server1")
|
||||
resetCount()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
t.Run("test_say_hello", NSayHello(c, 10))
|
||||
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount)
|
||||
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestUpdateResolver(t *testing.T) {
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
|
||||
|
||||
c := createTestClient(t)
|
||||
t.Run("test_say_hello", NSayHello(c, 10))
|
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount)
|
||||
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18083", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18084", map[string]string{})
|
||||
resetCount()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
t.Run("test_say_hello", NSayHello(c, 10))
|
||||
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount+testServerMap["server2"].SayHelloCount)
|
||||
assert.Equal(t, 10, testServerMap["server3"].SayHelloCount+testServerMap["server4"].SayHelloCount)
|
||||
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestErrorResolver(t *testing.T) {
|
||||
mockResolver := newMockDiscoveryBuilder()
|
||||
resolver.Set(mockResolver)
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server6", "127.0.0.1:18086", map[string]string{})
|
||||
|
||||
c := createTestClient(t)
|
||||
t.Run("test_say_hello", NSayHello(c, 10))
|
||||
assert.Equal(t, 10, testServerMap["server1"].SayHelloCount)
|
||||
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestClusterResolver(t *testing.T) {
|
||||
mockResolver := newMockDiscoveryBuilder()
|
||||
resolver.Set(mockResolver)
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{"cluster": "c1"})
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{"cluster": "c1"})
|
||||
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{"cluster": "c2"})
|
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server5", "127.0.0.1:18084", map[string]string{})
|
||||
|
||||
client := warden.NewClient(&warden.ClientConfig{Clusters: []string{"c1"}})
|
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test?cluster=c2")
|
||||
if err != nil {
|
||||
t.Fatalf("create client fail!err%s", err)
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
cli := pb.NewGreeterClient(conn)
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
assert.Equal(t, 1, testServerMap["server1"].SayHelloCount)
|
||||
assert.Equal(t, 1, testServerMap["server2"].SayHelloCount)
|
||||
assert.Equal(t, 1, testServerMap["server3"].SayHelloCount)
|
||||
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestNoClusterResolver(t *testing.T) {
|
||||
mockResolver := newMockDiscoveryBuilder()
|
||||
resolver.Set(mockResolver)
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{"cluster": "c1"})
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{"cluster": "c1"})
|
||||
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{"cluster": "c2"})
|
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
|
||||
client := warden.NewClient(&warden.ClientConfig{})
|
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test")
|
||||
if err != nil {
|
||||
t.Fatalf("create client fail!err%s", err)
|
||||
}
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
cli := pb.NewGreeterClient(conn)
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
assert.Equal(t, 1, testServerMap["server1"].SayHelloCount)
|
||||
assert.Equal(t, 1, testServerMap["server2"].SayHelloCount)
|
||||
assert.Equal(t, 1, testServerMap["server3"].SayHelloCount)
|
||||
assert.Equal(t, 1, testServerMap["server4"].SayHelloCount)
|
||||
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestZoneResolver(t *testing.T) {
|
||||
mockResolver := newMockDiscoveryBuilder()
|
||||
resolver.Set(mockResolver)
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
|
||||
env.Zone = "testsh"
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
|
||||
env.Zone = "hhhh"
|
||||
client := warden.NewClient(&warden.ClientConfig{Zone: "testsh"})
|
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test")
|
||||
if err != nil {
|
||||
t.Fatalf("create client fail!err%s", err)
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
cli := pb.NewGreeterClient(conn)
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
assert.Equal(t, 0, testServerMap["server1"].SayHelloCount)
|
||||
assert.Equal(t, 3, testServerMap["server2"].SayHelloCount)
|
||||
|
||||
resetCount()
|
||||
}
|
||||
|
||||
func TestSubsetConn(t *testing.T) {
|
||||
mockResolver := newMockDiscoveryBuilder()
|
||||
resolver.Set(mockResolver)
|
||||
mockResolver.registry(testAppID, "server1", "127.0.0.1:18081", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server2", "127.0.0.1:18082", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server3", "127.0.0.1:18083", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
|
||||
mockResolver.registry(testAppID, "server5", "127.0.0.1:18085", map[string]string{})
|
||||
|
||||
client := warden.NewClient(nil)
|
||||
conn, err := client.Dial(context.TODO(), "mockdiscovery://authority/main.test?subset=3")
|
||||
if err != nil {
|
||||
t.Fatalf("create client fail!err%s", err)
|
||||
}
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
cli := pb.NewGreeterClient(conn)
|
||||
for i := 0; i < 6; i++ {
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount)
|
||||
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount)
|
||||
assert.Equal(t, 2, testServerMap["server4"].SayHelloCount)
|
||||
resetCount()
|
||||
mockResolver.cancel("server4")
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
for i := 0; i < 6; i++ {
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount)
|
||||
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount)
|
||||
assert.Equal(t, 2, testServerMap["server3"].SayHelloCount)
|
||||
resetCount()
|
||||
mockResolver.registry(testAppID, "server4", "127.0.0.1:18084", map[string]string{})
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
for i := 0; i < 6; i++ {
|
||||
if _, err := cli.SayHello(context.TODO(), &pb.HelloRequest{Age: 1, Name: "hello"}); err != nil {
|
||||
t.Fatalf("call sayhello fail! err: %s", err)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, testServerMap["server2"].SayHelloCount)
|
||||
assert.Equal(t, 2, testServerMap["server5"].SayHelloCount)
|
||||
assert.Equal(t, 2, testServerMap["server4"].SayHelloCount)
|
||||
}
|
16
pkg/net/rpc/warden/resolver/util.go
Normal file
16
pkg/net/rpc/warden/resolver/util.go
Normal file
@ -0,0 +1,16 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// RegisterTarget will register grpc discovery mock address flag
|
||||
func RegisterTarget(target *string, discoveryID string) {
|
||||
flag.CommandLine.StringVar(
|
||||
target,
|
||||
fmt.Sprintf("grpc.%s", discoveryID),
|
||||
fmt.Sprintf("discovery://default/%s", discoveryID),
|
||||
fmt.Sprintf("App's grpc target.\n example: -grpc.%s=\"127.0.0.1:9090\"", discoveryID),
|
||||
)
|
||||
}
|
332
pkg/net/rpc/warden/server.go
Normal file
332
pkg/net/rpc/warden/server.go
Normal file
@ -0,0 +1,332 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/conf/dsn"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
"github.com/bilibili/kratos/pkg/net/trace"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
//this package is for json format response
|
||||
_ "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/encoding/json"
|
||||
"github.com/bilibili/kratos/pkg/net/rpc/warden/internal/status"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
var (
|
||||
_grpcDSN string
|
||||
_defaultSerConf = &ServerConfig{
|
||||
Network: "tcp",
|
||||
Addr: "0.0.0.0:9000",
|
||||
Timeout: xtime.Duration(time.Second),
|
||||
IdleTimeout: xtime.Duration(time.Second * 60),
|
||||
MaxLifeTime: xtime.Duration(time.Hour * 2),
|
||||
ForceCloseWait: xtime.Duration(time.Second * 20),
|
||||
KeepAliveInterval: xtime.Duration(time.Second * 60),
|
||||
KeepAliveTimeout: xtime.Duration(time.Second * 20),
|
||||
}
|
||||
_abortIndex int8 = math.MaxInt8 / 2
|
||||
)
|
||||
|
||||
// ServerConfig is rpc server conf.
|
||||
type ServerConfig struct {
|
||||
// Network is grpc listen network,default value is tcp
|
||||
Network string `dsn:"network"`
|
||||
// Addr is grpc listen addr,default value is 0.0.0.0:9000
|
||||
Addr string `dsn:"address"`
|
||||
// Timeout is context timeout for per rpc call.
|
||||
Timeout xtime.Duration `dsn:"query.timeout"`
|
||||
// IdleTimeout is a duration for the amount of time after which an idle connection would be closed by sending a GoAway.
|
||||
// Idleness duration is defined since the most recent time the number of outstanding RPCs became zero or the connection establishment.
|
||||
IdleTimeout xtime.Duration `dsn:"query.idleTimeout"`
|
||||
// MaxLifeTime is a duration for the maximum amount of time a connection may exist before it will be closed by sending a GoAway.
|
||||
// A random jitter of +/-10% will be added to MaxConnectionAge to spread out connection storms.
|
||||
MaxLifeTime xtime.Duration `dsn:"query.maxLife"`
|
||||
// ForceCloseWait is an additive period after MaxLifeTime after which the connection will be forcibly closed.
|
||||
ForceCloseWait xtime.Duration `dsn:"query.closeWait"`
|
||||
// KeepAliveInterval is after a duration of this time if the server doesn't see any activity it pings the client to see if the transport is still alive.
|
||||
KeepAliveInterval xtime.Duration `dsn:"query.keepaliveInterval"`
|
||||
// KeepAliveTimeout is After having pinged for keepalive check, the server waits for a duration of Timeout and if no activity is seen even after that
|
||||
// the connection is closed.
|
||||
KeepAliveTimeout xtime.Duration `dsn:"query.keepaliveTimeout"`
|
||||
}
|
||||
|
||||
// Server is the framework's server side instance, it contains the GrpcServer, interceptor and interceptors.
|
||||
// Create an instance of Server, by using NewServer().
|
||||
type Server struct {
|
||||
conf *ServerConfig
|
||||
mutex sync.RWMutex
|
||||
|
||||
server *grpc.Server
|
||||
handlers []grpc.UnaryServerInterceptor
|
||||
}
|
||||
|
||||
// handle return a new unary server interceptor for OpenTracing\Logging\LinkTimeout.
|
||||
func (s *Server) handle() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
var (
|
||||
cancel func()
|
||||
addr string
|
||||
)
|
||||
s.mutex.RLock()
|
||||
conf := s.conf
|
||||
s.mutex.RUnlock()
|
||||
// get derived timeout from grpc context,
|
||||
// compare with the warden configured,
|
||||
// and use the minimum one
|
||||
timeout := time.Duration(conf.Timeout)
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
ctimeout := time.Until(dl)
|
||||
if ctimeout-time.Millisecond*20 > 0 {
|
||||
ctimeout = ctimeout - time.Millisecond*20
|
||||
}
|
||||
if timeout > ctimeout {
|
||||
timeout = ctimeout
|
||||
}
|
||||
}
|
||||
ctx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// get grpc metadata(trace & remote_ip & color)
|
||||
var t trace.Trace
|
||||
cmd := nmd.MD{}
|
||||
if gmd, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
for key, vals := range gmd {
|
||||
if nmd.IsIncomingKey(key) {
|
||||
cmd[key] = vals[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
if t == nil {
|
||||
t = trace.New(args.FullMethod)
|
||||
} else {
|
||||
t.SetTitle(args.FullMethod)
|
||||
}
|
||||
|
||||
if pr, ok := peer.FromContext(ctx); ok {
|
||||
addr = pr.Addr.String()
|
||||
t.SetTag(trace.String(trace.TagAddress, addr))
|
||||
}
|
||||
defer t.Finish(&err)
|
||||
|
||||
// use common meta data context instead of grpc context
|
||||
ctx = nmd.NewContext(ctx, cmd)
|
||||
ctx = trace.NewContext(ctx, t)
|
||||
|
||||
resp, err = handler(ctx, req)
|
||||
return resp, status.FromError(err).Err()
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
addFlag(flag.CommandLine)
|
||||
}
|
||||
|
||||
func addFlag(fs *flag.FlagSet) {
|
||||
v := os.Getenv("GRPC")
|
||||
if v == "" {
|
||||
v = "tcp://0.0.0.0:9000/?timeout=1s&idle_timeout=60s"
|
||||
}
|
||||
fs.StringVar(&_grpcDSN, "grpc", v, "listen grpc dsn, or use GRPC env variable.")
|
||||
fs.Var(&_grpcTarget, "grpc.target", "usage: -grpc.target=seq.service=127.0.0.1:9000 -grpc.target=fav.service=192.168.10.1:9000")
|
||||
}
|
||||
|
||||
func parseDSN(rawdsn string) *ServerConfig {
|
||||
conf := new(ServerConfig)
|
||||
d, err := dsn.Parse(rawdsn)
|
||||
if err != nil {
|
||||
panic(errors.WithMessage(err, fmt.Sprintf("warden: invalid dsn: %s", rawdsn)))
|
||||
}
|
||||
if _, err = d.Bind(conf); err != nil {
|
||||
panic(errors.WithMessage(err, fmt.Sprintf("warden: invalid dsn: %s", rawdsn)))
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
// NewServer returns a new blank Server instance with a default server interceptor.
|
||||
func NewServer(conf *ServerConfig, opt ...grpc.ServerOption) (s *Server) {
|
||||
if conf == nil {
|
||||
if !flag.Parsed() {
|
||||
fmt.Fprint(os.Stderr, "[warden] please call flag.Parse() before Init warden server, some configure may not effect\n")
|
||||
}
|
||||
conf = parseDSN(_grpcDSN)
|
||||
}
|
||||
s = new(Server)
|
||||
if err := s.SetConfig(conf); err != nil {
|
||||
panic(errors.Errorf("warden: set config failed!err: %s", err.Error()))
|
||||
}
|
||||
keepParam := grpc.KeepaliveParams(keepalive.ServerParameters{
|
||||
MaxConnectionIdle: time.Duration(s.conf.IdleTimeout),
|
||||
MaxConnectionAgeGrace: time.Duration(s.conf.ForceCloseWait),
|
||||
Time: time.Duration(s.conf.KeepAliveInterval),
|
||||
Timeout: time.Duration(s.conf.KeepAliveTimeout),
|
||||
MaxConnectionAge: time.Duration(s.conf.MaxLifeTime),
|
||||
})
|
||||
opt = append(opt, keepParam, grpc.UnaryInterceptor(s.interceptor))
|
||||
s.server = grpc.NewServer(opt...)
|
||||
s.Use(s.recovery(), s.handle(), serverLogging(), s.stats(), s.validate())
|
||||
return
|
||||
}
|
||||
|
||||
// SetConfig hot reloads server config
|
||||
func (s *Server) SetConfig(conf *ServerConfig) (err error) {
|
||||
if conf == nil {
|
||||
conf = _defaultSerConf
|
||||
}
|
||||
if conf.Timeout <= 0 {
|
||||
conf.Timeout = xtime.Duration(time.Second)
|
||||
}
|
||||
if conf.IdleTimeout <= 0 {
|
||||
conf.IdleTimeout = xtime.Duration(time.Second * 60)
|
||||
}
|
||||
if conf.MaxLifeTime <= 0 {
|
||||
conf.MaxLifeTime = xtime.Duration(time.Hour * 2)
|
||||
}
|
||||
if conf.ForceCloseWait <= 0 {
|
||||
conf.ForceCloseWait = xtime.Duration(time.Second * 20)
|
||||
}
|
||||
if conf.KeepAliveInterval <= 0 {
|
||||
conf.KeepAliveInterval = xtime.Duration(time.Second * 60)
|
||||
}
|
||||
if conf.KeepAliveTimeout <= 0 {
|
||||
conf.KeepAliveTimeout = xtime.Duration(time.Second * 20)
|
||||
}
|
||||
if conf.Addr == "" {
|
||||
conf.Addr = "0.0.0.0:9000"
|
||||
}
|
||||
if conf.Network == "" {
|
||||
conf.Network = "tcp"
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.conf = conf
|
||||
s.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// interceptor is a single interceptor out of a chain of many interceptors.
|
||||
// Execution is done in left-to-right order, including passing of context.
|
||||
// For example ChainUnaryServer(one, two, three) will execute one before two before three, and three
|
||||
// will see context changes of one and two.
|
||||
func (s *Server) interceptor(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
var (
|
||||
i int
|
||||
chain grpc.UnaryHandler
|
||||
)
|
||||
|
||||
n := len(s.handlers)
|
||||
if n == 0 {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
chain = func(ic context.Context, ir interface{}) (interface{}, error) {
|
||||
if i == n-1 {
|
||||
return handler(ic, ir)
|
||||
}
|
||||
i++
|
||||
return s.handlers[i](ic, ir, args, chain)
|
||||
}
|
||||
|
||||
return s.handlers[0](ctx, req, args, chain)
|
||||
}
|
||||
|
||||
// Server return the grpc server for registering service.
|
||||
func (s *Server) Server() *grpc.Server {
|
||||
return s.server
|
||||
}
|
||||
|
||||
// Use attachs a global inteceptor to the server.
|
||||
// For example, this is the right place for a rate limiter or error management inteceptor.
|
||||
func (s *Server) Use(handlers ...grpc.UnaryServerInterceptor) *Server {
|
||||
finalSize := len(s.handlers) + len(handlers)
|
||||
if finalSize >= int(_abortIndex) {
|
||||
panic("warden: server use too many handlers")
|
||||
}
|
||||
mergedHandlers := make([]grpc.UnaryServerInterceptor, finalSize)
|
||||
copy(mergedHandlers, s.handlers)
|
||||
copy(mergedHandlers[len(s.handlers):], handlers)
|
||||
s.handlers = mergedHandlers
|
||||
return s
|
||||
}
|
||||
|
||||
// Run create a tcp listener and start goroutine for serving each incoming request.
|
||||
// Run will return a non-nil error unless Stop or GracefulStop is called.
|
||||
func (s *Server) Run(addr string) error {
|
||||
lis, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
log.Error("failed to listen: %v", err)
|
||||
return err
|
||||
}
|
||||
reflection.Register(s.server)
|
||||
return s.Serve(lis)
|
||||
}
|
||||
|
||||
// RunUnix create a unix listener and start goroutine for serving each incoming request.
|
||||
// RunUnix will return a non-nil error unless Stop or GracefulStop is called.
|
||||
func (s *Server) RunUnix(file string) error {
|
||||
lis, err := net.Listen("unix", file)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
log.Error("failed to listen: %v", err)
|
||||
return err
|
||||
}
|
||||
reflection.Register(s.server)
|
||||
return s.Serve(lis)
|
||||
}
|
||||
|
||||
// Start create a new goroutine run server with configured listen addr
|
||||
// will panic if any error happend
|
||||
// return server itself
|
||||
func (s *Server) Start() (*Server, error) {
|
||||
lis, err := net.Listen(s.conf.Network, s.conf.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reflection.Register(s.server)
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Serve accepts incoming connections on the listener lis, creating a new
|
||||
// ServerTransport and service goroutine for each.
|
||||
// Serve will return a non-nil error unless Stop or GracefulStop is called.
|
||||
func (s *Server) Serve(lis net.Listener) error {
|
||||
return s.server.Serve(lis)
|
||||
}
|
||||
|
||||
// Shutdown stops the server gracefully. It stops the server from
|
||||
// accepting new connections and RPCs and blocks until all the pending RPCs are
|
||||
// finished or the context deadline is reached.
|
||||
func (s *Server) Shutdown(ctx context.Context) (err error) {
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
s.server.GracefulStop()
|
||||
close(ch)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.server.Stop()
|
||||
err = ctx.Err()
|
||||
case <-ch:
|
||||
}
|
||||
return
|
||||
}
|
570
pkg/net/rpc/warden/server_test.go
Normal file
570
pkg/net/rpc/warden/server_test.go
Normal file
@ -0,0 +1,570 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
"github.com/bilibili/kratos/pkg/log"
|
||||
nmd "github.com/bilibili/kratos/pkg/net/metadata"
|
||||
"github.com/bilibili/kratos/pkg/net/netutil/breaker"
|
||||
pb "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/proto/testproto"
|
||||
xtrace "github.com/bilibili/kratos/pkg/net/trace"
|
||||
xtime "github.com/bilibili/kratos/pkg/time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const (
|
||||
_separator = "\001"
|
||||
)
|
||||
|
||||
var (
|
||||
outPut []string
|
||||
_testOnce sync.Once
|
||||
server *Server
|
||||
|
||||
clientConfig = ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
}
|
||||
clientConfig2 = ClientConfig{
|
||||
Dial: xtime.Duration(time.Second * 10),
|
||||
Timeout: xtime.Duration(time.Second * 10),
|
||||
Breaker: &breaker.Config{
|
||||
Window: xtime.Duration(3 * time.Second),
|
||||
Sleep: xtime.Duration(3 * time.Second),
|
||||
Bucket: 10,
|
||||
Ratio: 0.3,
|
||||
Request: 20,
|
||||
},
|
||||
Method: map[string]*ClientConfig{`/testproto.Greeter/SayHello`: {Timeout: xtime.Duration(time.Millisecond * 200)}},
|
||||
}
|
||||
)
|
||||
|
||||
type helloServer struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (s *helloServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
if in.Name == "trace_test" {
|
||||
t, isok := xtrace.FromContext(ctx)
|
||||
if !isok {
|
||||
t = xtrace.New("test title")
|
||||
s.t.Fatalf("no trace extracted from server context")
|
||||
}
|
||||
newCtx := xtrace.NewContext(ctx, t)
|
||||
if in.Age == 0 {
|
||||
runClient(newCtx, &clientConfig, s.t, "trace_test", 1)
|
||||
}
|
||||
} else if in.Name == "recovery_test" {
|
||||
panic("test recovery")
|
||||
} else if in.Name == "graceful_shutdown" {
|
||||
time.Sleep(time.Second * 3)
|
||||
} else if in.Name == "timeout_test" {
|
||||
if in.Age > 10 {
|
||||
s.t.Fatalf("can not deliver requests over 10 times because of link timeout")
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
_, err := runClient(ctx, &clientConfig, s.t, "timeout_test", in.Age+1)
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, err
|
||||
} else if in.Name == "timeout_test2" {
|
||||
if in.Age > 10 {
|
||||
s.t.Fatalf("can not deliver requests over 10 times because of link timeout")
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
|
||||
}
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
_, err := runClient(ctx, &clientConfig2, s.t, "timeout_test2", in.Age+1)
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, err
|
||||
} else if in.Name == "color_test" {
|
||||
if in.Age == 0 {
|
||||
resp, err := runClient(ctx, &clientConfig, s.t, "color_test", in.Age+1)
|
||||
return resp, err
|
||||
}
|
||||
color := nmd.String(ctx, nmd.Color)
|
||||
return &pb.HelloReply{Message: "Hello " + color, Success: true}, nil
|
||||
} else if in.Name == "breaker_test" {
|
||||
if rand.Intn(100) <= 50 {
|
||||
return nil, status.Errorf(codes.ResourceExhausted, "test")
|
||||
}
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
|
||||
} else if in.Name == "error_detail" {
|
||||
err, _ := ecode.Error(ecode.Code(123456), "test_error_detail").WithDetails(&pb.HelloReply{Success: true})
|
||||
return nil, err
|
||||
} else if in.Name == "ecode_status" {
|
||||
reply := &pb.HelloReply{Message: "status", Success: true}
|
||||
st, _ := ecode.Error(ecode.RequestErr, "RequestErr").WithDetails(reply)
|
||||
return nil, st
|
||||
} else if in.Name == "general_error" {
|
||||
return nil, fmt.Errorf("haha is error")
|
||||
} else if in.Name == "ecode_code_error" {
|
||||
return nil, ecode.Conflict
|
||||
} else if in.Name == "pb_error_error" {
|
||||
return nil, ecode.Error(ecode.Code(11122), "haha")
|
||||
} else if in.Name == "ecode_status_error" {
|
||||
return nil, ecode.Error(ecode.RequestErr, "RequestErr")
|
||||
} else if in.Name == "test_remote_port" {
|
||||
if strconv.Itoa(int(in.Age)) != nmd.String(ctx, nmd.RemotePort) {
|
||||
return nil, fmt.Errorf("error port %d", in.Age)
|
||||
}
|
||||
reply := &pb.HelloReply{Message: "status", Success: true}
|
||||
return reply, nil
|
||||
}
|
||||
return &pb.HelloReply{Message: "Hello " + in.Name, Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *helloServer) StreamHello(ss pb.Greeter_StreamHelloServer) error {
|
||||
for i := 0; i < 3; i++ {
|
||||
in, err := ss.Recv()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ret := &pb.HelloReply{Message: "Hello " + in.Name, Success: true}
|
||||
err = ss.Send(ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runServer(t *testing.T, interceptors ...grpc.UnaryServerInterceptor) func() {
|
||||
return func() {
|
||||
server = NewServer(&ServerConfig{Addr: "127.0.0.1:8080", Timeout: xtime.Duration(time.Second)})
|
||||
pb.RegisterGreeterServer(server.Server(), &helloServer{t})
|
||||
server.Use(
|
||||
func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
outPut = append(outPut, "1")
|
||||
resp, err := handler(ctx, req)
|
||||
outPut = append(outPut, "2")
|
||||
return resp, err
|
||||
},
|
||||
func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
outPut = append(outPut, "3")
|
||||
resp, err := handler(ctx, req)
|
||||
outPut = append(outPut, "4")
|
||||
return resp, err
|
||||
})
|
||||
if _, err := server.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runClient(ctx context.Context, cc *ClientConfig, t *testing.T, name string, age int32, interceptors ...grpc.UnaryClientInterceptor) (resp *pb.HelloReply, err error) {
|
||||
client := NewClient(cc)
|
||||
client.Use(interceptors...)
|
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("did not connect: %v,req: %v %v", err, name, age))
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewGreeterClient(conn)
|
||||
resp, err = c.SayHello(ctx, &pb.HelloRequest{Name: name, Age: age})
|
||||
return
|
||||
}
|
||||
|
||||
func TestMain(t *testing.T) {
|
||||
log.Init(nil)
|
||||
}
|
||||
|
||||
func Test_Warden(t *testing.T) {
|
||||
xtrace.Init(&xtrace.Config{Addr: "127.0.0.1:9982", Timeout: xtime.Duration(time.Second * 3)})
|
||||
go _testOnce.Do(runServer(t))
|
||||
go runClient(context.Background(), &clientConfig, t, "trace_test", 0)
|
||||
testTrace(t, 9982, false)
|
||||
testInterceptorChain(t)
|
||||
testValidation(t)
|
||||
testServerRecovery(t)
|
||||
testClientRecovery(t)
|
||||
testErrorDetail(t)
|
||||
testECodeStatus(t)
|
||||
testColorPass(t)
|
||||
testRemotePort(t)
|
||||
testLinkTimeout(t)
|
||||
testClientConfig(t)
|
||||
testBreaker(t)
|
||||
testAllErrorCase(t)
|
||||
testGracefulShutDown(t)
|
||||
}
|
||||
|
||||
func testValidation(t *testing.T) {
|
||||
_, err := runClient(context.Background(), &clientConfig, t, "", 0)
|
||||
if !ecode.RequestErr.Equal(err) {
|
||||
t.Fatalf("testValidation should return ecode.RequestErr,but is %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testAllErrorCase(t *testing.T) {
|
||||
// } else if in.Name == "general_error" {
|
||||
// return nil, fmt.Errorf("haha is error")
|
||||
// } else if in.Name == "ecode_code_error" {
|
||||
// return nil, ecode.CreativeArticleTagErr
|
||||
// } else if in.Name == "pb_error_error" {
|
||||
// return nil, &errpb.Error{ErrCode: 11122, ErrMessage: "haha"}
|
||||
// } else if in.Name == "ecode_status_error" {
|
||||
// return nil, ecode.Error(ecode.RequestErr, "RequestErr")
|
||||
// }
|
||||
ctx := context.Background()
|
||||
t.Run("general_error", func(t *testing.T) {
|
||||
_, err := runClient(ctx, &clientConfig, t, "general_error", 0)
|
||||
assert.Contains(t, err.Error(), "haha")
|
||||
ec := ecode.Cause(err)
|
||||
assert.Equal(t, -500, ec.Code())
|
||||
// remove this assert in future
|
||||
assert.Equal(t, "-500", ec.Message())
|
||||
})
|
||||
t.Run("ecode_code_error", func(t *testing.T) {
|
||||
_, err := runClient(ctx, &clientConfig, t, "ecode_code_error", 0)
|
||||
ec := ecode.Cause(err)
|
||||
assert.Equal(t, ecode.Conflict.Code(), ec.Code())
|
||||
// remove this assert in future
|
||||
assert.Equal(t, "20024", ec.Message())
|
||||
})
|
||||
t.Run("pb_error_error", func(t *testing.T) {
|
||||
_, err := runClient(ctx, &clientConfig, t, "pb_error_error", 0)
|
||||
ec := ecode.Cause(err)
|
||||
assert.Equal(t, 11122, ec.Code())
|
||||
assert.Equal(t, "haha", ec.Message())
|
||||
})
|
||||
t.Run("ecode_status_error", func(t *testing.T) {
|
||||
_, err := runClient(ctx, &clientConfig, t, "ecode_status_error", 0)
|
||||
ec := ecode.Cause(err)
|
||||
assert.Equal(t, ecode.RequestErr.Code(), ec.Code())
|
||||
assert.Equal(t, "RequestErr", ec.Message())
|
||||
})
|
||||
}
|
||||
|
||||
func testBreaker(t *testing.T) {
|
||||
client := NewClient(&clientConfig)
|
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewGreeterClient(conn)
|
||||
for i := 0; i < 35; i++ {
|
||||
_, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: "breaker_test"})
|
||||
if err != nil {
|
||||
if ecode.ServiceUnavailable.Equal(err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Fatalf("testBreaker failed!No breaker was triggered")
|
||||
}
|
||||
|
||||
func testColorPass(t *testing.T) {
|
||||
ctx := nmd.NewContext(context.Background(), nmd.MD{
|
||||
nmd.Color: "red",
|
||||
})
|
||||
resp, err := runClient(ctx, &clientConfig, t, "color_test", 0)
|
||||
if err != nil {
|
||||
t.Fatalf("testColorPass return error %v", err)
|
||||
}
|
||||
if resp == nil || resp.Message != "Hello red" {
|
||||
t.Fatalf("testColorPass resp.Message must be red,%v", *resp)
|
||||
}
|
||||
}
|
||||
|
||||
func testRemotePort(t *testing.T) {
|
||||
ctx := nmd.NewContext(context.Background(), nmd.MD{
|
||||
nmd.RemotePort: "8000",
|
||||
})
|
||||
_, err := runClient(ctx, &clientConfig, t, "test_remote_port", 8000)
|
||||
if err != nil {
|
||||
t.Fatalf("testRemotePort return error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testLinkTimeout(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
|
||||
defer cancel()
|
||||
_, err := runClient(ctx, &clientConfig, t, "timeout_test", 0)
|
||||
if err == nil {
|
||||
t.Fatalf("testLinkTimeout must return error")
|
||||
}
|
||||
if !ecode.Deadline.Equal(err) {
|
||||
t.Fatalf("testLinkTimeout must return error RPCDeadline,err:%v", err)
|
||||
}
|
||||
|
||||
}
|
||||
func testClientConfig(t *testing.T) {
|
||||
_, err := runClient(context.Background(), &clientConfig2, t, "timeout_test2", 0)
|
||||
if err == nil {
|
||||
t.Fatalf("testLinkTimeout must return error")
|
||||
}
|
||||
if !ecode.Deadline.Equal(err) {
|
||||
t.Fatalf("testLinkTimeout must return error RPCDeadline,err:%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testGracefulShutDown(t *testing.T) {
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
resp, err := runClient(context.Background(), &clientConfig, t, "graceful_shutdown", 0)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("run graceful_shutdown client return(%v)", err))
|
||||
}
|
||||
if !resp.Success || resp.Message != "Hello graceful_shutdown" {
|
||||
panic(fmt.Errorf("run graceful_shutdown client return(%v,%v)", err, *resp))
|
||||
}
|
||||
}()
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
defer cancel()
|
||||
server.Shutdown(ctx)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testClientRecovery(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := NewClient(&clientConfig)
|
||||
client.Use(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (ret error) {
|
||||
invoker(ctx, method, req, reply, cc, opts...)
|
||||
panic("client recovery test")
|
||||
})
|
||||
|
||||
conn, err := client.Dial(ctx, "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewGreeterClient(conn)
|
||||
|
||||
_, err = c.SayHello(ctx, &pb.HelloRequest{Name: "other_test", Age: 0})
|
||||
if err == nil {
|
||||
t.Fatalf("recovery must return error")
|
||||
}
|
||||
e, ok := errors.Cause(err).(ecode.Codes)
|
||||
if !ok {
|
||||
t.Fatalf("recovery must return ecode error")
|
||||
}
|
||||
|
||||
if !ecode.ServerErr.Equal(e) {
|
||||
t.Fatalf("recovery must return ecode.RPCClientErr")
|
||||
}
|
||||
}
|
||||
|
||||
func testServerRecovery(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := NewClient(&clientConfig)
|
||||
|
||||
conn, err := client.Dial(ctx, "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
c := pb.NewGreeterClient(conn)
|
||||
|
||||
_, err = c.SayHello(ctx, &pb.HelloRequest{Name: "recovery_test", Age: 0})
|
||||
if err == nil {
|
||||
t.Fatalf("recovery must return error")
|
||||
}
|
||||
e, ok := errors.Cause(err).(ecode.Codes)
|
||||
if !ok {
|
||||
t.Fatalf("recovery must return ecode error")
|
||||
}
|
||||
|
||||
if e.Code() != ecode.ServerErr.Code() {
|
||||
t.Fatalf("recovery must return ecode.ServerErr")
|
||||
}
|
||||
}
|
||||
|
||||
func testInterceptorChain(t *testing.T) {
|
||||
// NOTE: don't delete this sleep
|
||||
time.Sleep(time.Millisecond)
|
||||
if outPut[0] != "1" || outPut[1] != "3" || outPut[2] != "1" || outPut[3] != "3" || outPut[4] != "4" || outPut[5] != "2" || outPut[6] != "4" || outPut[7] != "2" {
|
||||
t.Fatalf("outPut shoud be [1 3 1 3 4 2 4 2]!")
|
||||
}
|
||||
}
|
||||
|
||||
func testErrorDetail(t *testing.T) {
|
||||
_, err := runClient(context.Background(), &clientConfig2, t, "error_detail", 0)
|
||||
if err == nil {
|
||||
t.Fatalf("testErrorDetail must return error")
|
||||
}
|
||||
if ec, ok := errors.Cause(err).(ecode.Codes); !ok {
|
||||
t.Fatalf("testErrorDetail must return ecode error")
|
||||
} else if ec.Code() != 123456 || ec.Message() != "test_error_detail" || len(ec.Details()) == 0 {
|
||||
t.Fatalf("testErrorDetail must return code:123456 and message:test_error_detail, code: %d, message: %s, details length: %d", ec.Code(), ec.Message(), len(ec.Details()))
|
||||
} else if _, ok := ec.Details()[len(ec.Details())-1].(*pb.HelloReply); !ok {
|
||||
t.Fatalf("expect get pb.HelloReply %#v", ec.Details()[len(ec.Details())-1])
|
||||
}
|
||||
}
|
||||
|
||||
func testECodeStatus(t *testing.T) {
|
||||
_, err := runClient(context.Background(), &clientConfig2, t, "ecode_status", 0)
|
||||
if err == nil {
|
||||
t.Fatalf("testECodeStatus must return error")
|
||||
}
|
||||
st, ok := errors.Cause(err).(*ecode.Status)
|
||||
if !ok {
|
||||
t.Fatalf("testECodeStatus must return *ecode.Status")
|
||||
}
|
||||
if st.Code() != int(ecode.RequestErr) && st.Message() != "RequestErr" {
|
||||
t.Fatalf("testECodeStatus must return code: -400, message: RequestErr get: code: %d, message: %s", st.Code(), st.Message())
|
||||
}
|
||||
detail := st.Details()[0].(*pb.HelloReply)
|
||||
if !detail.Success || detail.Message != "status" {
|
||||
t.Fatalf("wrong detail")
|
||||
}
|
||||
}
|
||||
|
||||
func testTrace(t *testing.T, port int, isStream bool) {
|
||||
listener, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: port})
|
||||
if err != nil {
|
||||
t.Fatalf("listent udp failed, %v", err)
|
||||
return
|
||||
}
|
||||
data := make([]byte, 1024)
|
||||
strs := make([][]string, 0)
|
||||
for {
|
||||
var n int
|
||||
n, _, err = listener.ReadFromUDP(data)
|
||||
if err != nil {
|
||||
t.Fatalf("read from udp faild, %v", err)
|
||||
}
|
||||
str := strings.Split(string(data[:n]), _separator)
|
||||
strs = append(strs, str)
|
||||
|
||||
if len(strs) == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(strs[0]) == 0 || len(strs[1]) == 0 {
|
||||
t.Fatalf("trace str's length must be greater than 0")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkServer(b *testing.B) {
|
||||
server := NewServer(&ServerConfig{Addr: "127.0.0.1:8080", Timeout: xtime.Duration(time.Second)})
|
||||
go func() {
|
||||
pb.RegisterGreeterServer(server.Server(), &helloServer{})
|
||||
if _, err := server.Start(); err != nil {
|
||||
os.Exit(0)
|
||||
return
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
server.Server().Stop()
|
||||
}()
|
||||
client := NewClient(&clientConfig)
|
||||
conn, err := client.Dial(context.Background(), "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
b.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(parab *testing.PB) {
|
||||
for parab.Next() {
|
||||
c := pb.NewGreeterClient(conn)
|
||||
resp, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: "benchmark_test", Age: 1})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
b.Fatalf("c.SayHello failed: %v,req: %v %v", err, "benchmark", 1)
|
||||
}
|
||||
if !resp.Success {
|
||||
b.Error("repsonse not success!")
|
||||
}
|
||||
}
|
||||
})
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestParseDSN(t *testing.T) {
|
||||
dsn := "tcp://0.0.0.0:80/?timeout=100ms&idleTimeout=120s&keepaliveInterval=120s&keepaliveTimeout=20s&maxLife=4h&closeWait=3s"
|
||||
config := parseDSN(dsn)
|
||||
if config.Network != "tcp" || config.Addr != "0.0.0.0:80" || time.Duration(config.Timeout) != time.Millisecond*100 ||
|
||||
time.Duration(config.IdleTimeout) != time.Second*120 || time.Duration(config.KeepAliveInterval) != time.Second*120 ||
|
||||
time.Duration(config.MaxLifeTime) != time.Hour*4 || time.Duration(config.ForceCloseWait) != time.Second*3 || time.Duration(config.KeepAliveTimeout) != time.Second*20 {
|
||||
t.Fatalf("parseDSN(%s) not compare config result(%+v)", dsn, config)
|
||||
}
|
||||
|
||||
dsn = "unix:///temp/warden.sock?timeout=300ms"
|
||||
config = parseDSN(dsn)
|
||||
if config.Network != "unix" || config.Addr != "/temp/warden.sock" || time.Duration(config.Timeout) != time.Millisecond*300 {
|
||||
t.Fatalf("parseDSN(%s) not compare config result(%+v)", dsn, config)
|
||||
}
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
helloFn func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error)
|
||||
}
|
||||
|
||||
func (t *testServer) SayHello(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
return t.helloFn(ctx, req)
|
||||
}
|
||||
|
||||
func (t *testServer) StreamHello(pb.Greeter_StreamHelloServer) error { panic("not implemented") }
|
||||
|
||||
// NewTestServerClient .
|
||||
func NewTestServerClient(invoker func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error), svrcfg *ServerConfig, clicfg *ClientConfig) (pb.GreeterClient, func() error) {
|
||||
srv := NewServer(svrcfg)
|
||||
pb.RegisterGreeterServer(srv.Server(), &testServer{helloFn: invoker})
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ch := make(chan bool, 1)
|
||||
go func() {
|
||||
ch <- true
|
||||
srv.Serve(lis)
|
||||
}()
|
||||
<-ch
|
||||
println(lis.Addr().String())
|
||||
conn, err := NewConn(lis.Addr().String())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pb.NewGreeterClient(conn), func() error { return srv.Shutdown(context.Background()) }
|
||||
}
|
||||
|
||||
func TestMetadata(t *testing.T) {
|
||||
cli, cancel := NewTestServerClient(func(ctx context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) {
|
||||
assert.Equal(t, "red", nmd.String(ctx, nmd.Color))
|
||||
assert.Equal(t, "2.2.3.3", nmd.String(ctx, nmd.RemoteIP))
|
||||
assert.Equal(t, "2233", nmd.String(ctx, nmd.RemotePort))
|
||||
return &pb.HelloReply{}, nil
|
||||
}, nil, nil)
|
||||
defer cancel()
|
||||
|
||||
ctx := nmd.NewContext(context.Background(), nmd.MD{
|
||||
nmd.Color: "red",
|
||||
nmd.RemoteIP: "2.2.3.3",
|
||||
nmd.RemotePort: "2233",
|
||||
})
|
||||
_, err := cli.SayHello(ctx, &pb.HelloRequest{Name: "test"})
|
||||
assert.Nil(t, err)
|
||||
}
|
25
pkg/net/rpc/warden/stats.go
Normal file
25
pkg/net/rpc/warden/stats.go
Normal file
@ -0,0 +1,25 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
nmd "github.com/bilibili/kratos/pkg/net/rpc/warden/internal/metadata"
|
||||
"github.com/bilibili/kratos/pkg/stat/sys/cpu"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
gmd "google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func (s *Server) stats() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
resp, err = handler(ctx, req)
|
||||
var cpustat cpu.Stat
|
||||
cpu.ReadStat(&cpustat)
|
||||
if cpustat.Usage != 0 {
|
||||
trailer := gmd.Pairs([]string{nmd.CPUUsage, strconv.FormatInt(int64(cpustat.Usage), 10)}...)
|
||||
grpc.SetTrailer(ctx, trailer)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
31
pkg/net/rpc/warden/validate.go
Normal file
31
pkg/net/rpc/warden/validate.go
Normal file
@ -0,0 +1,31 @@
|
||||
package warden
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/bilibili/kratos/pkg/ecode"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
var validate = validator.New()
|
||||
|
||||
// Validate return a client interceptor validate incoming request per RPC call.
|
||||
func (s *Server) validate() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, args *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
if err = validate.Struct(req); err != nil {
|
||||
err = ecode.Error(ecode.RequestErr, err.Error())
|
||||
return
|
||||
}
|
||||
resp, err = handler(ctx, req)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterValidation adds a validation Func to a Validate's map of validators denoted by the key
|
||||
// NOTE: if the key already exists, the previous validation function will be replaced.
|
||||
// NOTE: this method is not thread-safe it is intended that these all be registered prior to any validation
|
||||
func (s *Server) RegisterValidation(key string, fn validator.Func) error {
|
||||
return validate.RegisterValidation(key, fn)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user