diff --git a/transport/http/client.go b/transport/http/client.go index fcd29d2cb..b792bccc3 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -232,27 +232,7 @@ func (client *Client) Invoke(ctx context.Context, method, path string, args inte func (client *Client) invoke(ctx context.Context, req *http.Request, args interface{}, reply interface{}, c callInfo, opts ...CallOption) error { h := func(ctx context.Context, in interface{}) (interface{}, error) { - var done func(context.Context, selector.DoneInfo) - if client.r != nil { - var ( - err error - node selector.Node - ) - if node, done, err = client.opts.selector.Select(ctx); err != nil { - return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) - } - if client.insecure { - req.URL.Scheme = "http" - } else { - req.URL.Scheme = "https" - } - req.URL.Host = node.Address() - req.Host = node.Address() - } res, err := client.do(ctx, req, c) - if done != nil { - done(ctx, selector.DoneInfo{Err: err}) - } if res != nil { cs := csAttempt{res: res} for _, o := range opts { @@ -284,16 +264,39 @@ func (client *Client) Do(req *http.Request, opts ...CallOption) (*http.Response, return nil, err } } - return client.do(req.Context(), req, c) + ctx := req.Context() + + return client.do(ctx, req, c) } func (client *Client) do(ctx context.Context, req *http.Request, c callInfo) (*http.Response, error) { + var done func(context.Context, selector.DoneInfo) + if client.r != nil { + var ( + err error + node selector.Node + ) + if node, done, err = client.opts.selector.Select(ctx); err != nil { + return nil, errors.ServiceUnavailable("NODE_NOT_FOUND", err.Error()) + } + if client.insecure { + req.URL.Scheme = "http" + } else { + req.URL.Scheme = "https" + } + req.URL.Host = node.Address() + req.Host = node.Address() + } resp, err := client.cc.Do(req) + if err == nil { + err = client.opts.errorDecoder(ctx, resp) + } + if err != nil { return nil, err } - if err := client.opts.errorDecoder(ctx, resp); err != nil { - return nil, err + if done != nil { + done(ctx, selector.DoneInfo{Err: err}) } return resp, nil }