1
0
mirror of https://github.com/go-micro/go-micro.git synced 2025-06-18 22:17:44 +02:00

feat: add test framework & refactor RPC server (#2579)

Co-authored-by: Rene Jochum <rene@jochum.dev>
This commit is contained in:
David Brouwer
2022-10-20 13:00:50 +02:00
committed by GitHub
parent c25dee7c8a
commit a3980c2308
54 changed files with 3703 additions and 2497 deletions

View File

@ -1,11 +1,5 @@
package server
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// Meh, we need to get rid of this shit
import (
"context"
"errors"
@ -80,7 +74,7 @@ type router struct {
subscribers map[string][]*subscriber
}
// rpcRouter encapsulates functions that become a server.Router.
// rpcRouter encapsulates functions that become a Router.
type rpcRouter struct {
h func(context.Context, Request, interface{}) error
m func(context.Context, Message) error
@ -96,7 +90,7 @@ func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response)
func newRpcRouter(opts ...RouterOption) *router {
return &router{
ops: newRouterOptions(opts...),
ops: NewRouterOptions(opts...),
serviceMap: make(map[string]*service),
subscribers: make(map[string][]*subscriber),
}
@ -180,11 +174,13 @@ func prepareMethod(method reflect.Method, logger log.Logger) *methodType {
logger.Logf(log.ErrorLevel, "method %v has wrong number of outs: %v", mname, mtype.NumOut())
return nil
}
// The return type of the method must be error.
if returnType := mtype.Out(0); returnType != typeOfError {
logger.Logf(log.ErrorLevel, "method %v returns %v not error", mname, returnType.String())
return nil
}
return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
}
@ -195,10 +191,13 @@ func (router *router) sendResponse(sending sync.Locker, req *request, reply inte
resp.msg = msg
resp.msg.Id = req.msg.Id
sending.Lock()
err := cc.Write(resp.msg, reply)
sending.Unlock()
router.freeResponse(resp)
return err
}
@ -261,6 +260,7 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex,
// Invoke the method, providing a new value for the reply.
fn := func(ctx context.Context, req Request, stream interface{}) error {
returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
if err := returnValues[0].Interface(); err != nil {
// the function returned an error, we use that
return err.(error)
@ -288,11 +288,14 @@ func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
return contextv
}
return reflect.Zero(m.ContextType)
}
func (router *router) getRequest() *request {
router.reqLock.Lock()
defer router.reqLock.Unlock()
req := router.freeReq
if req == nil {
req = new(request)
@ -300,19 +303,22 @@ func (router *router) getRequest() *request {
router.freeReq = req.next
*req = request{}
}
router.reqLock.Unlock()
return req
}
func (router *router) freeRequest(req *request) {
router.reqLock.Lock()
defer router.reqLock.Unlock()
req.next = router.freeReq
router.freeReq = req
router.reqLock.Unlock()
}
func (router *router) getResponse() *response {
router.respLock.Lock()
defer router.respLock.Unlock()
resp := router.freeResp
if resp == nil {
resp = new(response)
@ -320,15 +326,16 @@ func (router *router) getResponse() *response {
router.freeResp = resp.next
*resp = response{}
}
router.respLock.Unlock()
return resp
}
func (router *router) freeResponse(resp *response) {
router.respLock.Lock()
defer router.respLock.Unlock()
resp.next = router.freeResp
router.freeResp = resp
router.respLock.Unlock()
}
func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
@ -341,8 +348,10 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp
}
// discard body
cc.ReadBody(nil)
return
}
// is it a streaming request? then we don't read the body
if mtype.stream {
if cc.(codec.Codec).String() != "grpc" {
@ -359,10 +368,12 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
// argv guaranteed to be a pointer now.
if err = cc.ReadBody(argv.Interface()); err != nil {
return
}
if argIsValue {
argv = argv.Elem()
}
@ -370,6 +381,7 @@ func (router *router) readRequest(r Request) (service *service, mtype *methodTyp
if !mtype.stream {
replyv = reflect.New(mtype.ReplyType.Elem())
}
return
}
@ -387,6 +399,7 @@ func (router *router) readHeader(cc codec.Reader) (service *service, mtype *meth
return
}
err = errors.New("rpc: router cannot decode request: " + err.Error())
return
}
@ -399,28 +412,33 @@ func (router *router) readHeader(cc codec.Reader) (service *service, mtype *meth
err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
return
}
// Look up the request.
router.mu.Lock()
service = router.serviceMap[serviceMethod[0]]
router.mu.Unlock()
if service == nil {
err = errors.New("rpc: can't find service " + serviceMethod[0])
return
}
mtype = service.method[serviceMethod[1]]
if mtype == nil {
err = errors.New("rpc: can't find method " + serviceMethod[1])
}
return
}
func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
return newRpcHandler(h, opts...)
return NewRpcHandler(h, opts...)
}
func (router *router) Handle(h Handler) error {
router.mu.Lock()
defer router.mu.Unlock()
if router.serviceMap == nil {
router.serviceMap = make(map[string]*service)
}
@ -428,6 +446,7 @@ func (router *router) Handle(h Handler) error {
if len(h.Name()) == 0 {
return errors.New("rpc.Handle: handler has no name")
}
if !isExported(h.Name()) {
return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
}
@ -460,6 +479,7 @@ func (router *router) Handle(h Handler) error {
// save handler
router.serviceMap[s.name] = s
return nil
}
@ -474,8 +494,10 @@ func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response)
if req != nil {
router.freeRequest(req)
}
return err
}
return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
}
@ -488,6 +510,7 @@ func (router *router) Subscribe(s Subscriber) error {
if !ok {
return fmt.Errorf("invalid subscriber: expected *subscriber")
}
if len(sub.handlers) == 0 {
return fmt.Errorf("invalid subscriber: no handler functions")
}
@ -517,10 +540,9 @@ func (router *router) ProcessMessage(ctx context.Context, msg Message) (err erro
}
}()
router.su.RLock()
// get the subscribers by topic
router.su.RLock()
subs, ok := router.subscribers[msg.Topic()]
// unlock since we only need to get the subs
router.su.RUnlock()
if !ok {
return nil