diff --git a/examples/helloworld/client/main.go b/examples/helloworld/client/main.go index d967168c5..bd4ab79c9 100644 --- a/examples/helloworld/client/main.go +++ b/examples/helloworld/client/main.go @@ -25,6 +25,7 @@ func callHTTP() { transhttp.WithMiddleware( middleware.Chain( recovery.Recovery(), + status.Client(), ), ), ) diff --git a/examples/helloworld/server/main.go b/examples/helloworld/server/main.go index c052e0be8..8ed226b2b 100644 --- a/examples/helloworld/server/main.go +++ b/examples/helloworld/server/main.go @@ -63,6 +63,7 @@ func main() { httpSrv.HandlePrefix("/", pb.NewGreeterHandler(s, http.Middleware( middleware.Chain( + status.Server(), logging.Server(logger), recovery.Recovery(), ), diff --git a/internal/http/http.go b/internal/http/http.go new file mode 100644 index 000000000..a3ae15525 --- /dev/null +++ b/internal/http/http.go @@ -0,0 +1,121 @@ +package http + +import ( + "net/http" + "strings" + + "google.golang.org/grpc/codes" +) + +const ( + baseContentType = "application" +) + +var ( + // HeaderAccept is accept header. + HeaderAccept = http.CanonicalHeaderKey("Accept") + // HeaderContentType is content-type header. + HeaderContentType = http.CanonicalHeaderKey("Content-Type") + // HeaderAcceptLanguage is accept-language header. + HeaderAcceptLanguage = http.CanonicalHeaderKey("Accept-Language") +) + +// ContentType returns the content-type with base prefix. +func ContentType(subtype string) string { + return strings.Join([]string{baseContentType, subtype}, "/") +} + +// ContentSubtype returns the content-subtype for the given content-type. The +// given content-type must be a valid content-type that starts with +// but no content-subtype will be returned. +// +// contentType is assumed to be lowercase already. +func ContentSubtype(contentType string) string { + if contentType == baseContentType { + return "" + } + if !strings.HasPrefix(contentType, baseContentType) { + return "" + } + switch contentType[len(baseContentType)] { + case '/', ';': + if i := strings.Index(contentType, ";"); i != -1 { + return contentType[len(baseContentType)+1 : i] + } + return contentType[len(baseContentType)+1:] + default: + return "" + } +} + +// GRPCCodeFromStatus converts a HTTP error code into the corresponding gRPC response status. +// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto +func GRPCCodeFromStatus(code int) codes.Code { + switch code { + case http.StatusOK: + return codes.OK + case http.StatusBadRequest: + return codes.InvalidArgument + case http.StatusUnauthorized: + return codes.Unauthenticated + case http.StatusForbidden: + return codes.PermissionDenied + case http.StatusNotFound: + return codes.NotFound + case http.StatusConflict: + return codes.Aborted + case http.StatusTooManyRequests: + return codes.ResourceExhausted + case http.StatusInternalServerError: + return codes.Internal + case http.StatusNotImplemented: + return codes.Unimplemented + case http.StatusServiceUnavailable: + return codes.Unavailable + case http.StatusGatewayTimeout: + return codes.DeadlineExceeded + } + return codes.Unknown +} + +// StatusFromGRPCCode converts a gRPC error code into the corresponding HTTP response status. +// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto +func StatusFromGRPCCode(code codes.Code) int { + switch code { + case codes.OK: + return http.StatusOK + case codes.Canceled: + return http.StatusRequestTimeout + case codes.Unknown: + return http.StatusInternalServerError + case codes.InvalidArgument: + return http.StatusBadRequest + case codes.DeadlineExceeded: + return http.StatusGatewayTimeout + case codes.NotFound: + return http.StatusNotFound + case codes.AlreadyExists: + return http.StatusConflict + case codes.PermissionDenied: + return http.StatusForbidden + case codes.Unauthenticated: + return http.StatusUnauthorized + case codes.ResourceExhausted: + return http.StatusTooManyRequests + case codes.FailedPrecondition: + return http.StatusBadRequest + case codes.Aborted: + return http.StatusConflict + case codes.OutOfRange: + return http.StatusBadRequest + case codes.Unimplemented: + return http.StatusNotImplemented + case codes.Internal: + return http.StatusInternalServerError + case codes.Unavailable: + return http.StatusServiceUnavailable + case codes.DataLoss: + return http.StatusInternalServerError + } + return http.StatusInternalServerError +} diff --git a/middleware/status/status.go b/middleware/status/status.go index 768bf9ec3..5d4311faa 100644 --- a/middleware/status/status.go +++ b/middleware/status/status.go @@ -2,15 +2,14 @@ package status import ( "context" - "net/http" "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/internal/http" "github.com/go-kratos/kratos/v2/middleware" //lint:ignore SA1019 grpc "github.com/golang/protobuf/proto" "google.golang.org/genproto/googleapis/rpc/errdetails" - "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -34,7 +33,7 @@ func WithHandler(h HandlerFunc) Option { // Server is an error middleware. func Server(opts ...Option) middleware.Middleware { options := options{ - handler: encodeErr, + handler: encodeError, } for _, o := range opts { o(&options) @@ -53,7 +52,7 @@ func Server(opts ...Option) middleware.Middleware { // Client is an error middleware. func Client(opts ...Option) middleware.Middleware { options := options{ - handler: decodeErr, + handler: decodeError, } for _, o := range opts { o(&options) @@ -69,7 +68,7 @@ func Client(opts ...Option) middleware.Middleware { } } -func encodeErr(ctx context.Context, err error) error { +func encodeError(ctx context.Context, err error) error { var details []proto.Message if target := new(errors.ErrorInfo); errors.As(err, &target) { details = append(details, &errdetails.ErrorInfo{ @@ -79,7 +78,7 @@ func encodeErr(ctx context.Context, err error) error { }) } es := errors.FromError(err) - gs := status.New(httpToGRPCCode(es.Code), es.Message) + gs := status.New(http.GRPCCodeFromStatus(es.Code), es.Message) gs, err = gs.WithDetails(details...) if err != nil { return err @@ -87,9 +86,9 @@ func encodeErr(ctx context.Context, err error) error { return gs.Err() } -func decodeErr(ctx context.Context, err error) error { +func decodeError(ctx context.Context, err error) error { gs := status.Convert(err) - code := grpcToHTTPCode(gs.Code()) + code := http.StatusFromGRPCCode(gs.Code()) message := gs.Message() for _, detail := range gs.Details() { switch d := detail.(type) { @@ -104,43 +103,3 @@ func decodeErr(ctx context.Context, err error) error { } return errors.New(code, message) } - -func httpToGRPCCode(code int) codes.Code { - switch code { - case http.StatusBadRequest: - return codes.InvalidArgument - case http.StatusUnauthorized: - return codes.Unauthenticated - case http.StatusForbidden: - return codes.PermissionDenied - case http.StatusNotFound: - return codes.NotFound - case http.StatusConflict: - return codes.Aborted - case http.StatusInternalServerError: - return codes.Internal - case http.StatusServiceUnavailable: - return codes.Unavailable - } - return codes.Unknown -} - -func grpcToHTTPCode(code codes.Code) int { - switch code { - case codes.InvalidArgument: - return http.StatusBadRequest - case codes.Unauthenticated: - return http.StatusUnauthorized - case codes.PermissionDenied: - return http.StatusForbidden - case codes.NotFound: - return http.StatusNotFound - case codes.Aborted: - return http.StatusConflict - case codes.Internal: - return http.StatusInternalServerError - case codes.Unavailable: - return http.StatusServiceUnavailable - } - return http.StatusInternalServerError -} diff --git a/middleware/status/status_test.go b/middleware/status/status_test.go index 074f5c9a0..a974e09ee 100644 --- a/middleware/status/status_test.go +++ b/middleware/status/status_test.go @@ -9,8 +9,8 @@ import ( func TestErrEncoder(t *testing.T) { err := errors.BadRequest("test", "invalid_argument", "format") - en := encodeErr(context.Background(), err) - de := decodeErr(context.Background(), en) + en := encodeError(context.Background(), err) + de := decodeError(context.Background(), en) if !errors.IsBadRequest(de) { t.Errorf("expected %v got %v", err, de) } diff --git a/transport/http/client.go b/transport/http/client.go index 95b4e1bb9..9c987757b 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -8,10 +8,17 @@ import ( "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" + xhttp "github.com/go-kratos/kratos/v2/internal/http" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" ) +// DecodeErrorFunc is decode error func. +type DecodeErrorFunc func(ctx context.Context, w *http.Response) error + // ClientOption is HTTP client option. type ClientOption func(*clientOptions) @@ -45,11 +52,12 @@ func WithMiddleware(m middleware.Middleware) ClientOption { // Client is a HTTP transport client. type clientOptions struct { - ctx context.Context - timeout time.Duration - userAgent string - transport http.RoundTripper - middleware middleware.Middleware + ctx context.Context + timeout time.Duration + userAgent string + transport http.RoundTripper + errorDecoder DecodeErrorFunc + middleware middleware.Middleware } // NewClient returns an HTTP client. @@ -64,26 +72,29 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*http.Client, error) // NewTransport creates an http.RoundTripper. func NewTransport(ctx context.Context, opts ...ClientOption) (http.RoundTripper, error) { options := &clientOptions{ - ctx: ctx, - timeout: 500 * time.Millisecond, - transport: http.DefaultTransport, + ctx: ctx, + timeout: 500 * time.Millisecond, + transport: http.DefaultTransport, + errorDecoder: checkResponse, } for _, o := range opts { o(options) } return &baseTransport{ - middleware: options.middleware, - userAgent: options.userAgent, - timeout: options.timeout, - base: options.transport, + errorDecoder: options.errorDecoder, + middleware: options.middleware, + userAgent: options.userAgent, + timeout: options.timeout, + base: options.transport, }, nil } type baseTransport struct { - userAgent string - timeout time.Duration - base http.RoundTripper - middleware middleware.Middleware + userAgent string + timeout time.Duration + base http.RoundTripper + errorDecoder DecodeErrorFunc + middleware middleware.Middleware } func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -96,7 +107,14 @@ func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) { defer cancel() h := func(ctx context.Context, in interface{}) (interface{}, error) { - return t.base.RoundTrip(in.(*http.Request)) + res, err := t.base.RoundTrip(in.(*http.Request)) + if err != nil { + return nil, err + } + if err := t.errorDecoder(ctx, res); err != nil { + return nil, err + } + return res, nil } if t.middleware != nil { h = t.middleware(h) @@ -115,19 +133,7 @@ func Do(client *http.Client, req *http.Request, target interface{}) error { if err != nil { return err } - defer res.Body.Close() - if res.StatusCode < 200 || res.StatusCode > 299 { - se := &errors.Error{Code: 500} - if err := decodeResponse(res, se); err != nil { - return err - } - return se - } - return decodeResponse(res, target) -} - -func decodeResponse(res *http.Response, target interface{}) error { - subtype := contentSubtype(res.Header.Get(contentTypeHeader)) + subtype := xhttp.ContentSubtype(res.Header.Get(xhttp.HeaderContentType)) codec := encoding.GetCodec(subtype) if codec == nil { codec = encoding.GetCodec("json") @@ -138,3 +144,19 @@ func decodeResponse(res *http.Response, target interface{}) error { } return codec.Unmarshal(data, target) } + +// checkResponse returns an error (of type *Error) if the response +// status code is not 2xx. +func checkResponse(ctx context.Context, res *http.Response) error { + if res.StatusCode >= 200 && res.StatusCode <= 299 { + return nil + } + defer res.Body.Close() + if data, err := ioutil.ReadAll(res.Body); err == nil { + st := new(spb.Status) + if err = protojson.Unmarshal(data, st); err == nil { + return status.ErrorProto(st) + } + } + return errors.New(res.StatusCode, "") +} diff --git a/transport/http/handle.go b/transport/http/handle.go index 1bce7e9a6..b02782c87 100644 --- a/transport/http/handle.go +++ b/transport/http/handle.go @@ -3,26 +3,18 @@ package http import ( "io/ioutil" "net/http" - "strings" "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/encoding/json" - "github.com/go-kratos/kratos/v2/errors" + xhttp "github.com/go-kratos/kratos/v2/internal/http" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport/http/binding" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" ) -const ( - // SupportPackageIsVersion1 These constants should not be referenced from any other code. - SupportPackageIsVersion1 = true - - baseContentType = "application" -) - -var ( - acceptHeader = http.CanonicalHeaderKey("Accept") - contentTypeHeader = http.CanonicalHeaderKey("Content-Type") -) +// SupportPackageIsVersion1 These constants should not be referenced from any other code. +const SupportPackageIsVersion1 = true // DecodeRequestFunc is decode request func. type DecodeRequestFunc func(*http.Request, interface{}) error @@ -83,7 +75,7 @@ func Middleware(m middleware.Middleware) HandleOption { // decodeRequest decodes the request body to object. func decodeRequest(req *http.Request, v interface{}) error { - subtype := contentSubtype(req.Header.Get(contentTypeHeader)) + subtype := xhttp.ContentSubtype(req.Header.Get(xhttp.HeaderContentType)) if codec := encoding.GetCodec(subtype); codec != nil { data, err := ioutil.ReadAll(req.Body) if err != nil { @@ -101,26 +93,29 @@ func encodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error if err != nil { return err } - w.Header().Set(contentTypeHeader, contentType(codec.Name())) + w.Header().Set(xhttp.HeaderContentType, xhttp.ContentType(codec.Name())) _, _ = w.Write(data) return nil } // encodeError encodes the error to the HTTP response. func encodeError(w http.ResponseWriter, r *http.Request, err error) { - se := errors.FromError(err) - codec := codecForRequest(r) - data, _ := codec.Marshal(se) - w.Header().Set(contentTypeHeader, contentType(codec.Name())) - w.WriteHeader(se.Code) - _, _ = w.Write(data) + st, _ := status.FromError(err) + data, err := protojson.Marshal(st.Proto()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Set(xhttp.HeaderContentType, "application/json; charset=utf-8") + w.WriteHeader(xhttp.StatusFromGRPCCode(st.Code())) + w.Write(data) } // codecForRequest get encoding.Codec via http.Request func codecForRequest(r *http.Request) encoding.Codec { var codec encoding.Codec - for _, accept := range r.Header[acceptHeader] { - if codec = encoding.GetCodec(contentSubtype(accept)); codec != nil { + for _, accept := range r.Header[xhttp.HeaderAccept] { + if codec = encoding.GetCodec(xhttp.ContentSubtype(accept)); codec != nil { break } } @@ -129,25 +124,3 @@ func codecForRequest(r *http.Request) encoding.Codec { } return codec } - -func contentType(subtype string) string { - return strings.Join([]string{baseContentType, subtype}, "/") -} - -func contentSubtype(contentType string) string { - if contentType == baseContentType { - return "" - } - if !strings.HasPrefix(contentType, baseContentType) { - return "" - } - switch contentType[len(baseContentType)] { - case '/', ';': - if i := strings.Index(contentType, ";"); i != -1 { - return contentType[len(baseContentType)+1 : i] - } - return contentType[len(baseContentType)+1:] - default: - return "" - } -}