From cfd6d8b77feaee381a039bb039c6044a4200efec Mon Sep 17 00:00:00 2001 From: Vishal Rana Date: Thu, 21 May 2015 14:02:29 -0700 Subject: [PATCH] Better HTTP status in basic auth middleware Signed-off-by: Vishal Rana --- context.go | 2 +- middleware/auth.go | 23 ++++++++++++++--------- middleware/auth_test.go | 23 +++++++++++++++++------ 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/context.go b/context.go index 82df218d..e92dd875 100644 --- a/context.go +++ b/context.go @@ -100,7 +100,7 @@ func (c *Context) NoContent(code int) error { return nil } -// Error invokes the registered HTTP error handler. +// Error invokes the registered HTTP error handler. Usually used by middleware. func (c *Context) Error(err error) { c.echo.httpErrorHandler(err, c) } diff --git a/middleware/auth.go b/middleware/auth.go index 5bc0b2d0..09d5e4c0 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -14,12 +14,16 @@ const ( Basic = "Basic" ) -// BasicAuth returns an HTTP basic authentication middleware. +// BasicAuth returns an HTTP basic authentication middleware. For valid credentials +// it calls the next handler in the chain. + +// For invalid Authorization header it sends "404 - Bad Request" response. +// For invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn AuthFunc) echo.HandlerFunc { return func(c *echo.Context) error { auth := c.Request.Header.Get(echo.Authorization) i := 0 - he := echo.NewHTTPError(http.StatusUnauthorized) + code := http.StatusBadRequest for ; i < len(auth); i++ { c := auth[i] @@ -33,31 +37,32 @@ func BasicAuth(fn AuthFunc) echo.HandlerFunc { // Ignore case if i == 0 { if c != Basic[i] && c != 'b' { - return he + break } } else { if c != Basic[i] { - return he + break } } } else { // Extract credentials b, err := base64.StdEncoding.DecodeString(auth[i:]) if err != nil { - return he + break } cred := string(b) for i := 0; i < len(cred); i++ { if cred[i] == ':' { // Verify credentials - if !fn(cred[:i], cred[i+1:]) { - return he + if fn(cred[:i], cred[i+1:]) { + return nil } - return nil + code = http.StatusUnauthorized + break } } } } - return he + return echo.NewHTTPError(code) } } diff --git a/middleware/auth_test.go b/middleware/auth_test.go index 24ceed00..9933b434 100644 --- a/middleware/auth_test.go +++ b/middleware/auth_test.go @@ -42,34 +42,45 @@ func TestBasicAuth(t *testing.T) { //--------------------- // Incorrect password - auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe: password")) + auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password")) req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) + he := ba(c).(*echo.HTTPError) if ba(c) == nil { t.Error("expected `fail`, with incorrect password.") + } else if he.Code != http.StatusUnauthorized { + t.Errorf("expected status `401`, got %d", he.Code) } // Empty Authorization header req.Header.Set(echo.Authorization, "") ba = BasicAuth(fn) - if ba(c) == nil { + he = ba(c).(*echo.HTTPError) + if he == nil { t.Error("expected `fail`, with empty Authorization header.") + } else if he.Code != http.StatusBadRequest { + t.Errorf("expected status `400`, got %d", he.Code) } // Invalid Authorization header auth = base64.StdEncoding.EncodeToString([]byte(" :secret")) req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) - if ba(c) == nil { + he = ba(c).(*echo.HTTPError) + if he == nil { t.Error("expected `fail`, with invalid Authorization header.") + } else if he.Code != http.StatusBadRequest { + t.Errorf("expected status `400`, got %d", he.Code) } // Invalid scheme - auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret")) + auth = "Ace " + base64.StdEncoding.EncodeToString([]byte(" :secret")) req.Header.Set(echo.Authorization, auth) ba = BasicAuth(fn) - if ba(c) == nil { + he = ba(c).(*echo.HTTPError) + if he == nil { t.Error("expected `fail`, with invalid scheme.") + } else if he.Code != http.StatusBadRequest { + t.Errorf("expected status `400`, got %d", he.Code) } - }