mirror of
				https://github.com/go-kratos/kratos.git
				synced 2025-10-30 23:47:59 +02:00 
			
		
		
		
	transport/http: uses gRPC status to the HTTP error. (#870)
* uses gRPC status to the HTTP error.
This commit is contained in:
		| @@ -25,6 +25,7 @@ func callHTTP() { | ||||
| 		transhttp.WithMiddleware( | ||||
| 			middleware.Chain( | ||||
| 				recovery.Recovery(), | ||||
| 				status.Client(), | ||||
| 			), | ||||
| 		), | ||||
| 	) | ||||
|   | ||||
| @@ -63,6 +63,7 @@ func main() { | ||||
| 	httpSrv.HandlePrefix("/", pb.NewGreeterHandler(s, | ||||
| 		http.Middleware( | ||||
| 			middleware.Chain( | ||||
| 				status.Server(), | ||||
| 				logging.Server(logger), | ||||
| 				recovery.Recovery(), | ||||
| 			), | ||||
|   | ||||
							
								
								
									
										121
									
								
								internal/http/http.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										121
									
								
								internal/http/http.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,121 @@ | ||||
| package http | ||||
|  | ||||
| import ( | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"google.golang.org/grpc/codes" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	baseContentType = "application" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	// HeaderAccept is accept header. | ||||
| 	HeaderAccept = http.CanonicalHeaderKey("Accept") | ||||
| 	// HeaderContentType is content-type header. | ||||
| 	HeaderContentType = http.CanonicalHeaderKey("Content-Type") | ||||
| 	// HeaderAcceptLanguage is accept-language header. | ||||
| 	HeaderAcceptLanguage = http.CanonicalHeaderKey("Accept-Language") | ||||
| ) | ||||
|  | ||||
| // ContentType returns the content-type with base prefix. | ||||
| func ContentType(subtype string) string { | ||||
| 	return strings.Join([]string{baseContentType, subtype}, "/") | ||||
| } | ||||
|  | ||||
| // ContentSubtype returns the content-subtype for the given content-type.  The | ||||
| // given content-type must be a valid content-type that starts with | ||||
| // but no content-subtype will be returned. | ||||
| // | ||||
| // contentType is assumed to be lowercase already. | ||||
| func ContentSubtype(contentType string) string { | ||||
| 	if contentType == baseContentType { | ||||
| 		return "" | ||||
| 	} | ||||
| 	if !strings.HasPrefix(contentType, baseContentType) { | ||||
| 		return "" | ||||
| 	} | ||||
| 	switch contentType[len(baseContentType)] { | ||||
| 	case '/', ';': | ||||
| 		if i := strings.Index(contentType, ";"); i != -1 { | ||||
| 			return contentType[len(baseContentType)+1 : i] | ||||
| 		} | ||||
| 		return contentType[len(baseContentType)+1:] | ||||
| 	default: | ||||
| 		return "" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // GRPCCodeFromStatus converts a HTTP error code into the corresponding gRPC response status. | ||||
| // See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto | ||||
| func GRPCCodeFromStatus(code int) codes.Code { | ||||
| 	switch code { | ||||
| 	case http.StatusOK: | ||||
| 		return codes.OK | ||||
| 	case http.StatusBadRequest: | ||||
| 		return codes.InvalidArgument | ||||
| 	case http.StatusUnauthorized: | ||||
| 		return codes.Unauthenticated | ||||
| 	case http.StatusForbidden: | ||||
| 		return codes.PermissionDenied | ||||
| 	case http.StatusNotFound: | ||||
| 		return codes.NotFound | ||||
| 	case http.StatusConflict: | ||||
| 		return codes.Aborted | ||||
| 	case http.StatusTooManyRequests: | ||||
| 		return codes.ResourceExhausted | ||||
| 	case http.StatusInternalServerError: | ||||
| 		return codes.Internal | ||||
| 	case http.StatusNotImplemented: | ||||
| 		return codes.Unimplemented | ||||
| 	case http.StatusServiceUnavailable: | ||||
| 		return codes.Unavailable | ||||
| 	case http.StatusGatewayTimeout: | ||||
| 		return codes.DeadlineExceeded | ||||
| 	} | ||||
| 	return codes.Unknown | ||||
| } | ||||
|  | ||||
| // StatusFromGRPCCode converts a gRPC error code into the corresponding HTTP response status. | ||||
| // See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto | ||||
| func StatusFromGRPCCode(code codes.Code) int { | ||||
| 	switch code { | ||||
| 	case codes.OK: | ||||
| 		return http.StatusOK | ||||
| 	case codes.Canceled: | ||||
| 		return http.StatusRequestTimeout | ||||
| 	case codes.Unknown: | ||||
| 		return http.StatusInternalServerError | ||||
| 	case codes.InvalidArgument: | ||||
| 		return http.StatusBadRequest | ||||
| 	case codes.DeadlineExceeded: | ||||
| 		return http.StatusGatewayTimeout | ||||
| 	case codes.NotFound: | ||||
| 		return http.StatusNotFound | ||||
| 	case codes.AlreadyExists: | ||||
| 		return http.StatusConflict | ||||
| 	case codes.PermissionDenied: | ||||
| 		return http.StatusForbidden | ||||
| 	case codes.Unauthenticated: | ||||
| 		return http.StatusUnauthorized | ||||
| 	case codes.ResourceExhausted: | ||||
| 		return http.StatusTooManyRequests | ||||
| 	case codes.FailedPrecondition: | ||||
| 		return http.StatusBadRequest | ||||
| 	case codes.Aborted: | ||||
| 		return http.StatusConflict | ||||
| 	case codes.OutOfRange: | ||||
| 		return http.StatusBadRequest | ||||
| 	case codes.Unimplemented: | ||||
| 		return http.StatusNotImplemented | ||||
| 	case codes.Internal: | ||||
| 		return http.StatusInternalServerError | ||||
| 	case codes.Unavailable: | ||||
| 		return http.StatusServiceUnavailable | ||||
| 	case codes.DataLoss: | ||||
| 		return http.StatusInternalServerError | ||||
| 	} | ||||
| 	return http.StatusInternalServerError | ||||
| } | ||||
| @@ -2,15 +2,14 @@ package status | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/go-kratos/kratos/v2/errors" | ||||
| 	"github.com/go-kratos/kratos/v2/internal/http" | ||||
| 	"github.com/go-kratos/kratos/v2/middleware" | ||||
|  | ||||
| 	//lint:ignore SA1019 grpc | ||||
| 	"github.com/golang/protobuf/proto" | ||||
| 	"google.golang.org/genproto/googleapis/rpc/errdetails" | ||||
| 	"google.golang.org/grpc/codes" | ||||
| 	"google.golang.org/grpc/status" | ||||
| ) | ||||
|  | ||||
| @@ -34,7 +33,7 @@ func WithHandler(h HandlerFunc) Option { | ||||
| // Server is an error middleware. | ||||
| func Server(opts ...Option) middleware.Middleware { | ||||
| 	options := options{ | ||||
| 		handler: encodeErr, | ||||
| 		handler: encodeError, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| @@ -53,7 +52,7 @@ func Server(opts ...Option) middleware.Middleware { | ||||
| // Client is an error middleware. | ||||
| func Client(opts ...Option) middleware.Middleware { | ||||
| 	options := options{ | ||||
| 		handler: decodeErr, | ||||
| 		handler: decodeError, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(&options) | ||||
| @@ -69,7 +68,7 @@ func Client(opts ...Option) middleware.Middleware { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func encodeErr(ctx context.Context, err error) error { | ||||
| func encodeError(ctx context.Context, err error) error { | ||||
| 	var details []proto.Message | ||||
| 	if target := new(errors.ErrorInfo); errors.As(err, &target) { | ||||
| 		details = append(details, &errdetails.ErrorInfo{ | ||||
| @@ -79,7 +78,7 @@ func encodeErr(ctx context.Context, err error) error { | ||||
| 		}) | ||||
| 	} | ||||
| 	es := errors.FromError(err) | ||||
| 	gs := status.New(httpToGRPCCode(es.Code), es.Message) | ||||
| 	gs := status.New(http.GRPCCodeFromStatus(es.Code), es.Message) | ||||
| 	gs, err = gs.WithDetails(details...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @@ -87,9 +86,9 @@ func encodeErr(ctx context.Context, err error) error { | ||||
| 	return gs.Err() | ||||
| } | ||||
|  | ||||
| func decodeErr(ctx context.Context, err error) error { | ||||
| func decodeError(ctx context.Context, err error) error { | ||||
| 	gs := status.Convert(err) | ||||
| 	code := grpcToHTTPCode(gs.Code()) | ||||
| 	code := http.StatusFromGRPCCode(gs.Code()) | ||||
| 	message := gs.Message() | ||||
| 	for _, detail := range gs.Details() { | ||||
| 		switch d := detail.(type) { | ||||
| @@ -104,43 +103,3 @@ func decodeErr(ctx context.Context, err error) error { | ||||
| 	} | ||||
| 	return errors.New(code, message) | ||||
| } | ||||
|  | ||||
| func httpToGRPCCode(code int) codes.Code { | ||||
| 	switch code { | ||||
| 	case http.StatusBadRequest: | ||||
| 		return codes.InvalidArgument | ||||
| 	case http.StatusUnauthorized: | ||||
| 		return codes.Unauthenticated | ||||
| 	case http.StatusForbidden: | ||||
| 		return codes.PermissionDenied | ||||
| 	case http.StatusNotFound: | ||||
| 		return codes.NotFound | ||||
| 	case http.StatusConflict: | ||||
| 		return codes.Aborted | ||||
| 	case http.StatusInternalServerError: | ||||
| 		return codes.Internal | ||||
| 	case http.StatusServiceUnavailable: | ||||
| 		return codes.Unavailable | ||||
| 	} | ||||
| 	return codes.Unknown | ||||
| } | ||||
|  | ||||
| func grpcToHTTPCode(code codes.Code) int { | ||||
| 	switch code { | ||||
| 	case codes.InvalidArgument: | ||||
| 		return http.StatusBadRequest | ||||
| 	case codes.Unauthenticated: | ||||
| 		return http.StatusUnauthorized | ||||
| 	case codes.PermissionDenied: | ||||
| 		return http.StatusForbidden | ||||
| 	case codes.NotFound: | ||||
| 		return http.StatusNotFound | ||||
| 	case codes.Aborted: | ||||
| 		return http.StatusConflict | ||||
| 	case codes.Internal: | ||||
| 		return http.StatusInternalServerError | ||||
| 	case codes.Unavailable: | ||||
| 		return http.StatusServiceUnavailable | ||||
| 	} | ||||
| 	return http.StatusInternalServerError | ||||
| } | ||||
|   | ||||
| @@ -9,8 +9,8 @@ import ( | ||||
|  | ||||
| func TestErrEncoder(t *testing.T) { | ||||
| 	err := errors.BadRequest("test", "invalid_argument", "format") | ||||
| 	en := encodeErr(context.Background(), err) | ||||
| 	de := decodeErr(context.Background(), en) | ||||
| 	en := encodeError(context.Background(), err) | ||||
| 	de := decodeError(context.Background(), en) | ||||
| 	if !errors.IsBadRequest(de) { | ||||
| 		t.Errorf("expected %v got %v", err, de) | ||||
| 	} | ||||
|   | ||||
| @@ -8,10 +8,17 @@ import ( | ||||
|  | ||||
| 	"github.com/go-kratos/kratos/v2/encoding" | ||||
| 	"github.com/go-kratos/kratos/v2/errors" | ||||
| 	xhttp "github.com/go-kratos/kratos/v2/internal/http" | ||||
| 	"github.com/go-kratos/kratos/v2/middleware" | ||||
| 	"github.com/go-kratos/kratos/v2/transport" | ||||
| 	spb "google.golang.org/genproto/googleapis/rpc/status" | ||||
| 	"google.golang.org/grpc/status" | ||||
| 	"google.golang.org/protobuf/encoding/protojson" | ||||
| ) | ||||
|  | ||||
| // DecodeErrorFunc is decode error func. | ||||
| type DecodeErrorFunc func(ctx context.Context, w *http.Response) error | ||||
|  | ||||
| // ClientOption is HTTP client option. | ||||
| type ClientOption func(*clientOptions) | ||||
|  | ||||
| @@ -45,11 +52,12 @@ func WithMiddleware(m middleware.Middleware) ClientOption { | ||||
|  | ||||
| // Client is a HTTP transport client. | ||||
| type clientOptions struct { | ||||
| 	ctx        context.Context | ||||
| 	timeout    time.Duration | ||||
| 	userAgent  string | ||||
| 	transport  http.RoundTripper | ||||
| 	middleware middleware.Middleware | ||||
| 	ctx          context.Context | ||||
| 	timeout      time.Duration | ||||
| 	userAgent    string | ||||
| 	transport    http.RoundTripper | ||||
| 	errorDecoder DecodeErrorFunc | ||||
| 	middleware   middleware.Middleware | ||||
| } | ||||
|  | ||||
| // NewClient returns an HTTP client. | ||||
| @@ -64,26 +72,29 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*http.Client, error) | ||||
| // NewTransport creates an http.RoundTripper. | ||||
| func NewTransport(ctx context.Context, opts ...ClientOption) (http.RoundTripper, error) { | ||||
| 	options := &clientOptions{ | ||||
| 		ctx:       ctx, | ||||
| 		timeout:   500 * time.Millisecond, | ||||
| 		transport: http.DefaultTransport, | ||||
| 		ctx:          ctx, | ||||
| 		timeout:      500 * time.Millisecond, | ||||
| 		transport:    http.DefaultTransport, | ||||
| 		errorDecoder: checkResponse, | ||||
| 	} | ||||
| 	for _, o := range opts { | ||||
| 		o(options) | ||||
| 	} | ||||
| 	return &baseTransport{ | ||||
| 		middleware: options.middleware, | ||||
| 		userAgent:  options.userAgent, | ||||
| 		timeout:    options.timeout, | ||||
| 		base:       options.transport, | ||||
| 		errorDecoder: options.errorDecoder, | ||||
| 		middleware:   options.middleware, | ||||
| 		userAgent:    options.userAgent, | ||||
| 		timeout:      options.timeout, | ||||
| 		base:         options.transport, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| type baseTransport struct { | ||||
| 	userAgent  string | ||||
| 	timeout    time.Duration | ||||
| 	base       http.RoundTripper | ||||
| 	middleware middleware.Middleware | ||||
| 	userAgent    string | ||||
| 	timeout      time.Duration | ||||
| 	base         http.RoundTripper | ||||
| 	errorDecoder DecodeErrorFunc | ||||
| 	middleware   middleware.Middleware | ||||
| } | ||||
|  | ||||
| func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| @@ -96,7 +107,14 @@ func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||||
| 	defer cancel() | ||||
|  | ||||
| 	h := func(ctx context.Context, in interface{}) (interface{}, error) { | ||||
| 		return t.base.RoundTrip(in.(*http.Request)) | ||||
| 		res, err := t.base.RoundTrip(in.(*http.Request)) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if err := t.errorDecoder(ctx, res); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		return res, nil | ||||
| 	} | ||||
| 	if t.middleware != nil { | ||||
| 		h = t.middleware(h) | ||||
| @@ -115,19 +133,7 @@ func Do(client *http.Client, req *http.Request, target interface{}) error { | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	if res.StatusCode < 200 || res.StatusCode > 299 { | ||||
| 		se := &errors.Error{Code: 500} | ||||
| 		if err := decodeResponse(res, se); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return se | ||||
| 	} | ||||
| 	return decodeResponse(res, target) | ||||
| } | ||||
|  | ||||
| func decodeResponse(res *http.Response, target interface{}) error { | ||||
| 	subtype := contentSubtype(res.Header.Get(contentTypeHeader)) | ||||
| 	subtype := xhttp.ContentSubtype(res.Header.Get(xhttp.HeaderContentType)) | ||||
| 	codec := encoding.GetCodec(subtype) | ||||
| 	if codec == nil { | ||||
| 		codec = encoding.GetCodec("json") | ||||
| @@ -138,3 +144,19 @@ func decodeResponse(res *http.Response, target interface{}) error { | ||||
| 	} | ||||
| 	return codec.Unmarshal(data, target) | ||||
| } | ||||
|  | ||||
| // checkResponse returns an error (of type *Error) if the response | ||||
| // status code is not 2xx. | ||||
| func checkResponse(ctx context.Context, res *http.Response) error { | ||||
| 	if res.StatusCode >= 200 && res.StatusCode <= 299 { | ||||
| 		return nil | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	if data, err := ioutil.ReadAll(res.Body); err == nil { | ||||
| 		st := new(spb.Status) | ||||
| 		if err = protojson.Unmarshal(data, st); err == nil { | ||||
| 			return status.ErrorProto(st) | ||||
| 		} | ||||
| 	} | ||||
| 	return errors.New(res.StatusCode, "") | ||||
| } | ||||
|   | ||||
| @@ -3,26 +3,18 @@ package http | ||||
| import ( | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/go-kratos/kratos/v2/encoding" | ||||
| 	"github.com/go-kratos/kratos/v2/encoding/json" | ||||
| 	"github.com/go-kratos/kratos/v2/errors" | ||||
| 	xhttp "github.com/go-kratos/kratos/v2/internal/http" | ||||
| 	"github.com/go-kratos/kratos/v2/middleware" | ||||
| 	"github.com/go-kratos/kratos/v2/transport/http/binding" | ||||
| 	"google.golang.org/grpc/status" | ||||
| 	"google.golang.org/protobuf/encoding/protojson" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	// SupportPackageIsVersion1 These constants should not be referenced from any other code. | ||||
| 	SupportPackageIsVersion1 = true | ||||
|  | ||||
| 	baseContentType = "application" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	acceptHeader      = http.CanonicalHeaderKey("Accept") | ||||
| 	contentTypeHeader = http.CanonicalHeaderKey("Content-Type") | ||||
| ) | ||||
| // SupportPackageIsVersion1 These constants should not be referenced from any other code. | ||||
| const SupportPackageIsVersion1 = true | ||||
|  | ||||
| // DecodeRequestFunc is decode request func. | ||||
| type DecodeRequestFunc func(*http.Request, interface{}) error | ||||
| @@ -83,7 +75,7 @@ func Middleware(m middleware.Middleware) HandleOption { | ||||
|  | ||||
| // decodeRequest decodes the request body to object. | ||||
| func decodeRequest(req *http.Request, v interface{}) error { | ||||
| 	subtype := contentSubtype(req.Header.Get(contentTypeHeader)) | ||||
| 	subtype := xhttp.ContentSubtype(req.Header.Get(xhttp.HeaderContentType)) | ||||
| 	if codec := encoding.GetCodec(subtype); codec != nil { | ||||
| 		data, err := ioutil.ReadAll(req.Body) | ||||
| 		if err != nil { | ||||
| @@ -101,26 +93,29 @@ func encodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	w.Header().Set(contentTypeHeader, contentType(codec.Name())) | ||||
| 	w.Header().Set(xhttp.HeaderContentType, xhttp.ContentType(codec.Name())) | ||||
| 	_, _ = w.Write(data) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // encodeError encodes the error to the HTTP response. | ||||
| func encodeError(w http.ResponseWriter, r *http.Request, err error) { | ||||
| 	se := errors.FromError(err) | ||||
| 	codec := codecForRequest(r) | ||||
| 	data, _ := codec.Marshal(se) | ||||
| 	w.Header().Set(contentTypeHeader, contentType(codec.Name())) | ||||
| 	w.WriteHeader(se.Code) | ||||
| 	_, _ = w.Write(data) | ||||
| 	st, _ := status.FromError(err) | ||||
| 	data, err := protojson.Marshal(st.Proto()) | ||||
| 	if err != nil { | ||||
| 		w.WriteHeader(http.StatusInternalServerError) | ||||
| 		return | ||||
| 	} | ||||
| 	w.Header().Set(xhttp.HeaderContentType, "application/json; charset=utf-8") | ||||
| 	w.WriteHeader(xhttp.StatusFromGRPCCode(st.Code())) | ||||
| 	w.Write(data) | ||||
| } | ||||
|  | ||||
| // codecForRequest get encoding.Codec via http.Request | ||||
| func codecForRequest(r *http.Request) encoding.Codec { | ||||
| 	var codec encoding.Codec | ||||
| 	for _, accept := range r.Header[acceptHeader] { | ||||
| 		if codec = encoding.GetCodec(contentSubtype(accept)); codec != nil { | ||||
| 	for _, accept := range r.Header[xhttp.HeaderAccept] { | ||||
| 		if codec = encoding.GetCodec(xhttp.ContentSubtype(accept)); codec != nil { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| @@ -129,25 +124,3 @@ func codecForRequest(r *http.Request) encoding.Codec { | ||||
| 	} | ||||
| 	return codec | ||||
| } | ||||
|  | ||||
| func contentType(subtype string) string { | ||||
| 	return strings.Join([]string{baseContentType, subtype}, "/") | ||||
| } | ||||
|  | ||||
| func contentSubtype(contentType string) string { | ||||
| 	if contentType == baseContentType { | ||||
| 		return "" | ||||
| 	} | ||||
| 	if !strings.HasPrefix(contentType, baseContentType) { | ||||
| 		return "" | ||||
| 	} | ||||
| 	switch contentType[len(baseContentType)] { | ||||
| 	case '/', ';': | ||||
| 		if i := strings.Index(contentType, ";"); i != -1 { | ||||
| 			return contentType[len(baseContentType)+1 : i] | ||||
| 		} | ||||
| 		return contentType[len(baseContentType)+1:] | ||||
| 	default: | ||||
| 		return "" | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user