1
0
mirror of https://github.com/labstack/echo.git synced 2025-01-03 22:52:19 +02:00

refactor basic_auth_test to utilize table driven tests

This commit is contained in:
eolson 2024-10-16 12:37:13 -07:00 committed by Martti T.
parent 822d11a465
commit 03c0236fb3

View File

@ -5,6 +5,7 @@ package middleware
import ( import (
"encoding/base64" "encoding/base64"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -16,78 +17,110 @@ import (
func TestBasicAuth(t *testing.T) { func TestBasicAuth(t *testing.T) {
e := echo.New() e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder() mockValidator := func(u, p string, c echo.Context) (bool, error) {
c := e.NewContext(req, res)
f := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" { if u == "joe" && p == "secret" {
return true, nil return true, nil
} }
return false, nil return false, nil
} }
h := BasicAuth(f)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
// Valid credentials // Define the test cases
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) tests := []struct {
req.Header.Set(echo.HeaderAuthorization, auth) name string
assert.NoError(t, h(c)) authHeader string
expectedCode int
h = BasicAuthWithConfig(BasicAuthConfig{ expectedAuth string
Validator: f, skipperResult bool
Realm: "someRealm", expectedErr bool
})(func(c echo.Context) error { expectedErrMsg string
return c.String(http.StatusOK, "test") }{
}) {
name: "Valid credentials",
// Valid credentials authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) expectedCode: http.StatusOK,
req.Header.Set(echo.HeaderAuthorization, auth) skipperResult: false,
assert.NoError(t, h(c))
// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusBadRequest, he.Code)
// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(t, http.StatusUnauthorized, he.Code)
h = BasicAuthWithConfig(BasicAuthConfig{
Validator: f,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return true
}, },
})(func(c echo.Context) error { {
return c.String(http.StatusOK, "test") name: "Case-insensitive header scheme",
}) authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
expectedCode: http.StatusOK,
skipperResult: false,
},
{
name: "Invalid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
expectedCode: http.StatusUnauthorized,
expectedAuth: basic + ` realm="someRealm"`,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid base64 string",
authHeader: basic + " invalidString",
expectedCode: http.StatusBadRequest,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Bad Request",
},
{
name: "Missing Authorization header",
expectedCode: http.StatusUnauthorized,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid Authorization header",
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
expectedCode: http.StatusUnauthorized,
skipperResult: false,
expectedErr: true,
expectedErrMsg: "Unauthorized",
},
{
name: "Skipped Request",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
expectedCode: http.StatusOK,
skipperResult: true,
},
}
// Skipped Request for _, tt := range tests {
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")) t.Run(tt.name, func(t *testing.T) {
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(t, h(c))
req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
if tt.authHeader != "" {
req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
}
h := BasicAuthWithConfig(BasicAuthConfig{
Validator: mockValidator,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return tt.skipperResult
},
})(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
err := h(c)
if tt.expectedErr {
var he *echo.HTTPError
errors.As(err, &he)
assert.Equal(t, tt.expectedCode, he.Code)
if tt.expectedAuth != "" {
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
}
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedCode, res.Code)
}
})
}
} }