mirror of
https://github.com/labstack/echo.git
synced 2025-01-20 02:59:54 +02:00
refactor basic_auth_test to utilize table driven tests
This commit is contained in:
parent
822d11a465
commit
03c0236fb3
@ -5,6 +5,7 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@ -16,78 +17,110 @@ import (
|
|||||||
|
|
||||||
func TestBasicAuth(t *testing.T) {
|
func TestBasicAuth(t *testing.T) {
|
||||||
e := echo.New()
|
e := echo.New()
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
res := httptest.NewRecorder()
|
mockValidator := func(u, p string, c echo.Context) (bool, error) {
|
||||||
c := e.NewContext(req, res)
|
|
||||||
f := func(u, p string, c echo.Context) (bool, error) {
|
|
||||||
if u == "joe" && p == "secret" {
|
if u == "joe" && p == "secret" {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
h := BasicAuth(f)(func(c echo.Context) error {
|
|
||||||
return c.String(http.StatusOK, "test")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Valid credentials
|
// Define the test cases
|
||||||
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
tests := []struct {
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
name string
|
||||||
assert.NoError(t, h(c))
|
authHeader string
|
||||||
|
expectedCode int
|
||||||
|
expectedAuth string
|
||||||
|
skipperResult bool
|
||||||
|
expectedErr bool
|
||||||
|
expectedErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid credentials",
|
||||||
|
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
skipperResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Case-insensitive header scheme",
|
||||||
|
authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
skipperResult: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid credentials",
|
||||||
|
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
|
||||||
|
expectedCode: http.StatusUnauthorized,
|
||||||
|
expectedAuth: basic + ` realm="someRealm"`,
|
||||||
|
skipperResult: false,
|
||||||
|
expectedErr: true,
|
||||||
|
expectedErrMsg: "Unauthorized",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid base64 string",
|
||||||
|
authHeader: basic + " invalidString",
|
||||||
|
expectedCode: http.StatusBadRequest,
|
||||||
|
skipperResult: false,
|
||||||
|
expectedErr: true,
|
||||||
|
expectedErrMsg: "Bad Request",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing Authorization header",
|
||||||
|
expectedCode: http.StatusUnauthorized,
|
||||||
|
skipperResult: false,
|
||||||
|
expectedErr: true,
|
||||||
|
expectedErrMsg: "Unauthorized",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid Authorization header",
|
||||||
|
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
|
||||||
|
expectedCode: http.StatusUnauthorized,
|
||||||
|
skipperResult: false,
|
||||||
|
expectedErr: true,
|
||||||
|
expectedErrMsg: "Unauthorized",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Skipped Request",
|
||||||
|
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
|
||||||
|
expectedCode: http.StatusOK,
|
||||||
|
skipperResult: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
for _, tt := range tests {
|
||||||
Validator: f,
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
Realm: "someRealm",
|
|
||||||
})(func(c echo.Context) error {
|
|
||||||
return c.String(http.StatusOK, "test")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Valid credentials
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
res := httptest.NewRecorder()
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
c := e.NewContext(req, res)
|
||||||
assert.NoError(t, h(c))
|
|
||||||
|
|
||||||
// Case-insensitive header scheme
|
if tt.authHeader != "" {
|
||||||
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
|
req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
}
|
||||||
assert.NoError(t, h(c))
|
|
||||||
|
|
||||||
// Invalid credentials
|
h := BasicAuthWithConfig(BasicAuthConfig{
|
||||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
|
Validator: mockValidator,
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
|
||||||
he := h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
|
||||||
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
|
|
||||||
|
|
||||||
// Invalid base64 string
|
|
||||||
auth = basic + " invalidString"
|
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
|
||||||
he = h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusBadRequest, he.Code)
|
|
||||||
|
|
||||||
// Missing Authorization header
|
|
||||||
req.Header.Del(echo.HeaderAuthorization)
|
|
||||||
he = h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
|
||||||
|
|
||||||
// Invalid Authorization header
|
|
||||||
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
|
||||||
he = h(c).(*echo.HTTPError)
|
|
||||||
assert.Equal(t, http.StatusUnauthorized, he.Code)
|
|
||||||
|
|
||||||
h = BasicAuthWithConfig(BasicAuthConfig{
|
|
||||||
Validator: f,
|
|
||||||
Realm: "someRealm",
|
Realm: "someRealm",
|
||||||
Skipper: func(c echo.Context) bool {
|
Skipper: func(c echo.Context) bool {
|
||||||
return true
|
return tt.skipperResult
|
||||||
},
|
},
|
||||||
})(func(c echo.Context) error {
|
})(func(c echo.Context) error {
|
||||||
return c.String(http.StatusOK, "test")
|
return c.String(http.StatusOK, "test")
|
||||||
})
|
})
|
||||||
|
|
||||||
// Skipped Request
|
err := h(c)
|
||||||
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))
|
|
||||||
req.Header.Set(echo.HeaderAuthorization, auth)
|
|
||||||
assert.NoError(t, h(c))
|
|
||||||
|
|
||||||
|
if tt.expectedErr {
|
||||||
|
var he *echo.HTTPError
|
||||||
|
errors.As(err, &he)
|
||||||
|
assert.Equal(t, tt.expectedCode, he.Code)
|
||||||
|
if tt.expectedAuth != "" {
|
||||||
|
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.expectedCode, res.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user