diff --git a/server/rpc_router.go b/server/rpc_router.go index 1ffb4963..c4591f33 100644 --- a/server/rpc_router.go +++ b/server/rpc_router.go @@ -202,6 +202,11 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, return nil } + // wrap the handler + for i := len(router.hdlrWrappers); i > 0; i-- { + fn = router.hdlrWrappers[i-1](fn) + } + // execute handler if err := fn(ctx, r, replyv.Interface()); err != nil { return err @@ -235,6 +240,11 @@ func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, } } + // wrap the handler + for i := len(router.hdlrWrappers); i > 0; i-- { + fn = router.hdlrWrappers[i-1](fn) + } + // client.Stream request r.stream = true diff --git a/server/rpc_server.go b/server/rpc_server.go index 8fe3a812..22fca2d9 100644 --- a/server/rpc_server.go +++ b/server/rpc_server.go @@ -35,6 +35,9 @@ type rpcServer struct { func newRpcServer(opts ...Option) Server { options := newOptions(opts...) + router := newRpcRouter() + router.hdlrWrappers = options.HdlrWrappers + return &rpcServer{ opts: options, router: DefaultRouter, @@ -45,6 +48,14 @@ func newRpcServer(opts ...Option) Server { } } +type rpcRouter struct { + h func(context.Context, Request, interface{}) error +} + +func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) error { + return r.h(ctx, req, rsp) +} + // ServeConn serves a single connection func (s *rpcServer) ServeConn(sock transport.Socket) { defer func() { @@ -143,24 +154,26 @@ func (s *rpcServer) ServeConn(sock transport.Socket) { } // set router - r := s.opts.Router + r := Router(s.router) - // if nil use default router - if s.opts.Router == nil { - r = s.router + // if not nil use the router specified + if s.opts.Router != nil { + // create a wrapped function + handler := func(ctx context.Context, req Request, rsp interface{}) error { + return s.opts.Router.ServeRequest(ctx, req, rsp.(Response)) + } + + // execute the wrapper for it + for i := len(s.opts.HdlrWrappers); i > 0; i-- { + handler = s.opts.HdlrWrappers[i-1](handler) + } + + // set the router + r = rpcRouter{handler} } - // create a wrapped function - handler := func(ctx context.Context, req Request, rsp interface{}) error { - return r.ServeRequest(ctx, req, rsp.(Response)) - } - - for i := len(s.opts.HdlrWrappers); i > 0; i-- { - handler = s.opts.HdlrWrappers[i-1](handler) - } - - // TODO: handle error better - if err := handler(ctx, request, response); err != nil { + // serve the actual request using the request router + if err := r.ServeRequest(ctx, request, response); err != nil { // write an error response err = rcodec.Write(&codec.Message{ Header: msg.Header, @@ -206,6 +219,15 @@ func (s *rpcServer) Init(opts ...Option) error { for _, opt := range opts { opt(&s.opts) } + + // update router if its the default + if s.opts.Router == nil { + r := newRpcRouter() + r.hdlrWrappers = s.opts.HdlrWrappers + r.serviceMap = s.router.serviceMap + s.router = r + } + s.Unlock() return nil }