From 8aaf620c2d1836f73b49764d37fdb7de0d9b1a94 Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Tue, 1 Sep 2015 11:24:36 -0700 Subject: [PATCH] Closes #180 Signed-off-by: Vishal Rana --- echo.go | 1 + middleware/auth.go | 6 ++---- middleware/auth_test.go | 7 +++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/echo.go b/echo.go index 54d3ce2a..07c0072d 100644 --- a/echo.go +++ b/echo.go @@ -133,6 +133,7 @@ const ( Location = "Location" Upgrade = "Upgrade" Vary = "Vary" + WWWAuthenticate = "WWW-Authenticate" //----------- // Protocols diff --git a/middleware/auth.go b/middleware/auth.go index 1dc5c4ea..99a89f52 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -18,7 +18,6 @@ const ( // BasicAuth returns an HTTP basic authentication middleware. // // For valid credentials it calls the next handler. -// For invalid Authorization header it sends "404 - Bad Request" response. // For invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { return func(c *echo.Context) error { @@ -29,7 +28,6 @@ func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { auth := c.Request().Header.Get(echo.Authorization) l := len(Basic) - he := echo.NewHTTPError(http.StatusBadRequest) if len(auth) > l+1 && auth[:l] == Basic { b, err := base64.StdEncoding.DecodeString(auth[l+1:]) @@ -41,11 +39,11 @@ func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { if fn(cred[:i], cred[i+1:]) { return nil } - he.SetCode(http.StatusUnauthorized) + c.Response().Header().Set(echo.WWWAuthenticate, Basic + " realm=Restricted") } } } } - return he + return echo.NewHTTPError(http.StatusUnauthorized) } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index c953d927..278a75ef 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -36,17 +36,20 @@ func TestBasicAuth(t *testing.T) { req.Header.Set(echo.Authorization, auth) he := ba(c).(*echo.HTTPError) assert.Equal(t, http.StatusUnauthorized, he.Code()) + assert.Equal(t, Basic + " realm=Restricted", rec.Header().Get(echo.WWWAuthenticate)) // Empty Authorization header req.Header.Set(echo.Authorization, "") he = ba(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code()) + assert.Equal(t, http.StatusUnauthorized, he.Code()) + assert.Equal(t, Basic + " realm=Restricted", rec.Header().Get(echo.WWWAuthenticate)) // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte("invalid")) req.Header.Set(echo.Authorization, auth) he = ba(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code()) + assert.Equal(t, http.StatusUnauthorized, he.Code()) + assert.Equal(t, Basic + " realm=Restricted", rec.Header().Get(echo.WWWAuthenticate)) // WebSocket c.Request().Header.Set(echo.Upgrade, echo.WebSocket)