1
0
mirror of https://github.com/go-kratos/kratos.git synced 2025-01-07 23:02:12 +02:00

transport/http: fix content type (#1070)

* fix content type
This commit is contained in:
Tony Chen 2021-06-17 10:26:31 +08:00 committed by GitHub
parent db02034dd1
commit 16b1da04e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 3 deletions

View File

@ -43,10 +43,27 @@ type Context interface {
Reset(http.ResponseWriter, *http.Request)
}
type responseWriter struct {
code int
w http.ResponseWriter
}
func (w *responseWriter) rest(res http.ResponseWriter) {
w.w = res
w.code = http.StatusOK
}
func (w *responseWriter) Header() http.Header { return w.w.Header() }
func (w *responseWriter) WriteHeader(statusCode int) { w.code = statusCode }
func (w *responseWriter) Write(data []byte) (int, error) {
w.w.WriteHeader(w.code)
return w.w.Write(data)
}
type wrapper struct {
route *Route
req *http.Request
res http.ResponseWriter
w responseWriter
}
func (c *wrapper) Header() http.Header {
@ -83,14 +100,14 @@ func (c *wrapper) Returns(v interface{}, err error) error {
if err != nil {
return err
}
if err := c.route.srv.enc(c.res, c.req, v); err != nil {
if err := c.route.srv.enc(&c.w, c.req, v); err != nil {
return err
}
return nil
}
func (c *wrapper) Result(code int, v interface{}) error {
c.res.WriteHeader(code)
if err := c.route.srv.enc(c.res, c.req, v); err != nil {
c.w.WriteHeader(code)
if err := c.route.srv.enc(&c.w, c.req, v); err != nil {
return err
}
return nil
@ -124,6 +141,7 @@ func (c *wrapper) Stream(code int, contentType string, rd io.Reader) error {
return err
}
func (c *wrapper) Reset(res http.ResponseWriter, req *http.Request) {
c.w.rest(res)
c.res = res
c.req = req
}

View File

@ -91,6 +91,9 @@ func testRoute(t *testing.T, srv *Server) {
if resp.StatusCode != 200 {
t.Fatalf("code: %d", resp.StatusCode)
}
if v := resp.Header.Get("Content-Type"); v != "application/json" {
t.Fatalf("contentType: %s", v)
}
u := new(User)
if err := json.NewDecoder(resp.Body).Decode(u); err != nil {
t.Fatal(err)
@ -107,6 +110,9 @@ func testRoute(t *testing.T, srv *Server) {
if resp.StatusCode != 201 {
t.Fatalf("code: %d", resp.StatusCode)
}
if v := resp.Header.Get("Content-Type"); v != "application/json" {
t.Fatalf("contentType: %s", v)
}
u = new(User)
if err = json.NewDecoder(resp.Body).Decode(u); err != nil {
t.Fatal(err)
@ -125,6 +131,9 @@ func testRoute(t *testing.T, srv *Server) {
if resp.StatusCode != 200 {
t.Fatalf("code: %d", resp.StatusCode)
}
if v := resp.Header.Get("Content-Type"); v != "application/json" {
t.Fatalf("contentType: %s", v)
}
u = new(User)
if err = json.NewDecoder(resp.Body).Decode(u); err != nil {
t.Fatal(err)