1
0
mirror of https://github.com/go-micro/go-micro.git synced 2024-12-24 10:07:04 +02:00

fixes for safe conversation and avoid panics (#1213)

* fixes for safe convertation

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>

* fix client publish panic

If broker connect returns error we dont check it status and use
it later to publish message, mostly this is unexpected because
broker connection failed and we cant use it.
Also proposed solution have benefit - we flag connection status
only when we have succeseful broker connection

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>

* api/handler/broker: fix possible broker publish panic

Signed-off-by: Vasiliy Tolstov <v.tolstov@unistack.org>
This commit is contained in:
Vasiliy Tolstov 2020-02-19 02:05:38 +03:00 committed by GitHub
parent 6248f05f74
commit 58598d0fe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 108 additions and 102 deletions

View File

@ -8,6 +8,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
@ -26,6 +27,7 @@ const (
)
type brokerHandler struct {
once atomic.Value
opts handler.Options
u websocket.Upgrader
}
@ -42,7 +44,6 @@ type conn struct {
}
var (
once sync.Once
contentType = "text/plain"
)
@ -155,10 +156,15 @@ func (b *brokerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
br := b.opts.Service.Client().Options().Broker
// Setup the broker
once.Do(func() {
br.Init()
br.Connect()
})
if !b.once.Load().(bool) {
if err := br.Init(); err != nil {
http.Error(w, err.Error(), 500)
}
if err := br.Connect(); err != nil {
http.Error(w, err.Error(), 500)
}
b.once.Store(true)
}
// Parse
r.ParseForm()
@ -235,7 +241,7 @@ func (b *brokerHandler) String() string {
}
func NewHandler(opts ...handler.Option) handler.Handler {
return &brokerHandler{
h := &brokerHandler{
u: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
@ -245,6 +251,8 @@ func NewHandler(opts ...handler.Option) handler.Handler {
},
opts: handler.NewOptions(opts...),
}
h.once.Store(true)
return h
}
func WithCors(cors map[string]bool, opts ...handler.Option) handler.Handler {

View File

@ -100,7 +100,10 @@ func (s *svc) Validate(token string) (*auth.Account, error) {
return nil, ErrInvalidToken
}
claims := res.Claims.(*AuthClaims)
claims, ok := res.Claims.(*AuthClaims)
if !ok {
return nil, ErrInvalidToken
}
return &auth.Account{
Id: claims.Id,

View File

@ -6,7 +6,7 @@ import (
"crypto/tls"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"github.com/micro/go-micro/v2/broker"
@ -24,9 +24,9 @@ import (
)
type grpcClient struct {
once sync.Once
opts client.Options
pool *pool
once atomic.Value
}
func init() {
@ -570,9 +570,12 @@ func (g *grpcClient) Publish(ctx context.Context, p client.Message, opts ...clie
body = b
}
g.once.Do(func() {
g.opts.Broker.Connect()
})
if !g.once.Load().(bool) {
if err = g.opts.Broker.Connect(); err != nil {
return errors.InternalServerError("go.micro.client", err.Error())
}
g.once.Store(true)
}
topic := p.Topic()
@ -641,9 +644,9 @@ func newClient(opts ...client.Option) client.Client {
}
rc := &grpcClient{
once: sync.Once{},
opts: options,
}
rc.once.Store(false)
rc.pool = newPool(options.PoolSize, options.PoolTTL, rc.poolMaxIdle(), rc.poolMaxStreams())

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
@ -22,7 +21,7 @@ import (
)
type rpcClient struct {
once sync.Once
once atomic.Value
opts Options
pool pool.Pool
seq uint64
@ -38,11 +37,11 @@ func newRpcClient(opt ...Option) Client {
)
rc := &rpcClient{
once: sync.Once{},
opts: opts,
pool: p,
seq: 0,
}
rc.once.Store(false)
c := Client(rc)
@ -645,9 +644,12 @@ func (r *rpcClient) Publish(ctx context.Context, msg Message, opts ...PublishOpt
body = b.Bytes()
}
r.once.Do(func() {
r.opts.Broker.Connect()
})
if !r.once.Load().(bool) {
if err = r.opts.Broker.Connect(); err != nil {
return errors.InternalServerError("go.micro.client", err.Error())
}
r.once.Store(true)
}
return r.opts.Broker.Publish(topic, &broker.Message{
Header: md,

View File

@ -2,6 +2,7 @@
package codec
import (
"errors"
"io"
)
@ -12,6 +13,10 @@ const (
Event
)
var (
ErrInvalidMessage = errors.New("invalid message")
)
type MessageType int
// Takes in a connection/buffer and returns a new Codec

View File

@ -25,13 +25,17 @@ func (c *Codec) ReadBody(b interface{}) error {
if err != nil {
return err
}
return proto.Unmarshal(buf, b.(proto.Message))
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
return proto.Unmarshal(buf, m)
}
func (c *Codec) Write(m *codec.Message, b interface{}) error {
p, ok := b.(proto.Message)
if !ok {
return nil
return codec.ErrInvalidMessage
}
buf, err := proto.Marshal(p)
if err != nil {

View File

@ -56,8 +56,12 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
if err != nil {
return err
}
// Of course this is a protobuf! Trust me or detonate the program.
data, err = proto.Marshal(b.(proto.Message))
// dont trust or incoming message
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
data, err = proto.Marshal(m)
if err != nil {
return err
}
@ -100,7 +104,11 @@ func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
}
}
case codec.Event:
data, err := proto.Marshal(b.(proto.Message))
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
data, err := proto.Marshal(m)
if err != nil {
return err
}

View File

@ -2,7 +2,6 @@ package grpc
import (
"encoding/json"
"fmt"
"strings"
b "bytes"
@ -71,11 +70,19 @@ func (w wrapCodec) Unmarshal(data []byte, v interface{}) error {
}
func (protoCodec) Marshal(v interface{}) ([]byte, error) {
return proto.Marshal(v.(proto.Message))
m, ok := v.(proto.Message)
if !ok {
return nil, codec.ErrInvalidMessage
}
return proto.Marshal(m)
}
func (protoCodec) Unmarshal(data []byte, v interface{}) error {
return proto.Unmarshal(data, v.(proto.Message))
m, ok := v.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
return proto.Unmarshal(data, m)
}
func (protoCodec) Name() string {
@ -85,7 +92,6 @@ func (protoCodec) Name() string {
func (jsonCodec) Marshal(v interface{}) ([]byte, error) {
if pb, ok := v.(proto.Message); ok {
s, err := jsonpbMarshaler.MarshalToString(pb)
return []byte(s), err
}
@ -109,7 +115,7 @@ func (jsonCodec) Name() string {
func (bytesCodec) Marshal(v interface{}) ([]byte, error) {
b, ok := v.(*[]byte)
if !ok {
return nil, fmt.Errorf("failed to marshal: %v is not type of *[]byte", v)
return nil, codec.ErrInvalidMessage
}
return *b, nil
}
@ -117,7 +123,7 @@ func (bytesCodec) Marshal(v interface{}) ([]byte, error) {
func (bytesCodec) Unmarshal(data []byte, v interface{}) error {
b, ok := v.(*[]byte)
if !ok {
return fmt.Errorf("failed to unmarshal: %v is not type of *[]byte", v)
return codec.ErrInvalidMessage
}
*b = data
return nil

16
server/grpc/context.go Normal file
View File

@ -0,0 +1,16 @@
package grpc
import (
"context"
"github.com/micro/go-micro/v2/server"
)
func setServerOption(k, v interface{}) server.Option {
return func(o *server.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, k, v)
}
}

View File

@ -143,9 +143,8 @@ func (g *grpcServer) getMaxMsgSize() int {
func (g *grpcServer) getCredentials() credentials.TransportCredentials {
if g.opts.Context != nil {
if v := g.opts.Context.Value(tlsAuth{}); v != nil {
tls := v.(*tls.Config)
return credentials.NewTLS(tls)
if v, ok := g.opts.Context.Value(tlsAuth{}).(*tls.Config); ok && v != nil {
return credentials.NewTLS(v)
}
}
return nil
@ -156,15 +155,8 @@ func (g *grpcServer) getGrpcOptions() []grpc.ServerOption {
return nil
}
v := g.opts.Context.Value(grpcOptions{})
if v == nil {
return nil
}
opts, ok := v.([]grpc.ServerOption)
if !ok {
opts, ok := g.opts.Context.Value(grpcOptions{}).([]grpc.ServerOption)
if !ok || opts == nil {
return nil
}
@ -505,8 +497,8 @@ func (g *grpcServer) processStream(stream grpc.ServerStream, service *service, m
func (g *grpcServer) newGRPCCodec(contentType string) (encoding.Codec, error) {
codecs := make(map[string]encoding.Codec)
if g.opts.Context != nil {
if v := g.opts.Context.Value(codecsKey{}); v != nil {
codecs = v.(map[string]encoding.Codec)
if v, ok := g.opts.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil {
codecs = v
}
}
if c, ok := codecs[contentType]; ok {
@ -573,10 +565,10 @@ func (g *grpcServer) Subscribe(sb server.Subscriber) error {
g.Lock()
_, ok = g.subscribers[sub]
if ok {
if _, ok = g.subscribers[sub]; ok {
return fmt.Errorf("subscriber %v already exists", sub)
}
g.subscribers[sub] = nil
g.Unlock()
return nil

View File

@ -27,8 +27,8 @@ func Codec(contentType string, c encoding.Codec) server.Option {
if o.Context == nil {
o.Context = context.Background()
}
if v := o.Context.Value(codecsKey{}); v != nil {
codecs = v.(map[string]encoding.Codec)
if v, ok := o.Context.Value(codecsKey{}).(map[string]encoding.Codec); ok && v != nil {
codecs = v
}
codecs[contentType] = c
o.Context = context.WithValue(o.Context, codecsKey{}, codecs)
@ -37,32 +37,17 @@ func Codec(contentType string, c encoding.Codec) server.Option {
// AuthTLS should be used to setup a secure authentication using TLS
func AuthTLS(t *tls.Config) server.Option {
return func(o *server.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, tlsAuth{}, t)
}
return setServerOption(tlsAuth{}, t)
}
// Listener specifies the net.Listener to use instead of the default
func Listener(l net.Listener) server.Option {
return func(o *server.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, netListener{}, l)
}
return setServerOption(netListener{}, l)
}
// Options to be used to configure gRPC options
func Options(opts ...grpc.ServerOption) server.Option {
return func(o *server.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, grpcOptions{}, opts)
}
return setServerOption(grpcOptions{}, opts)
}
//
@ -70,51 +55,25 @@ func Options(opts ...grpc.ServerOption) server.Option {
// send. Default maximum message size is 4 MB.
//
func MaxMsgSize(s int) server.Option {
return func(o *server.Options) {
if o.Context == nil {
o.Context = context.Background()
}
o.Context = context.WithValue(o.Context, maxMsgSizeKey{}, s)
}
return setServerOption(maxMsgSizeKey{}, s)
}
func newOptions(opt ...server.Option) server.Options {
opts := server.Options{
Codecs: make(map[string]codec.NewCodec),
Metadata: map[string]string{},
Codecs: make(map[string]codec.NewCodec),
Metadata: map[string]string{},
Broker: broker.DefaultBroker,
Registry: registry.DefaultRegistry,
Transport: transport.DefaultTransport,
Address: server.DefaultAddress,
Name: server.DefaultName,
Id: server.DefaultId,
Version: server.DefaultVersion,
}
for _, o := range opt {
o(&opts)
}
if opts.Broker == nil {
opts.Broker = broker.DefaultBroker
}
if opts.Registry == nil {
opts.Registry = registry.DefaultRegistry
}
if opts.Transport == nil {
opts.Transport = transport.DefaultTransport
}
if len(opts.Address) == 0 {
opts.Address = server.DefaultAddress
}
if len(opts.Name) == 0 {
opts.Name = server.DefaultName
}
if len(opts.Id) == 0 {
opts.Id = server.DefaultId
}
if len(opts.Version) == 0 {
opts.Version = server.DefaultVersion
}
return opts
}