1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-12 01:22:21 +02:00

Merge pull request #192 from labstack/issue-180

Closes #180
This commit is contained in:
Vishal Rana 2015-09-01 11:32:12 -07:00
commit e0a40f864c
3 changed files with 8 additions and 6 deletions

View File

@ -134,6 +134,7 @@ const (
Location = "Location" Location = "Location"
Upgrade = "Upgrade" Upgrade = "Upgrade"
Vary = "Vary" Vary = "Vary"
WWWAuthenticate = "WWW-Authenticate"
//----------- //-----------
// Protocols // Protocols

View File

@ -18,7 +18,6 @@ const (
// BasicAuth returns an HTTP basic authentication middleware. // BasicAuth returns an HTTP basic authentication middleware.
// //
// For valid credentials it calls the next handler. // For valid credentials it calls the next handler.
// For invalid Authorization header it sends "404 - Bad Request" response.
// For invalid credentials, it sends "401 - Unauthorized" response. // For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc { func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
return func(c *echo.Context) error { return func(c *echo.Context) error {
@ -29,7 +28,6 @@ func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
auth := c.Request().Header.Get(echo.Authorization) auth := c.Request().Header.Get(echo.Authorization)
l := len(Basic) l := len(Basic)
he := echo.NewHTTPError(http.StatusBadRequest)
if len(auth) > l+1 && auth[:l] == Basic { if len(auth) > l+1 && auth[:l] == Basic {
b, err := base64.StdEncoding.DecodeString(auth[l+1:]) b, err := base64.StdEncoding.DecodeString(auth[l+1:])
@ -41,11 +39,11 @@ func BasicAuth(fn BasicValidateFunc) echo.HandlerFunc {
if fn(cred[:i], cred[i+1:]) { if fn(cred[:i], cred[i+1:]) {
return nil return nil
} }
he.SetCode(http.StatusUnauthorized) c.Response().Header().Set(echo.WWWAuthenticate, Basic + " realm=Restricted")
} }
} }
} }
} }
return he return echo.NewHTTPError(http.StatusUnauthorized)
} }
} }

View File

@ -36,17 +36,20 @@ func TestBasicAuth(t *testing.T) {
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
he := ba(c).(*echo.HTTPError) he := ba(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code()) assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic + " realm=Restricted", rec.Header().Get(echo.WWWAuthenticate))
// Empty Authorization header // Empty Authorization header
req.Header.Set(echo.Authorization, "") req.Header.Set(echo.Authorization, "")
he = ba(c).(*echo.HTTPError) he = ba(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code()) assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic + " realm=Restricted", rec.Header().Get(echo.WWWAuthenticate))
// Invalid Authorization header // Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid")) auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
he = ba(c).(*echo.HTTPError) he = ba(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code()) assert.Equal(t, http.StatusUnauthorized, he.Code())
assert.Equal(t, Basic + " realm=Restricted", rec.Header().Get(echo.WWWAuthenticate))
// WebSocket // WebSocket
c.Request().Header.Set(echo.Upgrade, echo.WebSocket) c.Request().Header.Set(echo.Upgrade, echo.WebSocket)