1
0
mirror of https://github.com/go-micro/go-micro.git synced 2025-01-05 10:20:53 +02:00
go-micro/codec/protorpc/protorpc.go
2021-10-12 12:55:53 +01:00

187 lines
3.5 KiB
Go

// Protorpc provides a net/rpc proto-rpc codec. See envelope.proto for the format.
package protorpc
import (
"bytes"
"fmt"
"io"
"strconv"
"sync"
"go-micro.dev/v4/codec"
"github.com/golang/protobuf/proto"
)
type flusher interface {
Flush() error
}
type protoCodec struct {
sync.Mutex
rwc io.ReadWriteCloser
mt codec.MessageType
buf *bytes.Buffer
}
func (c *protoCodec) Close() error {
c.buf.Reset()
return c.rwc.Close()
}
func (c *protoCodec) String() string {
return "proto-rpc"
}
func id(id string) uint64 {
p, err := strconv.ParseInt(id, 10, 64)
if err != nil {
p = 0
}
i := uint64(p)
return i
}
func (c *protoCodec) Write(m *codec.Message, b interface{}) error {
switch m.Type {
case codec.Request:
c.Lock()
defer c.Unlock()
// This is protobuf, of course we copy it.
pbr := &Request{ServiceMethod: m.Method, Seq: id(m.Id)}
data, err := proto.Marshal(pbr)
if err != nil {
return err
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
// dont trust or incoming message
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
data, err = proto.Marshal(m)
if err != nil {
return err
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
if flusher, ok := c.rwc.(flusher); ok {
if err = flusher.Flush(); err != nil {
return err
}
}
case codec.Response, codec.Error:
c.Lock()
defer c.Unlock()
rtmp := &Response{ServiceMethod: m.Method, Seq: id(m.Id), Error: m.Error}
data, err := proto.Marshal(rtmp)
if err != nil {
return err
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
if pb, ok := b.(proto.Message); ok {
data, err = proto.Marshal(pb)
if err != nil {
return err
}
} else {
data = nil
}
_, err = WriteNetString(c.rwc, data)
if err != nil {
return err
}
if flusher, ok := c.rwc.(flusher); ok {
if err = flusher.Flush(); err != nil {
return err
}
}
case codec.Event:
m, ok := b.(proto.Message)
if !ok {
return codec.ErrInvalidMessage
}
data, err := proto.Marshal(m)
if err != nil {
return err
}
c.rwc.Write(data)
default:
return fmt.Errorf("Unrecognised message type: %v", m.Type)
}
return nil
}
func (c *protoCodec) ReadHeader(m *codec.Message, mt codec.MessageType) error {
c.buf.Reset()
c.mt = mt
switch mt {
case codec.Request:
data, err := ReadNetString(c.rwc)
if err != nil {
return err
}
rtmp := new(Request)
err = proto.Unmarshal(data, rtmp)
if err != nil {
return err
}
m.Method = rtmp.GetServiceMethod()
m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
case codec.Response:
data, err := ReadNetString(c.rwc)
if err != nil {
return err
}
rtmp := new(Response)
err = proto.Unmarshal(data, rtmp)
if err != nil {
return err
}
m.Method = rtmp.GetServiceMethod()
m.Id = fmt.Sprintf("%d", rtmp.GetSeq())
m.Error = rtmp.GetError()
case codec.Event:
_, err := io.Copy(c.buf, c.rwc)
return err
default:
return fmt.Errorf("Unrecognised message type: %v", mt)
}
return nil
}
func (c *protoCodec) ReadBody(b interface{}) error {
var data []byte
switch c.mt {
case codec.Request, codec.Response:
var err error
data, err = ReadNetString(c.rwc)
if err != nil {
return err
}
case codec.Event:
data = c.buf.Bytes()
default:
return fmt.Errorf("Unrecognised message type: %v", c.mt)
}
if b != nil {
return proto.Unmarshal(data, b.(proto.Message))
}
return nil
}
func NewCodec(rwc io.ReadWriteCloser) codec.Codec {
return &protoCodec{
buf: bytes.NewBuffer(nil),
rwc: rwc,
}
}