1
0
mirror of https://github.com/labstack/echo.git synced 2024-12-24 20:14:31 +02:00

Better HTTP status in basic auth middleware

Signed-off-by: Vishal Rana <vr@labstack.com>
This commit is contained in:
Vishal Rana 2015-05-21 14:02:29 -07:00
parent 3fcf7a470d
commit cfd6d8b77f
3 changed files with 32 additions and 16 deletions

View File

@ -100,7 +100,7 @@ func (c *Context) NoContent(code int) error {
return nil return nil
} }
// Error invokes the registered HTTP error handler. // Error invokes the registered HTTP error handler. Usually used by middleware.
func (c *Context) Error(err error) { func (c *Context) Error(err error) {
c.echo.httpErrorHandler(err, c) c.echo.httpErrorHandler(err, c)
} }

View File

@ -14,12 +14,16 @@ const (
Basic = "Basic" Basic = "Basic"
) )
// BasicAuth returns an HTTP basic authentication middleware. // BasicAuth returns an HTTP basic authentication middleware. For valid credentials
// it calls the next handler in the chain.
// For invalid Authorization header it sends "404 - Bad Request" response.
// For invalid credentials, it sends "401 - Unauthorized" response.
func BasicAuth(fn AuthFunc) echo.HandlerFunc { func BasicAuth(fn AuthFunc) echo.HandlerFunc {
return func(c *echo.Context) error { return func(c *echo.Context) error {
auth := c.Request.Header.Get(echo.Authorization) auth := c.Request.Header.Get(echo.Authorization)
i := 0 i := 0
he := echo.NewHTTPError(http.StatusUnauthorized) code := http.StatusBadRequest
for ; i < len(auth); i++ { for ; i < len(auth); i++ {
c := auth[i] c := auth[i]
@ -33,31 +37,32 @@ func BasicAuth(fn AuthFunc) echo.HandlerFunc {
// Ignore case // Ignore case
if i == 0 { if i == 0 {
if c != Basic[i] && c != 'b' { if c != Basic[i] && c != 'b' {
return he break
} }
} else { } else {
if c != Basic[i] { if c != Basic[i] {
return he break
} }
} }
} else { } else {
// Extract credentials // Extract credentials
b, err := base64.StdEncoding.DecodeString(auth[i:]) b, err := base64.StdEncoding.DecodeString(auth[i:])
if err != nil { if err != nil {
return he break
} }
cred := string(b) cred := string(b)
for i := 0; i < len(cred); i++ { for i := 0; i < len(cred); i++ {
if cred[i] == ':' { if cred[i] == ':' {
// Verify credentials // Verify credentials
if !fn(cred[:i], cred[i+1:]) { if fn(cred[:i], cred[i+1:]) {
return he
}
return nil return nil
} }
code = http.StatusUnauthorized
break
} }
} }
} }
return he }
return echo.NewHTTPError(code)
} }
} }

View File

@ -42,34 +42,45 @@ func TestBasicAuth(t *testing.T) {
//--------------------- //---------------------
// Incorrect password // Incorrect password
auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe: password")) auth = Basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:password"))
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn) ba = BasicAuth(fn)
he := ba(c).(*echo.HTTPError)
if ba(c) == nil { if ba(c) == nil {
t.Error("expected `fail`, with incorrect password.") t.Error("expected `fail`, with incorrect password.")
} else if he.Code != http.StatusUnauthorized {
t.Errorf("expected status `401`, got %d", he.Code)
} }
// Empty Authorization header // Empty Authorization header
req.Header.Set(echo.Authorization, "") req.Header.Set(echo.Authorization, "")
ba = BasicAuth(fn) ba = BasicAuth(fn)
if ba(c) == nil { he = ba(c).(*echo.HTTPError)
if he == nil {
t.Error("expected `fail`, with empty Authorization header.") t.Error("expected `fail`, with empty Authorization header.")
} else if he.Code != http.StatusBadRequest {
t.Errorf("expected status `400`, got %d", he.Code)
} }
// Invalid Authorization header // Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte(" :secret")) auth = base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn) ba = BasicAuth(fn)
if ba(c) == nil { he = ba(c).(*echo.HTTPError)
if he == nil {
t.Error("expected `fail`, with invalid Authorization header.") t.Error("expected `fail`, with invalid Authorization header.")
} else if he.Code != http.StatusBadRequest {
t.Errorf("expected status `400`, got %d", he.Code)
} }
// Invalid scheme // Invalid scheme
auth = "Base " + base64.StdEncoding.EncodeToString([]byte(" :secret")) auth = "Ace " + base64.StdEncoding.EncodeToString([]byte(" :secret"))
req.Header.Set(echo.Authorization, auth) req.Header.Set(echo.Authorization, auth)
ba = BasicAuth(fn) ba = BasicAuth(fn)
if ba(c) == nil { he = ba(c).(*echo.HTTPError)
if he == nil {
t.Error("expected `fail`, with invalid scheme.") t.Error("expected `fail`, with invalid scheme.")
} else if he.Code != http.StatusBadRequest {
t.Errorf("expected status `400`, got %d", he.Code)
} }
} }