mirror of
https://github.com/labstack/echo.git
synced 2025-01-01 22:09:21 +02:00
refactor basic_auth_test to utilize table driven tests
This commit is contained in:
parent
822d11a465
commit
03c0236fb3
@ -5,6 +5,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -16,78 +17,110 @@ import (
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
f := func(u, p string, c echo.Context) (bool, error) {
|
||||
|
||||
mockValidator := func(u, p string, c echo.Context) (bool, error) {
|
||||
if u == "joe" && p == "secret" {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
h := BasicAuth(f)(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Valid credentials
|
||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
||||
Validator: f,
|
||||
Realm: "someRealm",
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
// Valid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
// Case-insensitive header scheme
|
||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
|
||||
// Invalid credentials
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he := h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
|
||||
// Invalid base64 string
|
||||
auth = basic + " invalidString"
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
||||
|
||||
// Missing Authorization header
|
||||
req.Header.Del(echo.HeaderAuthorization)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
|
||||
// Invalid Authorization header
|
||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
he = h(c).(*echo.HTTPError)
|
||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
||||
|
||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
||||
Validator: f,
|
||||
Realm: "someRealm",
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return true
|
||||
// Define the test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectedCode int
|
||||
expectedAuth string
|
||||
skipperResult bool
|
||||
expectedErr bool
|
||||
expectedErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid credentials",
|
||||
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
|
||||
expectedCode: http.StatusOK,
|
||||
skipperResult: false,
|
||||
},
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
{
|
||||
name: "Case-insensitive header scheme",
|
||||
authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
|
||||
expectedCode: http.StatusOK,
|
||||
skipperResult: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid credentials",
|
||||
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
expectedAuth: basic + ` realm="someRealm"`,
|
||||
skipperResult: false,
|
||||
expectedErr: true,
|
||||
expectedErrMsg: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "Invalid base64 string",
|
||||
authHeader: basic + " invalidString",
|
||||
expectedCode: http.StatusBadRequest,
|
||||
skipperResult: false,
|
||||
expectedErr: true,
|
||||
expectedErrMsg: "Bad Request",
|
||||
},
|
||||
{
|
||||
name: "Missing Authorization header",
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
skipperResult: false,
|
||||
expectedErr: true,
|
||||
expectedErrMsg: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "Invalid Authorization header",
|
||||
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
|
||||
expectedCode: http.StatusUnauthorized,
|
||||
skipperResult: false,
|
||||
expectedErr: true,
|
||||
expectedErrMsg: "Unauthorized",
|
||||
},
|
||||
{
|
||||
name: "Skipped Request",
|
||||
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
|
||||
expectedCode: http.StatusOK,
|
||||
skipperResult: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Skipped Request
|
||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))
|
||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
||||
assert.NoError(t, h(c))
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
res := httptest.NewRecorder()
|
||||
c := e.NewContext(req, res)
|
||||
|
||||
if tt.authHeader != "" {
|
||||
req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
|
||||
}
|
||||
|
||||
h := BasicAuthWithConfig(BasicAuthConfig{
|
||||
Validator: mockValidator,
|
||||
Realm: "someRealm",
|
||||
Skipper: func(c echo.Context) bool {
|
||||
return tt.skipperResult
|
||||
},
|
||||
})(func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "test")
|
||||
})
|
||||
|
||||
err := h(c)
|
||||
|
||||
if tt.expectedErr {
|
||||
var he *echo.HTTPError
|
||||
errors.As(err, &he)
|
||||
assert.Equal(t, tt.expectedCode, he.Code)
|
||||
if tt.expectedAuth != "" {
|
||||
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedCode, res.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user