diff --git a/echo.go b/echo.go index dfb7590f..f984f405 100644 --- a/echo.go +++ b/echo.go @@ -135,7 +135,8 @@ const ( Upgrade = "Upgrade" Vary = "Vary" WWWAuthenticate = "WWW-Authenticate" - + XForwardedFor = "X-Forwarded-For" + XRealIP = "X-Real-IP" //----------- // Protocols //----------- diff --git a/middleware/logger.go b/middleware/logger.go index 19d1301d..6249cde7 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,6 +2,7 @@ package middleware import ( "log" + "net" "time" "github.com/labstack/echo" @@ -11,19 +12,30 @@ import ( func Logger() echo.MiddlewareFunc { return func(h echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { + req := c.Request() + res := c.Response() + + remoteAddr := req.RemoteAddr + if ip := req.Header.Get(echo.XRealIP); ip != "" { + remoteAddr = ip + } else if ip = req.Header.Get(echo.XForwardedFor); ip != "" { + remoteAddr = ip + } + remoteAddr, _, _ = net.SplitHostPort(remoteAddr) + start := time.Now() if err := h(c); err != nil { c.Error(err) } stop := time.Now() - method := c.Request().Method - path := c.Request().URL.Path + method := req.Method + path := req.URL.Path if path == "" { path = "/" } - size := c.Response().Size() + size := res.Size() - n := c.Response().Status() + n := res.Status() code := color.Green(n) switch { case n >= 500: @@ -34,7 +46,7 @@ func Logger() echo.MiddlewareFunc { code = color.Cyan(n) } - log.Printf("%s %s %s %s %d", method, path, code, stop.Sub(start), size) + log.Printf("%s %s %s %s %s %d", remoteAddr, method, path, code, stop.Sub(start), size) return nil } } diff --git a/middleware/logger_test.go b/middleware/logger_test.go index e46019f3..69c0ca46 100644 --- a/middleware/logger_test.go +++ b/middleware/logger_test.go @@ -2,24 +2,41 @@ package middleware import ( "errors" - "github.com/labstack/echo" "net/http" "net/http/httptest" "testing" + + "github.com/labstack/echo" ) func TestLogger(t *testing.T) { + // Note: Just for the test coverage, not a real test. e := echo.New() req, _ := http.NewRequest(echo.GET, "/", nil) rec := httptest.NewRecorder() c := echo.NewContext(req, echo.NewResponse(rec), e) - // Status 2xx + // With X-Real-IP + req.Header.Add(echo.XRealIP, "127.0.0.1") h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } Logger()(h)(c) + // With X-Forwarded-For + req.Header.Del(echo.XRealIP) + req.Header.Add(echo.XForwardedFor, "127.0.0.1") + h = func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + Logger()(h)(c) + + // Status 2xx + h = func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + Logger()(h)(c) + // Status 3xx rec = httptest.NewRecorder() c = echo.NewContext(req, echo.NewResponse(rec), e)