1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-20 02:59:54 +02:00

refactor basic_auth_test to utilize table driven tests

This commit is contained in:
eolson 2024-10-16 12:37:13 -07:00 committed by Martti T.
parent 822d11a465
commit 03c0236fb3

View File

@ -5,6 +5,7 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -16,78 +17,110 @@ import (
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder() mockValidator := func(u, p string, c echo.Context) (bool, error) {
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" { if u == "joe" && p == "secret" {
return true, nil return true, nil
} }
return false, nil return false, nil
} }
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials // Define the test cases
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) tests := []struct {
req.Header.Set(echo.HeaderAuthorization, auth) name string
assert.NoError(t, h(c)) authHeader string
expectedCode int
expectedAuth string
skipperResult bool
expectedErr bool
expectedErrMsg string
}{
{
name: "Valid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
skipperResult: false,
},
{
name: "Case-insensitive header scheme",
authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
skipperResult: false,
},
{
name: "Invalid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
expectedCode: http.StatusUnauthorized,
expectedAuth: basic + ` realm="someRealm"`,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid base64 string",
authHeader: basic + " invalidString",
expectedCode: http.StatusBadRequest,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Bad Request",
},
{
name: "Missing Authorization header",
expectedCode: http.StatusUnauthorized,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid Authorization header",
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
expectedCode: http.StatusUnauthorized,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Skipped Request",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
expectedCode: http.StatusOK,
skipperResult: true,
},
}
h = BasicAuthWithConfig(BasicAuthConfig{ for _, tt := range tests {
Validator: f, t.Run(tt.name, func(t *testing.T) {
Realm: "someRealm",
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials req := httptest.NewRequest(http.MethodGet, "/", nil)
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) res := httptest.NewRecorder()
req.Header.Set(echo.HeaderAuthorization, auth) c := e.NewContext(req, res)
assert.NoError(t, h(c))
// Case-insensitive header scheme if tt.authHeader != "" {
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
req.Header.Set(echo.HeaderAuthorization, auth) }
assert.NoError(t, h(c))
// Invalid credentials h := BasicAuthWithConfig(BasicAuthConfig{
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) Validator: mockValidator,
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
h = BasicAuthWithConfig(BasicAuthConfig{
Validator: f,
Realm: "someRealm", Realm: "someRealm",
Skipper: func(c echo.Context) bool { Skipper: func(c echo.Context) bool {
return true return tt.skipperResult
}, },
})(func(c echo.Context) error { })(func(c echo.Context) error {
return c.String(http.StatusOK, "test") return c.String(http.StatusOK, "test")
}) })
// Skipped Request err := h(c)
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
if tt.expectedErr {
var he *echo.HTTPError
errors.As(err, &he)
assert.Equal(t, tt.expectedCode, he.Code)
if tt.expectedAuth != "" {
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, res.Code)
}
})
}
} }