diff --git a/encoding/form/proto_decode.go b/encoding/form/proto_decode.go index 47affee1f..fcf46d469 100644 --- a/encoding/form/proto_decode.go +++ b/encoding/form/proto_decode.go @@ -99,6 +99,9 @@ func getDescriptorByFieldAndName(fields protoreflect.FieldDescriptors, fieldName } func populateField(fd protoreflect.FieldDescriptor, v protoreflect.Message, value string) error { + if value == "" { + return nil + } val, err := parseField(fd, value) if err != nil { return fmt.Errorf("parsing field %q: %w", fd.FullName().Name(), err) diff --git a/transport/http/codec.go b/transport/http/codec.go index 20b1a8ec6..896683659 100644 --- a/transport/http/codec.go +++ b/transport/http/codec.go @@ -4,10 +4,13 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/go-kratos/kratos/v2/encoding" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/internal/httputil" + "github.com/go-kratos/kratos/v2/transport/http/binding" + "github.com/gorilla/mux" ) // SupportPackageIsVersion1 These constants should not be referenced from any other code. @@ -37,6 +40,21 @@ type EncodeResponseFunc func(http.ResponseWriter, *http.Request, interface{}) er // EncodeErrorFunc is encode error func. type EncodeErrorFunc func(http.ResponseWriter, *http.Request, error) +// DefaultRequestVars decodes the request vars to object. +func DefaultRequestVars(r *http.Request, v interface{}) error { + raws := mux.Vars(r) + vars := make(url.Values, len(raws)) + for k, v := range raws { + vars[k] = []string{v} + } + return binding.BindQuery(vars, v) +} + +// DefaultRequestQuery decodes the request vars to object. +func DefaultRequestQuery(r *http.Request, v interface{}) error { + return binding.BindQuery(r.URL.Query(), v) +} + // DefaultRequestDecoder decodes the request body to object. func DefaultRequestDecoder(r *http.Request, v interface{}) error { codec, ok := CodecForRequest(r, "Content-Type") diff --git a/transport/http/context.go b/transport/http/context.go index 347d7f2b7..ab9a1d391 100644 --- a/transport/http/context.go +++ b/transport/http/context.go @@ -96,9 +96,9 @@ func (c *wrapper) Middleware(h middleware.Handler) middleware.Handler { } return middleware.Chain(c.router.srv.middleware.Match(c.req.URL.Path)...)(h) } -func (c *wrapper) Bind(v interface{}) error { return c.router.srv.dec(c.req, v) } -func (c *wrapper) BindVars(v interface{}) error { return binding.BindQuery(c.Vars(), v) } -func (c *wrapper) BindQuery(v interface{}) error { return binding.BindQuery(c.Query(), v) } +func (c *wrapper) Bind(v interface{}) error { return c.router.srv.decBody(c.req, v) } +func (c *wrapper) BindVars(v interface{}) error { return c.router.srv.decVars(c.req, v) } +func (c *wrapper) BindQuery(v interface{}) error { return c.router.srv.decQuery(c.req, v) } func (c *wrapper) BindForm(v interface{}) error { return binding.BindForm(c.req, v) } func (c *wrapper) Returns(v interface{}, err error) error { if err != nil { diff --git a/transport/http/context_test.go b/transport/http/context_test.go index b440e9df7..daccccd95 100644 --- a/transport/http/context_test.go +++ b/transport/http/context_test.go @@ -12,9 +12,11 @@ import ( "time" ) +var testRouter = &Router{srv: NewServer()} + func TestContextHeader(t *testing.T) { w := wrapper{ - router: nil, + router: testRouter, req: &http.Request{Header: map[string][]string{"name": {"kratos"}}}, res: nil, w: responseWriter{}, @@ -27,7 +29,7 @@ func TestContextHeader(t *testing.T) { func TestContextForm(t *testing.T) { w := wrapper{ - router: nil, + router: testRouter, req: &http.Request{Header: map[string][]string{"name": {"kratos"}}, Method: "POST"}, res: nil, w: responseWriter{}, @@ -38,7 +40,7 @@ func TestContextForm(t *testing.T) { } w = wrapper{ - router: nil, + router: testRouter, req: &http.Request{Form: map[string][]string{"name": {"kratos"}}}, res: nil, w: responseWriter{}, @@ -51,7 +53,7 @@ func TestContextForm(t *testing.T) { func TestContextQuery(t *testing.T) { w := wrapper{ - router: nil, + router: testRouter, req: &http.Request{URL: &url.URL{Scheme: "https", Host: "github.com", Path: "go-kratos/kratos", RawQuery: "page=1"}, Method: "POST"}, res: nil, w: responseWriter{}, @@ -65,7 +67,7 @@ func TestContextQuery(t *testing.T) { func TestContextRequest(t *testing.T) { req := &http.Request{Method: "POST"} w := wrapper{ - router: nil, + router: testRouter, req: req, res: nil, w: responseWriter{}, @@ -100,7 +102,7 @@ func TestContextResponse(t *testing.T) { func TestContextBindQuery(t *testing.T) { w := wrapper{ - router: nil, + router: testRouter, req: &http.Request{URL: &url.URL{Scheme: "https", Host: "go-kratos-dev", RawQuery: "page=2"}}, res: nil, w: responseWriter{}, @@ -120,7 +122,7 @@ func TestContextBindQuery(t *testing.T) { func TestContextBindForm(t *testing.T) { w := wrapper{ - router: nil, + router: testRouter, req: &http.Request{URL: &url.URL{Scheme: "https", Host: "go-kratos-dev"}, Form: map[string][]string{"page": {"2"}}}, res: nil, w: responseWriter{}, @@ -141,7 +143,7 @@ func TestContextBindForm(t *testing.T) { func TestContextResponseReturn(t *testing.T) { writer := httptest.NewRecorder() w := wrapper{ - router: nil, + router: testRouter, req: nil, res: writer, w: responseWriter{}, @@ -174,7 +176,7 @@ func TestContextCtx(t *testing.T) { req := &http.Request{Method: "POST"} req = req.WithContext(ctx) w := wrapper{ - router: &Router{srv: &Server{enc: DefaultResponseEncoder}}, + router: testRouter, req: req, res: nil, w: responseWriter{}, diff --git a/transport/http/server.go b/transport/http/server.go index db7d0a723..b756a0f80 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -70,10 +70,24 @@ func Filter(filters ...FilterFunc) ServerOption { } } +// RequestVarsDecoder with request decoder. +func RequestVarsDecoder(dec DecodeRequestFunc) ServerOption { + return func(o *Server) { + o.decVars = dec + } +} + +// RequestQueryDecoder with request decoder. +func RequestQueryDecoder(dec DecodeRequestFunc) ServerOption { + return func(o *Server) { + o.decQuery = dec + } +} + // RequestDecoder with request decoder. func RequestDecoder(dec DecodeRequestFunc) ServerOption { return func(o *Server) { - o.dec = dec + o.decBody = dec } } @@ -126,7 +140,9 @@ type Server struct { timeout time.Duration filters []FilterFunc middleware matcher.Matcher - dec DecodeRequestFunc + decVars DecodeRequestFunc + decQuery DecodeRequestFunc + decBody DecodeRequestFunc enc EncodeResponseFunc ene EncodeErrorFunc strictSlash bool @@ -140,7 +156,9 @@ func NewServer(opts ...ServerOption) *Server { address: ":0", timeout: 1 * time.Second, middleware: matcher.New(), - dec: DefaultRequestDecoder, + decVars: DefaultRequestVars, + decQuery: DefaultRequestQuery, + decBody: DefaultRequestDecoder, enc: DefaultResponseEncoder, ene: DefaultErrorEncoder, strictSlash: true, diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 0d4be88bd..ff9b3918c 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -316,8 +316,8 @@ func TestRequestDecoder(t *testing.T) { o := &Server{} v := func(*http.Request, interface{}) error { return nil } RequestDecoder(v)(o) - if o.dec == nil { - t.Errorf("expected nil got %v", o.dec) + if o.decBody == nil { + t.Errorf("expected nil got %v", o.decBody) } }