From bf85c56b08ab05db0c75ee5b6380eb908f97d5fc Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Fri, 22 May 2015 04:40:01 -0700 Subject: [PATCH] Encapsulated fields and exposed public functions. Signed-off-by: Vishal Rana --- context.go | 59 +++++++++++++++++++++++-------------- context_test.go | 8 ++--- echo.go | 58 ++++++++++++++++++++---------------- echo_test.go | 2 +- middleware/auth.go | 2 +- middleware/auth_test.go | 16 +++++----- middleware/compress.go | 17 ++++++----- middleware/compress_test.go | 2 +- middleware/logger.go | 8 ++--- middleware/logger_test.go | 4 +-- middleware/slash.go | 6 ++-- middleware/slash_test.go | 4 +-- response.go | 4 +++ 13 files changed, 108 insertions(+), 82 deletions(-) diff --git a/context.go b/context.go index e92dd875..3e36abcc 100644 --- a/context.go +++ b/context.go @@ -11,9 +11,9 @@ type ( // Context represents context for the current request. It holds request and // response objects, path parameters, data and registered handler. Context struct { - Request *http.Request - Response *Response - Socket *websocket.Conn + request *http.Request + response *Response + socket *websocket.Conn pnames []string pvalues []string store store @@ -24,8 +24,8 @@ type ( func NewContext(req *http.Request, res *Response, e *Echo) *Context { return &Context{ - Request: req, - Response: res, + request: req, + response: res, echo: e, pnames: make([]string, e.maxParam), pvalues: make([]string, e.maxParam), @@ -33,6 +33,21 @@ func NewContext(req *http.Request, res *Response, e *Echo) *Context { } } +// Request returns *http.Request. +func (c *Context) Request() *http.Request { + return c.request +} + +// Response returns *Response. +func (c *Context) Response() *Response { + return c.response +} + +// Socket returns *websocket.Conn. +func (c *Context) Socket() *websocket.Conn { + return c.socket +} + // P returns path parameter by index. func (c *Context) P(i uint8) (value string) { l := uint8(len(c.pnames)) @@ -57,7 +72,7 @@ func (c *Context) Param(name string) (value string) { // Bind binds the request body into specified type v. Default binder does it // based on Content-Type header. func (c *Context) Bind(i interface{}) error { - return c.echo.binder(c.Request, i) + return c.echo.binder(c.request, i) } // Render invokes the registered HTML template renderer and sends a text/html @@ -66,37 +81,37 @@ func (c *Context) Render(code int, name string, data interface{}) error { if c.echo.renderer == nil { return RendererNotRegistered } - c.Response.Header().Set(ContentType, TextHTML+"; charset=utf-8") - c.Response.WriteHeader(code) - return c.echo.renderer.Render(c.Response, name, data) + c.response.Header().Set(ContentType, TextHTML+"; charset=utf-8") + c.response.WriteHeader(code) + return c.echo.renderer.Render(c.response, name, data) } // JSON sends an application/json response with status code. func (c *Context) JSON(code int, i interface{}) error { - c.Response.Header().Set(ContentType, ApplicationJSON+"; charset=utf-8") - c.Response.WriteHeader(code) - return json.NewEncoder(c.Response).Encode(i) + c.response.Header().Set(ContentType, ApplicationJSON+"; charset=utf-8") + c.response.WriteHeader(code) + return json.NewEncoder(c.response).Encode(i) } // String sends a text/plain response with status code. func (c *Context) String(code int, s string) error { - c.Response.Header().Set(ContentType, TextPlain+"; charset=utf-8") - c.Response.WriteHeader(code) - _, err := c.Response.Write([]byte(s)) + c.response.Header().Set(ContentType, TextPlain+"; charset=utf-8") + c.response.WriteHeader(code) + _, err := c.response.Write([]byte(s)) return err } // HTML sends a text/html response with status code. func (c *Context) HTML(code int, html string) error { - c.Response.Header().Set(ContentType, TextHTML+"; charset=utf-8") - c.Response.WriteHeader(code) - _, err := c.Response.Write([]byte(html)) + c.response.Header().Set(ContentType, TextHTML+"; charset=utf-8") + c.response.WriteHeader(code) + _, err := c.response.Write([]byte(html)) return err } // NoContent sends a response with no body and a status code. func (c *Context) NoContent(code int) error { - c.Response.WriteHeader(code) + c.response.WriteHeader(code) return nil } @@ -117,11 +132,11 @@ func (c *Context) Set(key string, val interface{}) { // Redirect redirects the request using http.Redirect with status code. func (c *Context) Redirect(code int, url string) { - http.Redirect(c.Response, c.Request, url, code) + http.Redirect(c.response, c.request, url, code) } func (c *Context) reset(w http.ResponseWriter, r *http.Request, e *Echo) { - c.Request = r - c.Response.reset(w) + c.request = r + c.response.reset(w) c.echo = e } diff --git a/context_test.go b/context_test.go index b2b2e3f6..0cc11ac7 100644 --- a/context_test.go +++ b/context_test.go @@ -91,26 +91,26 @@ func TestContext(t *testing.T) { // JSON r.Header.Set(Accept, ApplicationJSON) - c.Response.committed = false + c.response.committed = false if he := c.JSON(http.StatusOK, u1); he != nil { t.Errorf("json %#v", he) } // String r.Header.Set(Accept, TextPlain) - c.Response.committed = false + c.response.committed = false if he := c.String(http.StatusOK, "Hello, World!"); he != nil { t.Errorf("string %#v", he.Error) } // HTML r.Header.Set(Accept, TextHTML) - c.Response.committed = false + c.response.committed = false if he := c.HTML(http.StatusOK, "Hello, World!"); he != nil { t.Errorf("html %v", he.Error) } // Redirect - c.Response.committed = false + c.response.committed = false c.Redirect(http.StatusMovedPermanently, "http://labstack.github.io/echo") } diff --git a/echo.go b/echo.go index b67d98c8..a20f5971 100644 --- a/echo.go +++ b/echo.go @@ -13,9 +13,9 @@ import ( "strings" "sync" + "github.com/bradfitz/http2" "github.com/mattn/go-colorable" "golang.org/x/net/websocket" - "github.com/bradfitz/http2" ) type ( @@ -34,8 +34,8 @@ type ( debug bool } HTTPError struct { - Code int - Message string + code int + message string } Middleware interface{} MiddlewareFunc func(HandlerFunc) HandlerFunc @@ -123,15 +123,21 @@ var ( ) func NewHTTPError(code int, msg ...string) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} + he := &HTTPError{code: code, message: http.StatusText(code)} for _, m := range msg { - he.Message = m + he.message = m } return he } +// Code returns code. +func (e *HTTPError) Code() int { + return e.code +} + +// Error returns message. func (e *HTTPError) Error() string { - return e.Message + return e.message } // New creates an Echo instance. @@ -157,13 +163,13 @@ func New() (e *Echo) { code := http.StatusInternalServerError msg := http.StatusText(code) if he, ok := err.(*HTTPError); ok { - code = he.Code - msg = he.Message + code = he.code + msg = he.message } if e.Debug() { msg = err.Error() } - http.Error(c.Response, msg, code) + http.Error(c.response, msg, code) }) e.SetBinder(func(r *http.Request, v interface{}) error { ct := r.Header.Get(ContentType) @@ -283,12 +289,12 @@ func (e *Echo) WebSocket(path string, h HandlerFunc) { e.Get(path, func(c *Context) (err error) { wss := websocket.Server{ Handler: func(ws *websocket.Conn) { - c.Socket = ws - c.Response.status = http.StatusSwitchingProtocols + c.socket = ws + c.response.status = http.StatusSwitchingProtocols err = h(c) }, } - wss.ServeHTTP(c.Response.writer, c.Request) + wss.ServeHTTP(c.response.writer, c.request) return err }) } @@ -313,7 +319,7 @@ func (e *Echo) Favicon(file string) { func (e *Echo) Static(path, root string) { fs := http.StripPrefix(path, http.FileServer(http.Dir(root))) e.Get(path+"*", func(c *Context) error { - fs.ServeHTTP(c.Response, c.Request) + fs.ServeHTTP(c.response, c.request) return nil }) } @@ -321,7 +327,7 @@ func (e *Echo) Static(path, root string) { // ServeFile serves a file. func (e *Echo) ServeFile(path, file string) { e.Get(path, func(c *Context) error { - http.ServeFile(c.Response, c.Request, file) + http.ServeFile(c.response, c.request, file) return nil }) } @@ -399,17 +405,17 @@ func (e *Echo) RunTLSServer(srv *http.Server, certFile, keyFile string) { e.run(srv, certFile, keyFile) } -func (e *Echo) run(s *http.Server, f ...string) { +func (e *Echo) run(s *http.Server, files ...string) { s.Handler = e if e.http2 { http2.ConfigureServer(s, nil) } - if len(f) == 0 { + if len(files) == 0 { log.Fatal(s.ListenAndServe()) - } else if len(f) == 2 { - log.Fatal(s.ListenAndServeTLS(f[0], f[1])) + } else if len(files) == 2 { + log.Fatal(s.ListenAndServeTLS(files[0], files[1])) } else { - log.Fatal("echo: invalid TLS configuration") + log.Fatal("echo => invalid TLS configuration") } } @@ -428,10 +434,10 @@ func wrapMiddleware(m Middleware) MiddlewareFunc { return func(h HandlerFunc) HandlerFunc { return func(c *Context) (err error) { m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c.Response.writer = w - c.Request = r + c.response.writer = w + c.request = r err = h(c) - })).ServeHTTP(c.Response.writer, c.Request) + })).ServeHTTP(c.response.writer, c.request) return } } @@ -462,8 +468,8 @@ func wrapHandlerFuncMW(m HandlerFunc) MiddlewareFunc { func wrapHTTPHandlerFuncMW(m http.HandlerFunc) MiddlewareFunc { return func(h HandlerFunc) HandlerFunc { return func(c *Context) error { - if !c.Response.committed { - m.ServeHTTP(c.Response.writer, c.Request) + if !c.response.committed { + m.ServeHTTP(c.response.writer, c.request) } return h(c) } @@ -479,12 +485,12 @@ func wrapHandler(h Handler) HandlerFunc { return h case http.Handler, http.HandlerFunc: return func(c *Context) error { - h.(http.Handler).ServeHTTP(c.Response, c.Request) + h.(http.Handler).ServeHTTP(c.response, c.request) return nil } case func(http.ResponseWriter, *http.Request): return func(c *Context) error { - h(c.Response, c.Request) + h(c.response, c.request) return nil } default: diff --git a/echo_test.go b/echo_test.go index fad3febf..7b007e91 100644 --- a/echo_test.go +++ b/echo_test.go @@ -252,7 +252,7 @@ func TestEchoMethod(t *testing.T) { func TestWebSocket(t *testing.T) { e := New() e.WebSocket("/ws", func(c *Context) error { - c.Socket.Write([]byte("test")) + c.socket.Write([]byte("test")) return nil }) srv := httptest.NewServer(e) diff --git a/middleware/auth.go b/middleware/auth.go index 09d5e4c0..bd3a1469 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -21,7 +21,7 @@ const ( // For invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn AuthFunc) echo.HandlerFunc { return func(c *echo.Context) error { - auth := c.Request.Header.Get(echo.Authorization) + auth := c.Request().Header.Get(echo.Authorization) i := 0 code := http.StatusBadRequest diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 9933b434..d4d77165 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -48,8 +48,8 @@ func TestBasicAuth(t *testing.T) { he := ba(c).(*echo.HTTPError) if ba(c) == nil { t.Error("expected `fail`, with incorrect password.") - } else if he.Code != http.StatusUnauthorized { - t.Errorf("expected status `401`, got %d", he.Code) + } else if he.Code() != http.StatusUnauthorized { + t.Errorf("expected status `401`, got %d", he.Code()) } // Empty Authorization header @@ -58,8 +58,8 @@ func TestBasicAuth(t *testing.T) { he = ba(c).(*echo.HTTPError) if he == nil { t.Error("expected `fail`, with empty Authorization header.") - } else if he.Code != http.StatusBadRequest { - t.Errorf("expected status `400`, got %d", he.Code) + } else if he.Code() != http.StatusBadRequest { + t.Errorf("expected status `400`, got %d", he.Code()) } // Invalid Authorization header @@ -69,8 +69,8 @@ func TestBasicAuth(t *testing.T) { he = ba(c).(*echo.HTTPError) if he == nil { t.Error("expected `fail`, with invalid Authorization header.") - } else if he.Code != http.StatusBadRequest { - t.Errorf("expected status `400`, got %d", he.Code) + } else if he.Code() != http.StatusBadRequest { + t.Errorf("expected status `400`, got %d", he.Code()) } // Invalid scheme @@ -80,7 +80,7 @@ func TestBasicAuth(t *testing.T) { he = ba(c).(*echo.HTTPError) if he == nil { t.Error("expected `fail`, with invalid scheme.") - } else if he.Code != http.StatusBadRequest { - t.Errorf("expected status `400`, got %d", he.Code) + } else if he.Code() != http.StatusBadRequest { + t.Errorf("expected status `400`, got %d", he.Code()) } } diff --git a/middleware/compress.go b/middleware/compress.go index af6af39f..33736a27 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -6,17 +6,18 @@ import ( "strings" "github.com/labstack/echo" + "net/http" ) type ( gzipWriter struct { io.Writer - *echo.Response + http.ResponseWriter } ) -func (g gzipWriter) Write(b []byte) (int, error) { - return g.Writer.Write(b) +func (w gzipWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) } // Gzip returns a middleware which compresses HTTP response using gzip compression @@ -26,12 +27,12 @@ func Gzip() echo.MiddlewareFunc { return func(h echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { - if strings.Contains(c.Request.Header.Get(echo.AcceptEncoding), scheme) { - w := gzip.NewWriter(c.Response.Writer()) + if strings.Contains(c.Request().Header.Get(echo.AcceptEncoding), scheme) { + w := gzip.NewWriter(c.Response().Writer()) defer w.Close() - gw := gzipWriter{Writer: w, Response: c.Response} - c.Response.Header().Set(echo.ContentEncoding, scheme) - c.Response = echo.NewResponse(gw) + gw := gzipWriter{Writer: w, ResponseWriter: c.Response().Writer()} + c.Response().Header().Set(echo.ContentEncoding, scheme) + c.Response().SetWriter(gw) } return h(c) } diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 78c4d6e4..90417ac7 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -29,7 +29,7 @@ func TestGzip(t *testing.T) { // Content-Encoding header req.Header.Set(echo.AcceptEncoding, "gzip") w = httptest.NewRecorder() - c.Response = echo.NewResponse(w) + c.Response().SetWriter(w) Gzip()(h)(c) ce := w.Header().Get(echo.ContentEncoding) if ce != "gzip" { diff --git a/middleware/logger.go b/middleware/logger.go index 132c4d51..92bc8ed8 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -16,14 +16,14 @@ func Logger() echo.MiddlewareFunc { c.Error(err) } end := time.Now() - method := c.Request.Method - path := c.Request.URL.Path + method := c.Request().Method + path := c.Request().URL.Path if path == "" { path = "/" } - size := c.Response.Size() + size := c.Response().Size() - n := c.Response.Status() + n := c.Response().Status() code := color.Green(n) switch { case n >= 500: diff --git a/middleware/logger_test.go b/middleware/logger_test.go index aa1f2341..0c341e67 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -21,14 +21,14 @@ func TestLogger(t *testing.T) { Logger()(h)(c) // Status 4xx - c.Response = echo.NewResponse(w) + c = echo.NewContext(req, echo.NewResponse(w), e) h = func(c *echo.Context) error { return c.String(http.StatusNotFound, "test") } Logger()(h)(c) // Status 5xx - c.Response = echo.NewResponse(w) + c = echo.NewContext(req, echo.NewResponse(w), e) h = func(c *echo.Context) error { return c.String(http.StatusInternalServerError, "test") } diff --git a/middleware/slash.go b/middleware/slash.go index 4d84c140..acd4bef5 100644 --- a/middleware/slash.go +++ b/middleware/slash.go @@ -15,10 +15,10 @@ type ( // path. func StripTrailingSlash() echo.HandlerFunc { return func(c *echo.Context) error { - p := c.Request.URL.Path + p := c.Request().URL.Path l := len(p) if p[l-1] == '/' { - c.Request.URL.Path = p[:l-1] + c.Request().URL.Path = p[:l-1] } return nil } @@ -36,7 +36,7 @@ func RedirectToSlash(opts ...RedirectToSlashOptions) echo.HandlerFunc { } return func(c *echo.Context) error { - p := c.Request.URL.Path + p := c.Request().URL.Path l := len(p) if p[l-1] != '/' { c.Redirect(code, p+"/") diff --git a/middleware/slash_test.go b/middleware/slash_test.go index 1ac0bc59..67fe72bf 100644 --- a/middleware/slash_test.go +++ b/middleware/slash_test.go @@ -13,7 +13,7 @@ func TestStripTrailingSlash(t *testing.T) { res := echo.NewResponse(httptest.NewRecorder()) c := echo.NewContext(req, res, echo.New()) StripTrailingSlash()(c) - p := c.Request.URL.Path + p := c.Request().URL.Path if p != "/users" { t.Errorf("expected path `/users` got, %s.", p) } @@ -31,7 +31,7 @@ func TestRedirectToSlash(t *testing.T) { } // Location header - l := c.Response.Header().Get("Location") + l := c.Response().Header().Get("Location") if l != "/users/" { t.Errorf("expected Location header `/users/`, got %s.", l) } diff --git a/response.go b/response.go index 1774c580..caa191bf 100644 --- a/response.go +++ b/response.go @@ -26,6 +26,10 @@ func (r *Response) Header() http.Header { return r.writer.Header() } +func (r *Response) SetWriter(w http.ResponseWriter) { + r.writer = w +} + func (r *Response) Writer() http.ResponseWriter { return r.writer }