mirror of
https://github.com/go-kratos/kratos.git
synced 2025-03-17 21:07:54 +02:00
transport/http: uses gRPC status to the HTTP error. (#870)
* uses gRPC status to the HTTP error.
This commit is contained in:
parent
b03c810dce
commit
7c3212c306
@ -25,6 +25,7 @@ func callHTTP() {
|
||||
transhttp.WithMiddleware(
|
||||
middleware.Chain(
|
||||
recovery.Recovery(),
|
||||
status.Client(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -63,6 +63,7 @@ func main() {
|
||||
httpSrv.HandlePrefix("/", pb.NewGreeterHandler(s,
|
||||
http.Middleware(
|
||||
middleware.Chain(
|
||||
status.Server(),
|
||||
logging.Server(logger),
|
||||
recovery.Recovery(),
|
||||
),
|
||||
|
121
internal/http/http.go
Normal file
121
internal/http/http.go
Normal file
@ -0,0 +1,121 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
)
|
||||
|
||||
const (
|
||||
baseContentType = "application"
|
||||
)
|
||||
|
||||
var (
|
||||
// HeaderAccept is accept header.
|
||||
HeaderAccept = http.CanonicalHeaderKey("Accept")
|
||||
// HeaderContentType is content-type header.
|
||||
HeaderContentType = http.CanonicalHeaderKey("Content-Type")
|
||||
// HeaderAcceptLanguage is accept-language header.
|
||||
HeaderAcceptLanguage = http.CanonicalHeaderKey("Accept-Language")
|
||||
)
|
||||
|
||||
// ContentType returns the content-type with base prefix.
|
||||
func ContentType(subtype string) string {
|
||||
return strings.Join([]string{baseContentType, subtype}, "/")
|
||||
}
|
||||
|
||||
// ContentSubtype returns the content-subtype for the given content-type. The
|
||||
// given content-type must be a valid content-type that starts with
|
||||
// but no content-subtype will be returned.
|
||||
//
|
||||
// contentType is assumed to be lowercase already.
|
||||
func ContentSubtype(contentType string) string {
|
||||
if contentType == baseContentType {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(contentType, baseContentType) {
|
||||
return ""
|
||||
}
|
||||
switch contentType[len(baseContentType)] {
|
||||
case '/', ';':
|
||||
if i := strings.Index(contentType, ";"); i != -1 {
|
||||
return contentType[len(baseContentType)+1 : i]
|
||||
}
|
||||
return contentType[len(baseContentType)+1:]
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// GRPCCodeFromStatus converts a HTTP error code into the corresponding gRPC response status.
|
||||
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
|
||||
func GRPCCodeFromStatus(code int) codes.Code {
|
||||
switch code {
|
||||
case http.StatusOK:
|
||||
return codes.OK
|
||||
case http.StatusBadRequest:
|
||||
return codes.InvalidArgument
|
||||
case http.StatusUnauthorized:
|
||||
return codes.Unauthenticated
|
||||
case http.StatusForbidden:
|
||||
return codes.PermissionDenied
|
||||
case http.StatusNotFound:
|
||||
return codes.NotFound
|
||||
case http.StatusConflict:
|
||||
return codes.Aborted
|
||||
case http.StatusTooManyRequests:
|
||||
return codes.ResourceExhausted
|
||||
case http.StatusInternalServerError:
|
||||
return codes.Internal
|
||||
case http.StatusNotImplemented:
|
||||
return codes.Unimplemented
|
||||
case http.StatusServiceUnavailable:
|
||||
return codes.Unavailable
|
||||
case http.StatusGatewayTimeout:
|
||||
return codes.DeadlineExceeded
|
||||
}
|
||||
return codes.Unknown
|
||||
}
|
||||
|
||||
// StatusFromGRPCCode converts a gRPC error code into the corresponding HTTP response status.
|
||||
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
|
||||
func StatusFromGRPCCode(code codes.Code) int {
|
||||
switch code {
|
||||
case codes.OK:
|
||||
return http.StatusOK
|
||||
case codes.Canceled:
|
||||
return http.StatusRequestTimeout
|
||||
case codes.Unknown:
|
||||
return http.StatusInternalServerError
|
||||
case codes.InvalidArgument:
|
||||
return http.StatusBadRequest
|
||||
case codes.DeadlineExceeded:
|
||||
return http.StatusGatewayTimeout
|
||||
case codes.NotFound:
|
||||
return http.StatusNotFound
|
||||
case codes.AlreadyExists:
|
||||
return http.StatusConflict
|
||||
case codes.PermissionDenied:
|
||||
return http.StatusForbidden
|
||||
case codes.Unauthenticated:
|
||||
return http.StatusUnauthorized
|
||||
case codes.ResourceExhausted:
|
||||
return http.StatusTooManyRequests
|
||||
case codes.FailedPrecondition:
|
||||
return http.StatusBadRequest
|
||||
case codes.Aborted:
|
||||
return http.StatusConflict
|
||||
case codes.OutOfRange:
|
||||
return http.StatusBadRequest
|
||||
case codes.Unimplemented:
|
||||
return http.StatusNotImplemented
|
||||
case codes.Internal:
|
||||
return http.StatusInternalServerError
|
||||
case codes.Unavailable:
|
||||
return http.StatusServiceUnavailable
|
||||
case codes.DataLoss:
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
@ -2,15 +2,14 @@ package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/internal/http"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
|
||||
//lint:ignore SA1019 grpc
|
||||
"github.com/golang/protobuf/proto"
|
||||
"google.golang.org/genproto/googleapis/rpc/errdetails"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
@ -34,7 +33,7 @@ func WithHandler(h HandlerFunc) Option {
|
||||
// Server is an error middleware.
|
||||
func Server(opts ...Option) middleware.Middleware {
|
||||
options := options{
|
||||
handler: encodeErr,
|
||||
handler: encodeError,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
@ -53,7 +52,7 @@ func Server(opts ...Option) middleware.Middleware {
|
||||
// Client is an error middleware.
|
||||
func Client(opts ...Option) middleware.Middleware {
|
||||
options := options{
|
||||
handler: decodeErr,
|
||||
handler: decodeError,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
@ -69,7 +68,7 @@ func Client(opts ...Option) middleware.Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
func encodeErr(ctx context.Context, err error) error {
|
||||
func encodeError(ctx context.Context, err error) error {
|
||||
var details []proto.Message
|
||||
if target := new(errors.ErrorInfo); errors.As(err, &target) {
|
||||
details = append(details, &errdetails.ErrorInfo{
|
||||
@ -79,7 +78,7 @@ func encodeErr(ctx context.Context, err error) error {
|
||||
})
|
||||
}
|
||||
es := errors.FromError(err)
|
||||
gs := status.New(httpToGRPCCode(es.Code), es.Message)
|
||||
gs := status.New(http.GRPCCodeFromStatus(es.Code), es.Message)
|
||||
gs, err = gs.WithDetails(details...)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -87,9 +86,9 @@ func encodeErr(ctx context.Context, err error) error {
|
||||
return gs.Err()
|
||||
}
|
||||
|
||||
func decodeErr(ctx context.Context, err error) error {
|
||||
func decodeError(ctx context.Context, err error) error {
|
||||
gs := status.Convert(err)
|
||||
code := grpcToHTTPCode(gs.Code())
|
||||
code := http.StatusFromGRPCCode(gs.Code())
|
||||
message := gs.Message()
|
||||
for _, detail := range gs.Details() {
|
||||
switch d := detail.(type) {
|
||||
@ -104,43 +103,3 @@ func decodeErr(ctx context.Context, err error) error {
|
||||
}
|
||||
return errors.New(code, message)
|
||||
}
|
||||
|
||||
func httpToGRPCCode(code int) codes.Code {
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
return codes.InvalidArgument
|
||||
case http.StatusUnauthorized:
|
||||
return codes.Unauthenticated
|
||||
case http.StatusForbidden:
|
||||
return codes.PermissionDenied
|
||||
case http.StatusNotFound:
|
||||
return codes.NotFound
|
||||
case http.StatusConflict:
|
||||
return codes.Aborted
|
||||
case http.StatusInternalServerError:
|
||||
return codes.Internal
|
||||
case http.StatusServiceUnavailable:
|
||||
return codes.Unavailable
|
||||
}
|
||||
return codes.Unknown
|
||||
}
|
||||
|
||||
func grpcToHTTPCode(code codes.Code) int {
|
||||
switch code {
|
||||
case codes.InvalidArgument:
|
||||
return http.StatusBadRequest
|
||||
case codes.Unauthenticated:
|
||||
return http.StatusUnauthorized
|
||||
case codes.PermissionDenied:
|
||||
return http.StatusForbidden
|
||||
case codes.NotFound:
|
||||
return http.StatusNotFound
|
||||
case codes.Aborted:
|
||||
return http.StatusConflict
|
||||
case codes.Internal:
|
||||
return http.StatusInternalServerError
|
||||
case codes.Unavailable:
|
||||
return http.StatusServiceUnavailable
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
@ -9,8 +9,8 @@ import (
|
||||
|
||||
func TestErrEncoder(t *testing.T) {
|
||||
err := errors.BadRequest("test", "invalid_argument", "format")
|
||||
en := encodeErr(context.Background(), err)
|
||||
de := decodeErr(context.Background(), en)
|
||||
en := encodeError(context.Background(), err)
|
||||
de := decodeError(context.Background(), en)
|
||||
if !errors.IsBadRequest(de) {
|
||||
t.Errorf("expected %v got %v", err, de)
|
||||
}
|
||||
|
@ -8,10 +8,17 @@ import (
|
||||
|
||||
"github.com/go-kratos/kratos/v2/encoding"
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
xhttp "github.com/go-kratos/kratos/v2/internal/http"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
spb "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
// DecodeErrorFunc is decode error func.
|
||||
type DecodeErrorFunc func(ctx context.Context, w *http.Response) error
|
||||
|
||||
// ClientOption is HTTP client option.
|
||||
type ClientOption func(*clientOptions)
|
||||
|
||||
@ -45,11 +52,12 @@ func WithMiddleware(m middleware.Middleware) ClientOption {
|
||||
|
||||
// Client is a HTTP transport client.
|
||||
type clientOptions struct {
|
||||
ctx context.Context
|
||||
timeout time.Duration
|
||||
userAgent string
|
||||
transport http.RoundTripper
|
||||
middleware middleware.Middleware
|
||||
ctx context.Context
|
||||
timeout time.Duration
|
||||
userAgent string
|
||||
transport http.RoundTripper
|
||||
errorDecoder DecodeErrorFunc
|
||||
middleware middleware.Middleware
|
||||
}
|
||||
|
||||
// NewClient returns an HTTP client.
|
||||
@ -64,26 +72,29 @@ func NewClient(ctx context.Context, opts ...ClientOption) (*http.Client, error)
|
||||
// NewTransport creates an http.RoundTripper.
|
||||
func NewTransport(ctx context.Context, opts ...ClientOption) (http.RoundTripper, error) {
|
||||
options := &clientOptions{
|
||||
ctx: ctx,
|
||||
timeout: 500 * time.Millisecond,
|
||||
transport: http.DefaultTransport,
|
||||
ctx: ctx,
|
||||
timeout: 500 * time.Millisecond,
|
||||
transport: http.DefaultTransport,
|
||||
errorDecoder: checkResponse,
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(options)
|
||||
}
|
||||
return &baseTransport{
|
||||
middleware: options.middleware,
|
||||
userAgent: options.userAgent,
|
||||
timeout: options.timeout,
|
||||
base: options.transport,
|
||||
errorDecoder: options.errorDecoder,
|
||||
middleware: options.middleware,
|
||||
userAgent: options.userAgent,
|
||||
timeout: options.timeout,
|
||||
base: options.transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type baseTransport struct {
|
||||
userAgent string
|
||||
timeout time.Duration
|
||||
base http.RoundTripper
|
||||
middleware middleware.Middleware
|
||||
userAgent string
|
||||
timeout time.Duration
|
||||
base http.RoundTripper
|
||||
errorDecoder DecodeErrorFunc
|
||||
middleware middleware.Middleware
|
||||
}
|
||||
|
||||
func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@ -96,7 +107,14 @@ func (t *baseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
defer cancel()
|
||||
|
||||
h := func(ctx context.Context, in interface{}) (interface{}, error) {
|
||||
return t.base.RoundTrip(in.(*http.Request))
|
||||
res, err := t.base.RoundTrip(in.(*http.Request))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := t.errorDecoder(ctx, res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
if t.middleware != nil {
|
||||
h = t.middleware(h)
|
||||
@ -115,19 +133,7 @@ func Do(client *http.Client, req *http.Request, target interface{}) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode < 200 || res.StatusCode > 299 {
|
||||
se := &errors.Error{Code: 500}
|
||||
if err := decodeResponse(res, se); err != nil {
|
||||
return err
|
||||
}
|
||||
return se
|
||||
}
|
||||
return decodeResponse(res, target)
|
||||
}
|
||||
|
||||
func decodeResponse(res *http.Response, target interface{}) error {
|
||||
subtype := contentSubtype(res.Header.Get(contentTypeHeader))
|
||||
subtype := xhttp.ContentSubtype(res.Header.Get(xhttp.HeaderContentType))
|
||||
codec := encoding.GetCodec(subtype)
|
||||
if codec == nil {
|
||||
codec = encoding.GetCodec("json")
|
||||
@ -138,3 +144,19 @@ func decodeResponse(res *http.Response, target interface{}) error {
|
||||
}
|
||||
return codec.Unmarshal(data, target)
|
||||
}
|
||||
|
||||
// checkResponse returns an error (of type *Error) if the response
|
||||
// status code is not 2xx.
|
||||
func checkResponse(ctx context.Context, res *http.Response) error {
|
||||
if res.StatusCode >= 200 && res.StatusCode <= 299 {
|
||||
return nil
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if data, err := ioutil.ReadAll(res.Body); err == nil {
|
||||
st := new(spb.Status)
|
||||
if err = protojson.Unmarshal(data, st); err == nil {
|
||||
return status.ErrorProto(st)
|
||||
}
|
||||
}
|
||||
return errors.New(res.StatusCode, "")
|
||||
}
|
||||
|
@ -3,26 +3,18 @@ package http
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/encoding"
|
||||
"github.com/go-kratos/kratos/v2/encoding/json"
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
xhttp "github.com/go-kratos/kratos/v2/internal/http"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport/http/binding"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const (
|
||||
// SupportPackageIsVersion1 These constants should not be referenced from any other code.
|
||||
SupportPackageIsVersion1 = true
|
||||
|
||||
baseContentType = "application"
|
||||
)
|
||||
|
||||
var (
|
||||
acceptHeader = http.CanonicalHeaderKey("Accept")
|
||||
contentTypeHeader = http.CanonicalHeaderKey("Content-Type")
|
||||
)
|
||||
// SupportPackageIsVersion1 These constants should not be referenced from any other code.
|
||||
const SupportPackageIsVersion1 = true
|
||||
|
||||
// DecodeRequestFunc is decode request func.
|
||||
type DecodeRequestFunc func(*http.Request, interface{}) error
|
||||
@ -83,7 +75,7 @@ func Middleware(m middleware.Middleware) HandleOption {
|
||||
|
||||
// decodeRequest decodes the request body to object.
|
||||
func decodeRequest(req *http.Request, v interface{}) error {
|
||||
subtype := contentSubtype(req.Header.Get(contentTypeHeader))
|
||||
subtype := xhttp.ContentSubtype(req.Header.Get(xhttp.HeaderContentType))
|
||||
if codec := encoding.GetCodec(subtype); codec != nil {
|
||||
data, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
@ -101,26 +93,29 @@ func encodeResponse(w http.ResponseWriter, r *http.Request, v interface{}) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set(contentTypeHeader, contentType(codec.Name()))
|
||||
w.Header().Set(xhttp.HeaderContentType, xhttp.ContentType(codec.Name()))
|
||||
_, _ = w.Write(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeError encodes the error to the HTTP response.
|
||||
func encodeError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
se := errors.FromError(err)
|
||||
codec := codecForRequest(r)
|
||||
data, _ := codec.Marshal(se)
|
||||
w.Header().Set(contentTypeHeader, contentType(codec.Name()))
|
||||
w.WriteHeader(se.Code)
|
||||
_, _ = w.Write(data)
|
||||
st, _ := status.FromError(err)
|
||||
data, err := protojson.Marshal(st.Proto())
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set(xhttp.HeaderContentType, "application/json; charset=utf-8")
|
||||
w.WriteHeader(xhttp.StatusFromGRPCCode(st.Code()))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// codecForRequest get encoding.Codec via http.Request
|
||||
func codecForRequest(r *http.Request) encoding.Codec {
|
||||
var codec encoding.Codec
|
||||
for _, accept := range r.Header[acceptHeader] {
|
||||
if codec = encoding.GetCodec(contentSubtype(accept)); codec != nil {
|
||||
for _, accept := range r.Header[xhttp.HeaderAccept] {
|
||||
if codec = encoding.GetCodec(xhttp.ContentSubtype(accept)); codec != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -129,25 +124,3 @@ func codecForRequest(r *http.Request) encoding.Codec {
|
||||
}
|
||||
return codec
|
||||
}
|
||||
|
||||
func contentType(subtype string) string {
|
||||
return strings.Join([]string{baseContentType, subtype}, "/")
|
||||
}
|
||||
|
||||
func contentSubtype(contentType string) string {
|
||||
if contentType == baseContentType {
|
||||
return ""
|
||||
}
|
||||
if !strings.HasPrefix(contentType, baseContentType) {
|
||||
return ""
|
||||
}
|
||||
switch contentType[len(baseContentType)] {
|
||||
case '/', ';':
|
||||
if i := strings.Index(contentType, ";"); i != -1 {
|
||||
return contentType[len(baseContentType)+1 : i]
|
||||
}
|
||||
return contentType[len(baseContentType)+1:]
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user