diff --git a/client/client.go b/client/client.go index ad117412..fdff08f7 100644 --- a/client/client.go +++ b/client/client.go @@ -27,13 +27,13 @@ import ( type Client interface { NewPublication(topic string, msg interface{}) Publication - NewRequest(service, method string, req interface{}) Request - NewProtoRequest(service, method string, req interface{}) Request - NewJsonRequest(service, method string, req interface{}) Request + NewRequest(service, method string, req interface{}, reqOpts ...RequestOption) Request + NewProtoRequest(service, method string, req interface{}, reqOpts ...RequestOption) Request + NewJsonRequest(service, method string, req interface{}, reqOpts ...RequestOption) Request Call(ctx context.Context, req Request, rsp interface{}, opts ...CallOption) error CallRemote(ctx context.Context, addr string, req Request, rsp interface{}, opts ...CallOption) error - Stream(ctx context.Context, req Request, rspChan interface{}, opts ...CallOption) (Streamer, error) - StreamRemote(ctx context.Context, addr string, req Request, rspChan interface{}, opts ...CallOption) (Streamer, error) + Stream(ctx context.Context, req Request, opts ...CallOption) (Streamer, error) + StreamRemote(ctx context.Context, addr string, req Request, opts ...CallOption) (Streamer, error) Publish(ctx context.Context, p Publication, opts ...PublishOption) error } @@ -48,10 +48,15 @@ type Request interface { Method() string ContentType() string Request() interface{} + // indicates whether the request will be a streaming one rather than unary + Stream() bool } type Streamer interface { + Context() context.Context Request() Request + Send(interface{}) error + Recv(interface{}) error Error() error Close() error } @@ -59,6 +64,7 @@ type Streamer interface { type Option func(*options) type CallOption func(*callOptions) type PublishOption func(*publishOptions) +type RequestOption func(*requestOptions) var ( DefaultClient Client = newRpcClient() @@ -76,13 +82,13 @@ func CallRemote(ctx context.Context, address string, request Request, response i // Creates a streaming connection with a service and returns responses on the // channel passed in. It's upto the user to close the streamer. -func Stream(ctx context.Context, request Request, responseChan interface{}, opts ...CallOption) (Streamer, error) { - return DefaultClient.Stream(ctx, request, responseChan, opts...) +func Stream(ctx context.Context, request Request, opts ...CallOption) (Streamer, error) { + return DefaultClient.Stream(ctx, request, opts...) } // Creates a streaming connection to the address specified. -func StreamRemote(ctx context.Context, address string, request Request, responseChan interface{}, opts ...CallOption) (Streamer, error) { - return DefaultClient.StreamRemote(ctx, address, request, responseChan, opts...) +func StreamRemote(ctx context.Context, address string, request Request, opts ...CallOption) (Streamer, error) { + return DefaultClient.StreamRemote(ctx, address, request, opts...) } // Publishes a publication using the default client. Using the underlying broker @@ -103,16 +109,16 @@ func NewPublication(topic string, message interface{}) Publication { // Creates a new request using the default client. Content Type will // be set to the default within options and use the appropriate codec -func NewRequest(service, method string, request interface{}) Request { - return DefaultClient.NewRequest(service, method, request) +func NewRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request { + return DefaultClient.NewRequest(service, method, request, reqOpts...) } // Creates a new protobuf request using the default client -func NewProtoRequest(service, method string, request interface{}) Request { - return DefaultClient.NewProtoRequest(service, method, request) +func NewProtoRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request { + return DefaultClient.NewProtoRequest(service, method, request, reqOpts...) } // Creates a new json request using the default client -func NewJsonRequest(service, method string, request interface{}) Request { - return DefaultClient.NewJsonRequest(service, method, request) +func NewJsonRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request { + return DefaultClient.NewJsonRequest(service, method, request, reqOpts...) } diff --git a/client/client_wrapper.go b/client/client_wrapper.go index ecbad74f..382747a3 100644 --- a/client/client_wrapper.go +++ b/client/client_wrapper.go @@ -36,3 +36,6 @@ Example usage: // Wrapper wraps a client and returns a client type Wrapper func(Client) Client + +// StreamWrapper wraps a Stream and returns the equivalent +type StreamWrapper func(Streamer) Streamer diff --git a/client/options.go b/client/options.go index ca0f6fc5..389b34b1 100644 --- a/client/options.go +++ b/client/options.go @@ -24,6 +24,10 @@ type callOptions struct { type publishOptions struct{} +type requestOptions struct { + stream bool +} + // Broker to be used for pub/sub func Broker(b broker.Broker) Option { return func(o *options) { @@ -80,3 +84,11 @@ func WithSelectOption(so selector.SelectOption) CallOption { o.selectOptions = append(o.selectOptions, so) } } + +// Request Options + +func StreamingRequest() RequestOption { + return func(o *requestOptions) { + o.stream = true + } +} diff --git a/client/rpc_client.go b/client/rpc_client.go index bf2abacf..97e41120 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -112,7 +112,7 @@ func (r *rpcClient) call(ctx context.Context, address string, request Request, r return client.Close() } -func (r *rpcClient) stream(ctx context.Context, address string, request Request, responseChan interface{}) (Streamer, error) { +func (r *rpcClient) stream(ctx context.Context, address string, req Request) (Streamer, error) { msg := &transport.Message{ Header: make(map[string]string), } @@ -124,9 +124,9 @@ func (r *rpcClient) stream(ctx context.Context, address string, request Request, } } - msg.Header["Content-Type"] = request.ContentType() + msg.Header["Content-Type"] = req.ContentType() - cf, err := r.newCodec(request.ContentType()) + cf, err := r.newCodec(req.ContentType()) if err != nil { return nil, errors.InternalServerError("go.micro.client", err.Error()) } @@ -136,14 +136,14 @@ func (r *rpcClient) stream(ctx context.Context, address string, request Request, return nil, errors.InternalServerError("go.micro.client", fmt.Sprintf("Error sending request: %v", err)) } - client := newClientWithCodec(newRpcPlusCodec(msg, c, cf)) - call := client.StreamGo(request.Service(), request.Method(), request.Request(), responseChan) + stream := &rpcStream{ + context: ctx, + request: req, + codec: newRpcPlusCodec(msg, c, cf), + } - return &rpcStream{ - request: request, - call: call, - client: client, - }, nil + err = stream.Send(req.Request()) + return stream, err } func (r *rpcClient) CallRemote(ctx context.Context, address string, request Request, response interface{}, opts ...CallOption) error { @@ -180,11 +180,11 @@ func (r *rpcClient) Call(ctx context.Context, request Request, response interfac return err } -func (r *rpcClient) StreamRemote(ctx context.Context, address string, request Request, responseChan interface{}, opts ...CallOption) (Streamer, error) { - return r.stream(ctx, address, request, responseChan) +func (r *rpcClient) StreamRemote(ctx context.Context, address string, request Request, opts ...CallOption) (Streamer, error) { + return r.stream(ctx, address, request) } -func (r *rpcClient) Stream(ctx context.Context, request Request, responseChan interface{}, opts ...CallOption) (Streamer, error) { +func (r *rpcClient) Stream(ctx context.Context, request Request, opts ...CallOption) (Streamer, error) { var copts callOptions for _, opt := range opts { opt(&copts) @@ -209,7 +209,7 @@ func (r *rpcClient) Stream(ctx context.Context, request Request, responseChan in address = fmt.Sprintf("%s:%d", address, node.Port) } - stream, err := r.stream(ctx, address, request, responseChan) + stream, err := r.stream(ctx, address, request) r.opts.selector.Mark(request.Service(), node, err) return stream, err } @@ -247,14 +247,14 @@ func (r *rpcClient) NewPublication(topic string, message interface{}) Publicatio func (r *rpcClient) NewProtoPublication(topic string, message interface{}) Publication { return newRpcPublication(topic, message, "application/octet-stream") } -func (r *rpcClient) NewRequest(service, method string, request interface{}) Request { - return newRpcRequest(service, method, request, r.opts.contentType) +func (r *rpcClient) NewRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request { + return newRpcRequest(service, method, request, r.opts.contentType, reqOpts...) } -func (r *rpcClient) NewProtoRequest(service, method string, request interface{}) Request { - return newRpcRequest(service, method, request, "application/octet-stream") +func (r *rpcClient) NewProtoRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request { + return newRpcRequest(service, method, request, "application/octet-stream", reqOpts...) } -func (r *rpcClient) NewJsonRequest(service, method string, request interface{}) Request { - return newRpcRequest(service, method, request, "application/json") +func (r *rpcClient) NewJsonRequest(service, method string, request interface{}, reqOpts ...RequestOption) Request { + return newRpcRequest(service, method, request, "application/json", reqOpts...) } diff --git a/client/rpc_codec.go b/client/rpc_codec.go index fd5df68b..74f6c4f2 100644 --- a/client/rpc_codec.go +++ b/client/rpc_codec.go @@ -63,6 +63,7 @@ func newRpcPlusCodec(req *transport.Message, client transport.Client, c codec.Ne } func (c *rpcPlusCodec) WriteRequest(req *request, body interface{}) error { + c.buf.wbuf.Reset() m := &codec.Message{ Id: req.Seq, Target: req.Service, diff --git a/client/rpc_request.go b/client/rpc_request.go index eb799193..5a5ec0df 100644 --- a/client/rpc_request.go +++ b/client/rpc_request.go @@ -5,14 +5,22 @@ type rpcRequest struct { method string contentType string request interface{} + opts requestOptions } -func newRpcRequest(service, method string, request interface{}, contentType string) Request { +func newRpcRequest(service, method string, request interface{}, contentType string, reqOpts ...RequestOption) Request { + var opts requestOptions + + for _, o := range reqOpts { + o(&opts) + } + return &rpcRequest{ service: service, method: method, request: request, contentType: contentType, + opts: opts, } } @@ -31,3 +39,7 @@ func (r *rpcRequest) Method() string { func (r *rpcRequest) Request() interface{} { return r.request } + +func (r *rpcRequest) Stream() bool { + return r.opts.stream +} diff --git a/client/rpc_stream.go b/client/rpc_stream.go index 9f64c588..837b97d2 100644 --- a/client/rpc_stream.go +++ b/client/rpc_stream.go @@ -1,19 +1,112 @@ package client +import ( + "errors" + "io" + "log" + "sync" + + "golang.org/x/net/context" +) + +// Implements the streamer interface type rpcStream struct { + sync.RWMutex + seq uint64 + closed bool + err error request Request - call *call - client *client + codec clientCodec + context context.Context +} + +func (r *rpcStream) Context() context.Context { + return r.context } func (r *rpcStream) Request() Request { return r.request } +func (r *rpcStream) Send(msg interface{}) error { + r.Lock() + defer r.Unlock() + + if r.closed { + r.err = errShutdown + return errShutdown + } + + seq := r.seq + r.seq++ + + req := request{ + Service: r.request.Service(), + Seq: seq, + ServiceMethod: r.request.Method(), + } + + if err := r.codec.WriteRequest(&req, msg); err != nil { + r.err = err + return err + } + return nil +} + +func (r *rpcStream) Recv(msg interface{}) error { + r.Lock() + defer r.Unlock() + + if r.closed { + r.err = errShutdown + return errShutdown + } + + var resp response + if err := r.codec.ReadResponseHeader(&resp); err != nil { + if err == io.EOF && !r.closed { + r.err = io.ErrUnexpectedEOF + return io.ErrUnexpectedEOF + } + r.err = err + return err + } + + switch { + case len(resp.Error) > 0: + // We've got an error response. Give this to the request; + // any subsequent requests will get the ReadResponseBody + // error if there is one. + if resp.Error != lastStreamResponseError { + r.err = serverError(resp.Error) + } else { + r.err = io.EOF + } + if err := r.codec.ReadResponseBody(nil); err != nil { + r.err = errors.New("reading error payload: " + err.Error()) + } + default: + if err := r.codec.ReadResponseBody(msg); err != nil { + r.err = errors.New("reading body " + err.Error()) + } + } + + if r.err != nil && r.err != io.EOF && !r.closed { + log.Println("rpc: client protocol error:", r.err) + } + + return r.err +} + func (r *rpcStream) Error() error { - return r.call.Error + r.RLock() + defer r.RUnlock() + return r.err } func (r *rpcStream) Close() error { - return r.client.Close() + r.Lock() + defer r.Unlock() + r.closed = true + return r.codec.Close() } diff --git a/client/rpcplus_client.go b/client/rpcplus_client.go index f119cc3e..5325474d 100644 --- a/client/rpcplus_client.go +++ b/client/rpcplus_client.go @@ -8,7 +8,6 @@ import ( "errors" "io" "log" - "reflect" "sync" "github.com/youtube/vitess/go/trace" @@ -38,7 +37,6 @@ type call struct { Reply interface{} // The reply from the function (*struct for single, chan * struct for streaming). Error error // After completion, the error status. Done chan *call // Strobes when call is complete (nil for streaming RPCs) - Stream bool // True for a streaming RPC call, false otherwise Subseq uint64 // The next expected subseq in the packets } @@ -145,28 +143,12 @@ func (client *client) input() { // We've got an error response. Give this to the request; // any subsequent requests will get the ReadResponseBody // error if there is one. - if !(call.Stream && resp.Error == lastStreamResponseError) { - call.Error = serverError(resp.Error) - } + call.Error = serverError(resp.Error) err = client.codec.ReadResponseBody(nil) if err != nil { err = errors.New("reading error payload: " + err.Error()) } client.done(seq) - case call.Stream: - // call.Reply is a chan *T2 - // we need to create a T2 and get a *T2 back - value := reflect.New(reflect.TypeOf(call.Reply).Elem().Elem()).Interface() - err = client.codec.ReadResponseBody(value) - if err != nil { - call.Error = errors.New("reading body " + err.Error()) - } else { - // writing on the channel could block forever. For - // instance, if a client calls 'close', this might block - // forever. the current suggestion is for the - // client to drain the receiving channel in that case - reflect.ValueOf(call.Reply).Send(reflect.ValueOf(value)) - } default: err = client.codec.ReadResponseBody(call.Reply) if err != nil { @@ -203,12 +185,6 @@ func (client *client) done(seq uint64) { } func (call *call) done() { - if call.Stream { - // need to close the channel. client won't be able to read any more. - reflect.ValueOf(call.Reply).Close() - return - } - select { case call.Done <- call: // ok @@ -270,28 +246,6 @@ func (client *client) Go(ctx context.Context, service, serviceMethod string, arg return cal } -// StreamGo invokes the streaming function asynchronously. It returns the call structure representing -// the invocation. -func (client *client) StreamGo(service string, serviceMethod string, args interface{}, replyStream interface{}) *call { - // first check the replyStream object is a stream of pointers to a data structure - typ := reflect.TypeOf(replyStream) - // FIXME: check the direction of the channel, maybe? - if typ.Kind() != reflect.Chan || typ.Elem().Kind() != reflect.Ptr { - log.Panic("rpc: replyStream is not a channel of pointers") - return nil - } - - call := new(call) - call.Service = service - call.ServiceMethod = serviceMethod - call.Args = args - call.Reply = replyStream - call.Stream = true - call.Subseq = 0 - client.send(call) - return call -} - // call invokes the named function, waits for it to complete, and returns its error status. func (client *client) Call(ctx context.Context, service string, serviceMethod string, args interface{}, reply interface{}) error { call := <-client.Go(ctx, service, serviceMethod, args, reply, make(chan *call, 1)).Done diff --git a/examples/client/codegen/codegen.go b/examples/client/codegen/codegen.go index 0289dd2d..aaa0db9e 100644 --- a/examples/client/codegen/codegen.go +++ b/examples/client/codegen/codegen.go @@ -21,14 +21,14 @@ func call(i int) { fmt.Println("Call:", i, "rsp:", rsp.Msg) } -func stream() { - stream, err := cl.Stream(context.Background(), &example.StreamingRequest{Count: int64(10)}) +func stream(i int) { + stream, err := cl.Stream(context.Background(), &example.StreamingRequest{Count: int64(i)}) if err != nil { fmt.Println("err:", err) return } - for i := 0; i < 10; i++ { - rsp, err := stream.Next() + for j := 0; j < i; j++ { + rsp, err := stream.RecvR() if err != nil { fmt.Println("err:", err) break @@ -44,6 +44,34 @@ func stream() { } } +func pingPong(i int) { + stream, err := cl.PingPong(context.Background()) + if err != nil { + fmt.Println("err:", err) + return + } + for j := 0; j < i; j++ { + if err := stream.SendR(&example.Ping{Stroke: int64(j)}); err != nil { + fmt.Println("err:", err) + return + } + rsp, err := stream.RecvR() + if err != nil { + fmt.Println("recv err", err) + break + } + fmt.Printf("Sent ping %v got pong %v\n", j, rsp.Stroke) + } + if stream.Error() != nil { + fmt.Println("stream err:", err) + return + } + + if err := stream.Close(); err != nil { + fmt.Println("stream close err:", err) + } +} + func main() { cmd.Init() @@ -51,6 +79,10 @@ func main() { for i := 0; i < 10; i++ { call(i) } + fmt.Println("\n--- Streamer example ---\n") - stream() + stream(10) + + fmt.Println("\n--- Ping Pong example ---\n") + pingPong(10) } diff --git a/examples/client/main.go b/examples/client/main.go index 369553a8..2fb229e5 100644 --- a/examples/client/main.go +++ b/examples/client/main.go @@ -54,22 +54,63 @@ func call(i int) { fmt.Println("Call:", i, "rsp:", rsp.Msg) } -func stream() { +func stream(i int) { // Create new request to service go.micro.srv.example, method Example.Call - req := client.NewRequest("go.micro.srv.example", "Example.Stream", &example.StreamingRequest{ - Count: int64(10), - }) + // Request can be empty as its actually ignored and merely used to call the handler + req := client.NewRequest("go.micro.srv.example", "Example.Stream", &example.StreamingRequest{}) - rspChan := make(chan *example.StreamingResponse, 10) + stream, err := client.Stream(context.Background(), req) + if err != nil { + fmt.Println("err:", err) + return + } + if err := stream.Send(&example.StreamingRequest{Count: int64(i)}); err != nil { + fmt.Println("err:", err) + return + } + for stream.Error() == nil { + rsp := &example.StreamingResponse{} + err := stream.Recv(rsp) + if err != nil { + fmt.Println("recv err", err) + break + } + fmt.Println("Stream: rsp:", rsp.Count) + } - stream, err := client.Stream(context.Background(), req, rspChan) + if stream.Error() != nil { + fmt.Println("stream err:", err) + return + } + + if err := stream.Close(); err != nil { + fmt.Println("stream close err:", err) + } +} + +func pingPong(i int) { + // Create new request to service go.micro.srv.example, method Example.Call + // Request can be empty as its actually ignored and merely used to call the handler + req := client.NewRequest("go.micro.srv.example", "Example.PingPong", &example.StreamingRequest{}) + + stream, err := client.Stream(context.Background(), req) if err != nil { fmt.Println("err:", err) return } - for rsp := range rspChan { - fmt.Println("Stream: rsp:", rsp.Count) + for j := 0; j < i; j++ { + if err := stream.Send(&example.Ping{Stroke: int64(j + 1)}); err != nil { + fmt.Println("err:", err) + return + } + rsp := &example.Pong{} + err := stream.Recv(rsp) + if err != nil { + fmt.Println("recv err", err) + break + } + fmt.Printf("Sent ping %v got pong %v\n", j+1, rsp.Stroke) } if stream.Error() != nil { @@ -90,7 +131,10 @@ func main() { } fmt.Println("\n--- Streamer example ---\n") - stream() + stream(10) + + fmt.Println("\n--- Ping Pong example ---\n") + pingPong(10) fmt.Println("\n--- Publisher example ---\n") pub() diff --git a/examples/server/handler/example.go b/examples/server/handler/example.go index 6e84cf72..d109ac60 100644 --- a/examples/server/handler/example.go +++ b/examples/server/handler/example.go @@ -18,19 +18,41 @@ func (e *Example) Call(ctx context.Context, req *example.Request, rsp *example.R return nil } -func (e *Example) Stream(ctx context.Context, req *example.StreamingRequest, response func(interface{}) error) error { +func (e *Example) Stream(ctx context.Context, stream server.Streamer) error { + log.Info("Executing streaming handler") + req := &example.StreamingRequest{} + + // We just want to receive 1 request and then process here + if err := stream.Recv(req); err != nil { + log.Errorf("Error receiving streaming request: %v", err) + return err + } + log.Infof("Received Example.Stream request with count: %d", req.Count) + for i := 0; i < int(req.Count); i++ { log.Infof("Responding: %d", i) - r := &example.StreamingResponse{ + if err := stream.Send(&example.StreamingResponse{ Count: int64(i), - } - - if err := response(r); err != nil { + }); err != nil { return err } } return nil } + +func (e *Example) PingPong(ctx context.Context, stream server.Streamer) error { + for { + req := &example.Ping{} + if err := stream.Recv(req); err != nil { + return err + } + log.Infof("Got ping %v", req.Stroke) + if err := stream.Send(&example.Pong{Stroke: req.Stroke}); err != nil { + return err + } + } + return nil +} diff --git a/examples/server/proto/example/example.pb.go b/examples/server/proto/example/example.pb.go index 3d7e141d..f48bff5a 100644 --- a/examples/server/proto/example/example.pb.go +++ b/examples/server/proto/example/example.pb.go @@ -14,6 +14,8 @@ It has these top-level messages: Response StreamingRequest StreamingResponse + Ping + Pong */ package go_micro_srv_example @@ -77,12 +79,32 @@ func (m *StreamingResponse) String() string { return proto.CompactTex func (*StreamingResponse) ProtoMessage() {} func (*StreamingResponse) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} } +type Ping struct { + Stroke int64 `protobuf:"varint,1,opt,name=stroke" json:"stroke,omitempty"` +} + +func (m *Ping) Reset() { *m = Ping{} } +func (m *Ping) String() string { return proto.CompactTextString(m) } +func (*Ping) ProtoMessage() {} +func (*Ping) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} } + +type Pong struct { + Stroke int64 `protobuf:"varint,1,opt,name=stroke" json:"stroke,omitempty"` +} + +func (m *Pong) Reset() { *m = Pong{} } +func (m *Pong) String() string { return proto.CompactTextString(m) } +func (*Pong) ProtoMessage() {} +func (*Pong) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} } + func init() { proto.RegisterType((*Message)(nil), "go.micro.srv.example.Message") proto.RegisterType((*Request)(nil), "go.micro.srv.example.Request") proto.RegisterType((*Response)(nil), "go.micro.srv.example.Response") proto.RegisterType((*StreamingRequest)(nil), "go.micro.srv.example.StreamingRequest") proto.RegisterType((*StreamingResponse)(nil), "go.micro.srv.example.StreamingResponse") + proto.RegisterType((*Ping)(nil), "go.micro.srv.example.Ping") + proto.RegisterType((*Pong)(nil), "go.micro.srv.example.Pong") } // Reference imports to suppress errors if they are not otherwise used. @@ -95,6 +117,7 @@ var _ server.Option type ExampleClient interface { Call(ctx context.Context, in *Request, opts ...client.CallOption) (*Response, error) Stream(ctx context.Context, in *StreamingRequest, opts ...client.CallOption) (Example_StreamClient, error) + PingPong(ctx context.Context, opts ...client.CallOption) (Example_PingPongClient, error) } type exampleClient struct { @@ -126,59 +149,151 @@ func (c *exampleClient) Call(ctx context.Context, in *Request, opts ...client.Ca } func (c *exampleClient) Stream(ctx context.Context, in *StreamingRequest, opts ...client.CallOption) (Example_StreamClient, error) { - req := c.c.NewRequest(c.serviceName, "Example.Stream", in) - outCh := make(chan *StreamingResponse) - stream, err := c.c.Stream(ctx, req, outCh, opts...) + req := c.c.NewRequest(c.serviceName, "Example.Stream", &StreamingRequest{}) + stream, err := c.c.Stream(ctx, req, opts...) if err != nil { return nil, err } - return &exampleStreamClient{stream, outCh}, nil + if err := stream.Send(in); err != nil { + return nil, err + } + return &exampleStreamClient{stream}, nil } type Example_StreamClient interface { - Next() (*StreamingResponse, error) + RecvR() (*StreamingResponse, error) client.Streamer } type exampleStreamClient struct { client.Streamer - next chan *StreamingResponse } -func (x *exampleStreamClient) Next() (*StreamingResponse, error) { - out, ok := <-x.next - if !ok { - return nil, fmt.Errorf(`chan closed`) +func (x *exampleStreamClient) RecvR() (*StreamingResponse, error) { + m := new(StreamingResponse) + err := x.Recv(m) + if err != nil { + return nil, err } - return out, nil + return m, nil +} + +func (c *exampleClient) PingPong(ctx context.Context, opts ...client.CallOption) (Example_PingPongClient, error) { + req := c.c.NewRequest(c.serviceName, "Example.PingPong", &Ping{}) + stream, err := c.c.Stream(ctx, req, opts...) + if err != nil { + return nil, err + } + return &examplePingPongClient{stream}, nil +} + +type Example_PingPongClient interface { + SendR(*Ping) error + RecvR() (*Pong, error) + client.Streamer +} + +type examplePingPongClient struct { + client.Streamer +} + +func (x *examplePingPongClient) SendR(m *Ping) error { + return x.Send(m) +} + +func (x *examplePingPongClient) RecvR() (*Pong, error) { + m := new(Pong) + err := x.Recv(m) + if err != nil { + return nil, err + } + return m, nil } // Server API for Example service type ExampleHandler interface { Call(context.Context, *Request, *Response) error - Stream(context.Context, func(*StreamingResponse) error) error + Stream(context.Context, *StreamingRequest, Example_StreamStream) error + PingPong(context.Context, Example_PingPongStream) error } func RegisterExampleHandler(s server.Server, hdlr ExampleHandler) { - s.Handle(s.NewHandler(hdlr)) + s.Handle(s.NewHandler(&exampleHandler{hdlr})) +} + +type exampleHandler struct { + ExampleHandler +} + +func (h *exampleHandler) Call(ctx context.Context, in *Request, out *Response) error { + return h.ExampleHandler.Call(ctx, in, out) +} + +func (h *exampleHandler) Stream(ctx context.Context, stream server.Streamer) error { + m := new(StreamingRequest) + if err := stream.Recv(m); err != nil { + return err + } + return h.ExampleHandler.Stream(ctx, m, &exampleStreamStream{stream}) +} + +type Example_StreamStream interface { + SendR(*StreamingResponse) error + server.Streamer +} + +type exampleStreamStream struct { + server.Streamer +} + +func (x *exampleStreamStream) SendR(m *StreamingResponse) error { + return x.Streamer.Send(m) +} + +func (h *exampleHandler) PingPong(ctx context.Context, stream server.Streamer) error { + return h.ExampleHandler.PingPong(ctx, &examplePingPongStream{stream}) +} + +type Example_PingPongStream interface { + SendR(*Pong) error + RecvR() (*Ping, error) + server.Streamer +} + +type examplePingPongStream struct { + server.Streamer +} + +func (x *examplePingPongStream) SendR(m *Pong) error { + return x.Streamer.Send(m) +} + +func (x *examplePingPongStream) RecvR() (*Ping, error) { + m := new(Ping) + if err := x.Streamer.Recv(m); err != nil { + return nil, err + } + return m, nil } var fileDescriptor0 = []byte{ - // 230 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x84, 0x90, 0xcd, 0x4a, 0xc5, 0x30, - 0x10, 0x85, 0x0d, 0xf7, 0x7a, 0xab, 0xa3, 0x82, 0x06, 0x51, 0x29, 0x28, 0x9a, 0x85, 0xba, 0x31, - 0x15, 0xf5, 0x0d, 0x44, 0x5c, 0xb9, 0xa9, 0x6b, 0x17, 0xb1, 0x0c, 0xa1, 0xd0, 0x24, 0x35, 0x93, - 0x16, 0x7d, 0x2c, 0xdf, 0x50, 0x48, 0xd3, 0xa2, 0x52, 0x71, 0x15, 0x98, 0xf3, 0x9d, 0x1f, 0x02, - 0x77, 0xda, 0x5d, 0x99, 0xba, 0xf2, 0xae, 0xc0, 0x77, 0x65, 0xda, 0x06, 0xa9, 0x20, 0xf4, 0x3d, - 0xfa, 0xa2, 0xf5, 0x2e, 0x4c, 0xd7, 0xf1, 0x95, 0xf1, 0xca, 0xf7, 0xb5, 0x93, 0xd1, 0x25, 0xc9, - 0xf7, 0x32, 0x69, 0xe2, 0x00, 0xb2, 0x27, 0x24, 0x52, 0x1a, 0xf9, 0x16, 0x2c, 0x48, 0x7d, 0x1c, - 0xb1, 0x53, 0x76, 0xb9, 0x29, 0x0e, 0x21, 0x2b, 0xf1, 0xad, 0x43, 0x0a, 0x7c, 0x1b, 0x96, 0x56, - 0x19, 0x9c, 0x84, 0x8d, 0x12, 0xa9, 0x75, 0x96, 0xa2, 0xc3, 0x90, 0x4e, 0xc2, 0x19, 0xec, 0x3e, - 0x07, 0x8f, 0xca, 0xd4, 0x56, 0x8f, 0xd6, 0x1d, 0x58, 0xaf, 0x5c, 0x67, 0x43, 0x44, 0x16, 0x42, - 0xc0, 0xde, 0x37, 0x24, 0x85, 0xfc, 0x64, 0x6e, 0x3e, 0x19, 0x64, 0x0f, 0xc3, 0x38, 0xfe, 0x08, - 0xcb, 0x7b, 0xd5, 0x34, 0xfc, 0x58, 0xce, 0x6d, 0x97, 0xa9, 0x25, 0x3f, 0xf9, 0x4b, 0x1e, 0x1a, - 0xc4, 0x1a, 0x7f, 0x81, 0xd5, 0x50, 0xcc, 0xcf, 0xe7, 0xd9, 0xdf, 0xcb, 0xf3, 0x8b, 0x7f, 0xb9, - 0x31, 0xfc, 0x9a, 0xbd, 0xae, 0xe2, 0x0f, 0xdf, 0x7e, 0x05, 0x00, 0x00, 0xff, 0xff, 0x63, 0x02, - 0xbf, 0x5f, 0x99, 0x01, 0x00, 0x00, + // 270 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x84, 0x91, 0x5f, 0x4b, 0xc3, 0x30, + 0x14, 0xc5, 0x17, 0x56, 0xdb, 0x79, 0xfd, 0x83, 0x06, 0x99, 0x52, 0x50, 0x34, 0x0f, 0xba, 0x17, + 0xd3, 0xa1, 0x7e, 0x03, 0x11, 0x7d, 0x11, 0x64, 0x3e, 0xfb, 0x10, 0xc7, 0x25, 0x0c, 0x9b, 0xa6, + 0xe6, 0x66, 0x43, 0x3f, 0xbb, 0x2f, 0x6e, 0x69, 0x3b, 0xc6, 0xec, 0xf0, 0x29, 0x70, 0x7e, 0xe7, + 0x5c, 0xce, 0x21, 0x70, 0xa7, 0xed, 0xb5, 0x99, 0x8c, 0x9d, 0xcd, 0xf0, 0x4b, 0x99, 0x32, 0x47, + 0xca, 0x08, 0xdd, 0x0c, 0x5d, 0x56, 0x3a, 0xeb, 0x97, 0x6a, 0xf3, 0xca, 0xa0, 0xf2, 0x23, 0x6d, + 0x65, 0x48, 0x49, 0x72, 0x33, 0x59, 0x33, 0xd1, 0x87, 0xe4, 0x19, 0x89, 0x94, 0x46, 0xbe, 0x03, + 0x5d, 0x52, 0xdf, 0x27, 0xec, 0x9c, 0x0d, 0xb6, 0xc5, 0x31, 0x24, 0x23, 0xfc, 0x9c, 0x22, 0x79, + 0xbe, 0x0b, 0x51, 0xa1, 0x0c, 0x2e, 0x41, 0x6f, 0x84, 0x54, 0xda, 0x82, 0x42, 0xc2, 0x90, 0xae, + 0xc1, 0x05, 0x1c, 0xbc, 0x7a, 0x87, 0xca, 0x4c, 0x0a, 0xdd, 0x44, 0xf7, 0x60, 0x6b, 0x6c, 0xa7, + 0x85, 0x0f, 0x96, 0xae, 0x10, 0x70, 0xb8, 0x62, 0xa9, 0x8f, 0xac, 0x79, 0xfa, 0x10, 0xbd, 0xcc, + 0x31, 0xdf, 0x87, 0x98, 0xbc, 0xb3, 0x1f, 0xb8, 0xa2, 0xdb, 0xbf, 0xfa, 0xcd, 0x0f, 0x83, 0xe4, + 0xa1, 0x1a, 0xc3, 0x1f, 0x21, 0xba, 0x57, 0x79, 0xce, 0x4f, 0x65, 0xdb, 0x56, 0x59, 0xb7, 0x4a, + 0xcf, 0x36, 0xe1, 0xaa, 0x91, 0xe8, 0xf0, 0x37, 0x88, 0xab, 0xa2, 0xfc, 0xb2, 0xdd, 0xbb, 0xbe, + 0x34, 0xbd, 0xfa, 0xd7, 0xd7, 0x1c, 0x1f, 0x32, 0xfe, 0x04, 0xbd, 0xc5, 0xc6, 0xb0, 0x27, 0x6d, + 0x0f, 0x2e, 0x78, 0xba, 0x89, 0xcd, 0x73, 0xa2, 0x33, 0x60, 0x43, 0xf6, 0x1e, 0x87, 0xbf, 0xbd, + 0xfd, 0x0d, 0x00, 0x00, 0xff, 0xff, 0x53, 0xb5, 0xeb, 0x31, 0x13, 0x02, 0x00, 0x00, } diff --git a/examples/server/proto/example/example.proto b/examples/server/proto/example/example.proto index b7dd8748..48c687e4 100644 --- a/examples/server/proto/example/example.proto +++ b/examples/server/proto/example/example.proto @@ -5,6 +5,7 @@ package go.micro.srv.example; service Example { rpc Call(Request) returns (Response) {} rpc Stream(StreamingRequest) returns (stream StreamingResponse) {} + rpc PingPong(stream Ping) returns (stream Pong) {} } message Message { @@ -26,3 +27,11 @@ message StreamingRequest { message StreamingResponse { int64 count = 1; } + +message Ping { + int64 stroke = 1; +} + +message Pong { + int64 stroke = 1; +} diff --git a/server/rpc_codec.go b/server/rpc_codec.go index 3aeb963f..69987924 100644 --- a/server/rpc_codec.go +++ b/server/rpc_codec.go @@ -60,9 +60,20 @@ func newRpcPlusCodec(req *transport.Message, socket transport.Socket, c codec.Ne return r } -func (c *rpcPlusCodec) ReadRequestHeader(r *request) error { - m := codec.Message{ - Headers: c.req.Header, +func (c *rpcPlusCodec) ReadRequestHeader(r *request, first bool) error { + m := codec.Message{Headers: c.req.Header} + + if !first { + var tm transport.Message + if err := c.socket.Recv(&tm); err != nil { + return err + } + c.buf.rbuf.Reset() + if _, err := c.buf.rbuf.Write(tm.Body); err != nil { + return err + } + + m.Headers = tm.Header } err := c.codec.ReadHeader(&m, codec.Request) diff --git a/server/rpc_stream.go b/server/rpc_stream.go new file mode 100644 index 00000000..1819c011 --- /dev/null +++ b/server/rpc_stream.go @@ -0,0 +1,78 @@ +package server + +import ( + "log" + "sync" + + "golang.org/x/net/context" +) + +// Implements the Streamer interface +type rpcStream struct { + sync.RWMutex + seq uint64 + closed bool + err error + request Request + codec serverCodec + context context.Context +} + +func (r *rpcStream) Context() context.Context { + return r.context +} + +func (r *rpcStream) Request() Request { + return r.request +} + +func (r *rpcStream) Send(msg interface{}) error { + r.Lock() + defer r.Unlock() + + seq := r.seq + r.seq++ + + resp := response{ + ServiceMethod: r.request.Method(), + Seq: seq, + } + + err := r.codec.WriteResponse(&resp, msg, false) + if err != nil { + log.Println("rpc: writing response:", err) + } + return err +} + +func (r *rpcStream) Recv(msg interface{}) error { + r.Lock() + defer r.Unlock() + + req := request{} + + if err := r.codec.ReadRequestHeader(&req, false); err != nil { + // discard body + r.codec.ReadRequestBody(nil) + return err + } + + if err := r.codec.ReadRequestBody(msg); err != nil { + return err + } + + return nil +} + +func (r *rpcStream) Error() error { + r.RLock() + defer r.RUnlock() + return r.err +} + +func (r *rpcStream) Close() error { + r.Lock() + defer r.Unlock() + r.closed = true + return r.codec.Close() +} diff --git a/server/rpcplus_server.go b/server/rpcplus_server.go index 62fca91c..2a270a7a 100644 --- a/server/rpcplus_server.go +++ b/server/rpcplus_server.go @@ -102,14 +102,19 @@ func prepareMethod(method reflect.Method) *methodType { mtype := method.Type mname := method.Name var replyType, argType, contextType reflect.Type + var stream bool - stream := false // Method must be exported. if method.PkgPath != "" { return nil } switch mtype.NumIn() { + case 3: + // assuming streaming + argType = mtype.In(2) + contextType = mtype.In(1) + stream = true case 4: // method that takes a context argType = mtype.In(2) @@ -120,44 +125,34 @@ func prepareMethod(method reflect.Method) *methodType { return nil } - // First arg need not be a pointer. - if !isExportedOrBuiltinType(argType) { - log.Println(mname, "argument type not exported:", argType) - return nil - } + if stream { + // check stream type + streamType := reflect.TypeOf((*Streamer)(nil)).Elem() + if !argType.Implements(streamType) { + log.Println(mname, "argument does not implement Streamer interface:", argType) + return nil + } + } else { + // if not stream check the replyType - // the second argument will tell us if it's a streaming call - // or a regular call - if replyType.Kind() == reflect.Func { - // this is a streaming call - stream = true - if replyType.NumIn() != 1 { - log.Println("method", mname, "sendReply has wrong number of ins:", replyType.NumIn()) - return nil - } - if replyType.In(0).Kind() != reflect.Interface { - log.Println("method", mname, "sendReply parameter type not an interface:", replyType.In(0)) - return nil - } - if replyType.NumOut() != 1 { - log.Println("method", mname, "sendReply has wrong number of outs:", replyType.NumOut()) - return nil - } - if returnType := replyType.Out(0); returnType != typeOfError { - log.Println("method", mname, "sendReply returns", returnType.String(), "not error") + // First arg need not be a pointer. + if !isExportedOrBuiltinType(argType) { + log.Println(mname, "argument type not exported:", argType) return nil } - } else if replyType.Kind() != reflect.Ptr { - log.Println("method", mname, "reply type not a pointer:", replyType) - return nil + if replyType.Kind() != reflect.Ptr { + log.Println("method", mname, "reply type not a pointer:", replyType) + return nil + } + + // Reply type must be exported. + if !isExportedOrBuiltinType(replyType) { + log.Println("method", mname, "reply type not exported:", replyType) + return nil + } } - // Reply type must be exported. - if !isExportedOrBuiltinType(replyType) { - log.Println("method", mname, "reply type not exported:", replyType) - return nil - } // Method needs one out. if mtype.NumOut() != 1 { log.Println("method", mname, "has wrong number of outs:", mtype.NumOut()) @@ -242,10 +237,11 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, service: s.name, contentType: ct, method: req.ServiceMethod, - request: argv.Interface(), } if !mtype.stream { + r.request = argv.Interface() + fn := func(ctx context.Context, req Request, rsp interface{}) error { returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rsp)}) @@ -276,40 +272,16 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, // keep track of the type, to make sure we return // the same one consistently var lastError error - var firstType reflect.Type - sendReply := func(oneReply interface{}) error { - - // we already triggered an error, we're done - if lastError != nil { - return lastError - } - - // check the oneReply has the right type using reflection - typ := reflect.TypeOf(oneReply) - if firstType == nil { - firstType = typ - } else { - if firstType != typ { - log.Println("passing wrong type to sendReply", - firstType, "!=", typ) - lastError = errors.New("rpc: passing wrong type to sendReply") - return lastError - } - } - - lastError = server.sendResponse(sending, req, oneReply, codec, "", false) - if lastError != nil { - return lastError - } - - // we manage to send, we're good - return nil + stream := &rpcStream{ + context: ctx, + codec: codec, + request: r, } // Invoke the method, providing a new value for the reply. - fn := func(ctx context.Context, req Request, rspFn interface{}) error { - returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(req.Request()), reflect.ValueOf(rspFn)}) + fn := func(ctx context.Context, req Request, stream interface{}) error { + returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) if err := returnValues[0].Interface(); err != nil { // the function returned an error, we use that return err.(error) @@ -331,7 +303,7 @@ func (s *service) call(ctx context.Context, server *server, sending *sync.Mutex, r.stream = true errmsg := "" - if err := fn(ctx, r, reflect.ValueOf(sendReply).Interface()); err != nil { + if err := fn(ctx, r, stream); err != nil { errmsg = err.Error() } @@ -417,6 +389,11 @@ func (server *server) readRequest(codec serverCodec) (service *service, mtype *m codec.ReadRequestBody(nil) return } + // is it a streaming request? then we don't read the body + if mtype.stream { + codec.ReadRequestBody(nil) + return + } // Decode the argument value. argIsValue := false // if true, need to indirect before calling. @@ -443,7 +420,7 @@ func (server *server) readRequest(codec serverCodec) (service *service, mtype *m func (server *server) readRequestHeader(codec serverCodec) (service *service, mtype *methodType, req *request, keepReading bool, err error) { // Grab the request header. req = server.getRequest() - err = codec.ReadRequestHeader(req) + err = codec.ReadRequestHeader(req, true) if err != nil { req = nil if err == io.EOF || err == io.ErrUnexpectedEOF { @@ -478,7 +455,7 @@ func (server *server) readRequestHeader(codec serverCodec) (service *service, mt } type serverCodec interface { - ReadRequestHeader(*request) error + ReadRequestHeader(*request, bool) error ReadRequestBody(interface{}) error WriteResponse(*response, interface{}, bool) error diff --git a/server/server.go b/server/server.go index 2a197bff..7044da0d 100644 --- a/server/server.go +++ b/server/server.go @@ -35,6 +35,7 @@ import ( log "github.com/golang/glog" "github.com/pborman/uuid" + "golang.org/x/net/context" ) type Server interface { @@ -61,10 +62,23 @@ type Request interface { Method() string ContentType() string Request() interface{} - // indicates whether the response should be streaming + // indicates whether the request will be streamed Stream() bool } +// Streamer represents a stream established with a client. +// A stream can be bidirectional which is indicated by the request. +// The last error will be left in Error(). +// EOF indicated end of the stream. +type Streamer interface { + Context() context.Context + Request() Request + Send(interface{}) error + Recv(interface{}) error + Error() error + Close() error +} + type Option func(*options) var ( diff --git a/server/server_wrapper.go b/server/server_wrapper.go index c6c4303f..45d2c46c 100644 --- a/server/server_wrapper.go +++ b/server/server_wrapper.go @@ -19,3 +19,9 @@ type HandlerWrapper func(HandlerFunc) HandlerFunc // SubscriberWrapper wraps the SubscriberFunc and returns the equivalent type SubscriberWrapper func(SubscriberFunc) SubscriberFunc + +// StreamerWrapper wraps a Streamer interface and returns the equivalent. +// Because streams exist for the lifetime of a method invocation this +// is a convenient way to wrap a Stream as its in use for trace, monitoring, +// metrics, etc. +type StreamerWrapper func(Streamer) Streamer diff --git a/transport/http_transport.go b/transport/http_transport.go index 91527e57..c2a2d770 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -23,16 +23,21 @@ type httpTransportClient struct { addr string conn net.Conn dialOpts dialOptions - r chan *http.Request once sync.Once sync.Mutex + r chan *http.Request + bl []*http.Request buff *bufio.Reader } type httpTransportSocket struct { - r *http.Request + r chan *http.Request conn net.Conn + once sync.Once + + sync.Mutex + buff *bufio.Reader } type httpTransportListener struct { @@ -68,7 +73,14 @@ func (h *httpTransportClient) Send(m *Message) error { Host: h.addr, } - h.r <- req + h.Lock() + h.bl = append(h.bl, req) + select { + case h.r <- h.bl[0]: + h.bl = h.bl[1:] + default: + } + h.Unlock() return req.Write(h.conn) } @@ -134,17 +146,23 @@ func (h *httpTransportSocket) Recv(m *Message) error { return errors.New("message passed in is nil") } - b, err := ioutil.ReadAll(h.r.Body) + r, err := http.ReadRequest(h.buff) if err != nil { return err } - h.r.Body.Close() + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + r.Body.Close() + mr := &Message{ Header: make(map[string]string), Body: b, } - for k, v := range h.r.Header { + for k, v := range r.Header { if len(v) > 0 { mr.Header[k] = v[0] } else { @@ -152,6 +170,11 @@ func (h *httpTransportSocket) Recv(m *Message) error { } } + select { + case h.r <- r: + default: + } + *m = *mr return nil } @@ -159,8 +182,11 @@ func (h *httpTransportSocket) Recv(m *Message) error { func (h *httpTransportSocket) Send(m *Message) error { b := bytes.NewBuffer(m.Body) defer b.Reset() + + r := <-h.r + rsp := &http.Response{ - Header: h.r.Header, + Header: r.Header, Body: &buffer{b}, Status: "200 OK", StatusCode: 200, @@ -174,6 +200,11 @@ func (h *httpTransportSocket) Send(m *Message) error { rsp.Header.Set(k, v) } + select { + case h.r <- r: + default: + } + return rsp.Write(h.conn) } @@ -199,7 +230,14 @@ func (h *httpTransportSocket) error(m *Message) error { } func (h *httpTransportSocket) Close() error { - return h.conn.Close() + err := h.conn.Close() + h.once.Do(func() { + h.Lock() + h.buff.Reset(nil) + h.buff = nil + h.Unlock() + }) + return err } func (h *httpTransportListener) Addr() string { @@ -211,18 +249,19 @@ func (h *httpTransportListener) Close() error { } func (h *httpTransportListener) Accept(fn func(Socket)) error { - srv := &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - return - } + for { + c, err := h.listener.Accept() + if err != nil { + return err + } - sock := &httpTransportSocket{ - conn: conn, - r: r, - } + sock := &httpTransportSocket{ + conn: c, + buff: bufio.NewReader(c), + r: make(chan *http.Request, 1), + } + go func() { // TODO: think of a better error response strategy defer func() { if r := recover(); r != nil { @@ -231,10 +270,9 @@ func (h *httpTransportListener) Accept(fn func(Socket)) error { }() fn(sock) - }), + }() } - - return srv.Serve(h.listener) + return nil } func (h *httpTransport) Dial(addr string, opts ...DialOption) (Client, error) {