diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 6e07065b..788c339e 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -5,6 +5,7 @@ package middleware import ( "encoding/base64" + "errors" "net/http" "net/http/httptest" "strings" @@ -16,78 +17,110 @@ import ( func TestBasicAuth(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - res := httptest.NewRecorder() - c := e.NewContext(req, res) - f := func(u, p string, c echo.Context) (bool, error) { + + mockValidator := func(u, p string, c echo.Context) (bool, error) { if u == "joe" && p == "secret" { return true, nil } return false, nil } - h := BasicAuth(f)(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - // Valid credentials - auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - - h = BasicAuthWithConfig(BasicAuthConfig{ - Validator: f, - Realm: "someRealm", - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) - - // Valid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - - // Case-insensitive header scheme - auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) - - // Invalid credentials - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")) - req.Header.Set(echo.HeaderAuthorization, auth) - he := h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate)) - - // Invalid base64 string - auth = basic + " invalidString" - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusBadRequest, he.Code) - - // Missing Authorization header - req.Header.Del(echo.HeaderAuthorization) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - - // Invalid Authorization header - auth = base64.StdEncoding.EncodeToString([]byte("invalid")) - req.Header.Set(echo.HeaderAuthorization, auth) - he = h(c).(*echo.HTTPError) - assert.Equal(t, http.StatusUnauthorized, he.Code) - - h = BasicAuthWithConfig(BasicAuthConfig{ - Validator: f, - Realm: "someRealm", - Skipper: func(c echo.Context) bool { - return true + // Define the test cases + tests := []struct { + name string + authHeader string + expectedCode int + expectedAuth string + skipperResult bool + expectedErr bool + expectedErrMsg string + }{ + { + name: "Valid credentials", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + expectedCode: http.StatusOK, + skipperResult: false, }, - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + { + name: "Case-insensitive header scheme", + authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + expectedCode: http.StatusOK, + skipperResult: false, + }, + { + name: "Invalid credentials", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + expectedCode: http.StatusUnauthorized, + expectedAuth: basic + ` realm="someRealm"`, + skipperResult: false, + expectedErr: true, + expectedErrMsg: "Unauthorized", + }, + { + name: "Invalid base64 string", + authHeader: basic + " invalidString", + expectedCode: http.StatusBadRequest, + skipperResult: false, + expectedErr: true, + expectedErrMsg: "Bad Request", + }, + { + name: "Missing Authorization header", + expectedCode: http.StatusUnauthorized, + skipperResult: false, + expectedErr: true, + expectedErrMsg: "Unauthorized", + }, + { + name: "Invalid Authorization header", + authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), + expectedCode: http.StatusUnauthorized, + skipperResult: false, + expectedErr: true, + expectedErrMsg: "Unauthorized", + }, + { + name: "Skipped Request", + authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), + expectedCode: http.StatusOK, + skipperResult: true, + }, + } - // Skipped Request - auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")) - req.Header.Set(echo.HeaderAuthorization, auth) - assert.NoError(t, h(c)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + c := e.NewContext(req, res) + + if tt.authHeader != "" { + req.Header.Set(echo.HeaderAuthorization, tt.authHeader) + } + + h := BasicAuthWithConfig(BasicAuthConfig{ + Validator: mockValidator, + Realm: "someRealm", + Skipper: func(c echo.Context) bool { + return tt.skipperResult + }, + })(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err := h(c) + + if tt.expectedErr { + var he *echo.HTTPError + errors.As(err, &he) + assert.Equal(t, tt.expectedCode, he.Code) + if tt.expectedAuth != "" { + assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedCode, res.Code) + } + }) + } }