1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-07 23:02:12 +02:00
kratos/transport/http/client.go
Tony Chen c7827ff701
fix http middleware (#710)
* fix http middleware
2021-02-18 16:53:59 +08:00

137 lines
3.3 KiB
Go

package http
import (
"context"
"io/ioutil"
"net/http"
"time"
"github.com/go-kratos/kratos/v2/encoding"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
// ClientOption is HTTP client option.
type ClientOption func(*clientOptions)
// WithTimeout with client request timeout.
func WithTimeout(d time.Duration) ClientOption {
return func(o *clientOptions) {
o.timeout = d
}
}
// WithUserAgent with client user agent.
func WithUserAgent(ua string) ClientOption {
return func(o *clientOptions) {
o.userAgent = ua
}
}
// WithTransport with client transport.
func WithTransport(trans http.RoundTripper) ClientOption {
return func(o *clientOptions) {
o.transport = trans
}
}
// WithMiddleware with client middleware.
func WithMiddleware(m middleware.Middleware) ClientOption {
return func(o *clientOptions) {
o.middleware = m
}
}
// Client is a HTTP transport client.
type clientOptions struct {
ctx context.Context
timeout time.Duration
userAgent string
transport http.RoundTripper
middleware middleware.Middleware
}
// NewClient returns an HTTP client.
func NewClient(ctx context.Context, opts ...ClientOption) (*http.Client, error) {
trans, err := NewTransport(ctx, opts...)
if err != nil {
return nil, err
}
return &http.Client{Transport: trans}, nil
}
// NewTransport creates an http.RoundTripper.
func NewTransport(ctx context.Context, opts ...ClientOption) (http.RoundTripper, error) {
options := &clientOptions{
ctx: ctx,
timeout: 500 * time.Millisecond,
transport: http.DefaultTransport,
}
for _, o := range opts {
o(options)
}
return &baseTransport{
middleware: options.middleware,
userAgent: options.userAgent,
timeout: options.timeout,
base: options.transport,
}, nil
}
type baseTransport struct {
userAgent string
timeout time.Duration
base http.RoundTripper
middleware middleware.Middleware
}
func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.userAgent != "" && req.Header.Get("User-Agent") == "" {
req.Header.Set("User-Agent", t.userAgent)
}
ctx := transport.NewContext(req.Context(), transport.Transport{Kind: "HTTP"})
ctx = NewClientContext(ctx, ClientInfo{Request: req})
ctx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
h := func(ctx context.Context, in interface{}) (interface{}, error) {
return t.base.RoundTrip(in.(*http.Request))
}
if t.middleware != nil {
h = t.middleware(h)
}
res, err := h(ctx, req)
if err != nil {
return nil, err
}
return res.(*http.Response), nil
}
// Do send an HTTP request and decodes the body of response into target.
// returns an error (of type *Error) if the response status code is not 2xx.
func Do(client *http.Client, req *http.Request, target interface{}) error {
res, err := client.Do(req)
if err != nil {
return err
}
data, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
defer res.Body.Close()
subtype := contentSubtype(res.Header.Get("content-type"))
codec := encoding.GetCodec(subtype)
if codec == nil {
codec = encoding.GetCodec("json")
}
if res.StatusCode < 200 || res.StatusCode > 299 {
se := &errors.StatusError{}
if err := codec.Unmarshal(data, se); err != nil {
return err
}
return se
}
return codec.Unmarshal(data, target)
}