1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-06 03:54:17 +02:00

Fix authboss core tests

- Delete callbacks tests
- Remove some useless code (SendMail), as well as some extra arguments
  in certain functions that didn't require them.
- Remove tests for more code that has been moved to default
  implementations
This commit is contained in:
Aaron L 2018-02-01 11:51:43 -08:00
parent 2db3a3f782
commit cbfc1d8388
12 changed files with 44 additions and 512 deletions

View File

@ -11,6 +11,7 @@ func TestAuthBossInit(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
ab.ViewLoader = mockRenderLoader{}
ab.MailViewLoader = mockRenderLoader{}
err := ab.Init()
if err != nil {
t.Error("Unexpected error:", err)

View File

@ -1,197 +0,0 @@
package authboss
import (
"bytes"
"context"
"testing"
"github.com/pkg/errors"
)
func TestCallbacks(t *testing.T) {
t.Parallel()
ab := New()
afterCalled := false
beforeCalled := false
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
beforeCalled = true
return InterruptNone, nil
})
ab.Callbacks.After(EventRegister, func(ctx context.Context) error {
afterCalled = true
return nil
})
if beforeCalled || afterCalled {
t.Error("Neither should be called.")
}
interrupt, err := ab.Callbacks.FireBefore(EventRegister, context.TODO())
if err != nil {
t.Error("Unexpected error:", err)
}
if interrupt != InterruptNone {
t.Error("It should not have been stopped.")
}
if !beforeCalled {
t.Error("Expected before to have been called.")
}
if afterCalled {
t.Error("Expected after not to be called.")
}
ab.Callbacks.FireAfter(EventRegister, context.TODO())
if !afterCalled {
t.Error("Expected after to be called.")
}
}
func TestCallbacksInterrupt(t *testing.T) {
t.Parallel()
ab := New()
before1 := false
before2 := false
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before1 = true
return InterruptAccountLocked, nil
})
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before2 = true
return InterruptNone, nil
})
interrupt, err := ab.Callbacks.FireBefore(EventRegister, context.TODO())
if err != nil {
t.Error(err)
}
if interrupt != InterruptAccountLocked {
t.Error("The interrupt signal was not account locked:", interrupt)
}
if !before1 {
t.Error("Before1 should have been called.")
}
if before2 {
t.Error("Before2 should not have been called.")
}
}
func TestCallbacksBeforeErrors(t *testing.T) {
t.Parallel()
ab := New()
log := &bytes.Buffer{}
ab.LogWriter = log
before1 := false
before2 := false
errValue := errors.New("problem occured")
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before1 = true
return InterruptNone, errValue
})
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before2 = true
return InterruptNone, nil
})
interrupt, err := ab.Callbacks.FireBefore(EventRegister, context.TODO())
if err != errValue {
t.Error("Expected an error to come back.")
}
if interrupt != InterruptNone {
t.Error("It should not have been stopped.")
}
if !before1 {
t.Error("Before1 should have been called.")
}
if before2 {
t.Error("Before2 should not have been called.")
}
}
func TestCallbacksAfterErrors(t *testing.T) {
t.Parallel()
log := &bytes.Buffer{}
ab := New()
ab.LogWriter = log
after1 := false
after2 := false
errValue := errors.New("problem occured")
ab.Callbacks.After(EventRegister, func(ctx context.Context) error {
after1 = true
return errValue
})
ab.Callbacks.After(EventRegister, func(ctx context.Context) error {
after2 = true
return nil
})
err := ab.Callbacks.FireAfter(EventRegister, context.TODO())
if err != errValue {
t.Error("Expected an error to come back.")
}
if !after1 {
t.Error("After1 should have been called.")
}
if after2 {
t.Error("After2 should not have been called.")
}
}
func TestEventString(t *testing.T) {
t.Parallel()
tests := []struct {
ev Event
str string
}{
{EventRegister, "EventRegister"},
{EventAuth, "EventAuth"},
{EventOAuth, "EventOAuth"},
{EventAuthFail, "EventAuthFail"},
{EventOAuthFail, "EventOAuthFail"},
{EventRecoverStart, "EventRecoverStart"},
{EventRecoverEnd, "EventRecoverEnd"},
{EventGetUser, "EventGetUser"},
{EventGetUserSession, "EventGetUserSession"},
{EventPasswordReset, "EventPasswordReset"},
}
for i, test := range tests {
if got := test.ev.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.ev, test.str, got)
}
}
}
func TestInterruptString(t *testing.T) {
t.Parallel()
tests := []struct {
in Interrupt
str string
}{
{InterruptNone, "InterruptNone"},
{InterruptAccountLocked, "InterruptAccountLocked"},
{InterruptAccountNotConfirmed, "InterruptAccountNotConfirmed"},
{InterruptSessionExpired, "InterruptSessionExpired"},
}
for i, test := range tests {
if got := test.in.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.in, test.str, got)
}
}
}

View File

@ -40,7 +40,6 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
csrw := w.(*ClientStateResponseWriter)
w.WriteHeader(200)
// Check this doesn't panic
@ -52,7 +51,7 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
}
}()
csrw.putClientState()
w.putClientState()
}
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
@ -126,30 +125,28 @@ func TestStateResponseWriterEvents(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
csrw := w.(*ClientStateResponseWriter)
PutSession(w, "one", "two")
DelSession(w, "one")
DelCookie(w, "one")
PutCookie(w, "two", "one")
want := ClientStateEvent{Kind: ClientStateEventPut, Key: "one", Value: "two"}
if got := csrw.sessionStateEvents[0]; got != want {
if got := w.sessionStateEvents[0]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
if got := csrw.sessionStateEvents[1]; got != want {
if got := w.sessionStateEvents[1]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
if got := csrw.cookieStateEvents[0]; got != want {
if got := w.cookieStateEvents[0]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventPut, Key: "two", Value: "one"}
if got := csrw.cookieStateEvents[1]; got != want {
if got := w.cookieStateEvents[1]; got != want {
t.Error("event was wrong", got)
}
}
@ -162,7 +159,6 @@ func TestFlashClearer(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
csrw := w.(*ClientStateResponseWriter)
if msg := FlashSuccess(w, r); msg != "" {
t.Error("Unexpected flash success:", msg)
@ -187,11 +183,11 @@ func TestFlashClearer(t *testing.T) {
}
want := ClientStateEvent{Kind: ClientStateEventDel, Key: FlashSuccessKey}
if got := csrw.sessionStateEvents[0]; got != want {
if got := w.sessionStateEvents[0]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventDel, Key: FlashErrorKey}
if got := csrw.sessionStateEvents[1]; got != want {
if got := w.sessionStateEvents[1]; got != want {
t.Error("event was wrong", got)
}
}

View File

@ -18,7 +18,7 @@ func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *htt
func testSetupContext() (*Authboss, *http.Request) {
ab := New()
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
ab.StoreLoader = mockStoreLoader{
ab.Storer = mockServerStorer{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil))
@ -29,12 +29,9 @@ func testSetupContext() (*Authboss, *http.Request) {
func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
ab := New()
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
storer := mockStoredUser{
mockUser: wantUser,
}
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid")
ctx = context.WithValue(ctx, ctxKeyUser, storer)
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
req = req.WithContext(ctx)
return ab, wantUser, req
@ -43,7 +40,7 @@ func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
func testSetupContextPanic() *Authboss {
ab := New()
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
ab.StoreLoader = mockStoreLoader{}
ab.Storer = mockServerStorer{}
return ab
}
@ -207,10 +204,10 @@ func TestLoadCurrentUser(t *testing.T) {
t.Error("got:", got)
}
want := user.(mockStoredUser).mockUser
got := r.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
want := user.(mockUser)
got := r.Context().Value(ctxKeyUser).(mockUser)
if got != want {
t.Error("users mismatched:\nwant: %#v\ngot: %#v", want, got)
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got)
}
}
@ -224,9 +221,9 @@ func TestLoadCurrentUserContext(t *testing.T) {
t.Error(err)
}
got := user.(mockStoredUser).mockUser
got := user.(mockUser)
if got != wantUser {
t.Error("users mismatched:\nwant: %#v\ngot: %#v", wantUser, got)
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", wantUser, got)
}
}

View File

@ -76,6 +76,7 @@ func TestResponder(t *testing.T) {
t.Errorf("data mismatched:\nwant: %#v\ngot: %#v", expectData, gotData)
}
}
func TestRedirector(t *testing.T) {
t.Parallel()

View File

@ -10,7 +10,7 @@ var nowTime = time.Now
// TimeToExpiry returns zero if the user session is expired else the time
// until expiry. Takes in the allowed idle duration.
func TimeToExpiry(w http.ResponseWriter, r *http.Request, expireAfter time.Duration) time.Duration {
func TimeToExpiry(r *http.Request, expireAfter time.Duration) time.Duration {
return timeToExpiry(r, expireAfter)
}

View File

@ -54,36 +54,37 @@ func TestExpireIsExpired(t *testing.T) {
t.Error("expected user not to be present")
}
csrw := w.(*ClientStateResponseWriter)
want := ClientStateEvent{
Kind: ClientStateEventDel,
Key: SessionKey,
}
if got := csrw.sessionStateEvents[0]; got != want {
if got := w.sessionStateEvents[0]; got != want {
t.Error("want:", want, "got:", got)
}
want = ClientStateEvent{
Kind: ClientStateEventDel,
Key: SessionLastAction,
}
if got := csrw.sessionStateEvents[1]; got != want {
if got := w.sessionStateEvents[1]; got != want {
t.Error("want:", want, "got:", got)
}
}
func TestExpireNotExpired(t *testing.T) {
ab := New()
ab.Config.ExpireAfter = time.Hour
ab.SessionStateStorer = newMockClientStateRW(
SessionKey, "username",
SessionLastAction, time.Now().UTC().Format(time.RFC3339),
)
var err error
r := httptest.NewRequest("GET", "/", nil)
r = r.WithContext(context.WithValue(r.Context(), ctxKeyPID, "primaryid"))
r = r.WithContext(context.WithValue(r.Context(), ctxKeyUser, struct{}{}))
w := ab.NewResponse(httptest.NewRecorder(), r)
r, err := ab.LoadClientState(w, r)
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
@ -119,14 +120,13 @@ func TestExpireNotExpired(t *testing.T) {
t.Error("expected user to be present")
}
csrw := w.(*ClientStateResponseWriter)
want := ClientStateEvent{
Kind: ClientStateEventPut,
Key: SessionLastAction,
Value: newTime.Format(time.RFC3339),
}
if got := csrw.sessionStateEvents[0]; got != want {
if got := w.sessionStateEvents[0]; got != want {
t.Error("want:", want, "got:", got)
}
}
@ -134,12 +134,10 @@ func TestExpireNotExpired(t *testing.T) {
func TestExpireTimeToExpiry(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
want := 5 * time.Second
dur := TimeToExpiry(w, r, want)
dur := TimeToExpiry(r, want)
if dur != want {
t.Error("duration was wrong:", dur)
}
@ -153,11 +151,10 @@ func TestExpireRefreshExpiry(t *testing.T) {
w := ab.NewResponse(httptest.NewRecorder(), r)
RefreshExpiry(w, r)
csrw := w.(*ClientStateResponseWriter)
if got := csrw.sessionStateEvents[0].Kind; got != ClientStateEventPut {
if got := w.sessionStateEvents[0].Kind; got != ClientStateEventPut {
t.Error("wrong event:", got)
}
if got := csrw.sessionStateEvents[0].Key; got != SessionLastAction {
if got := w.sessionStateEvents[0].Key; got != SessionLastAction {
t.Error("wrong key:", got)
}
}

View File

@ -4,11 +4,6 @@ import (
"context"
)
// SendMail uses the currently configured mailer to deliver e-mails.
func (a *Authboss) SendMail(ctx context.Context, data Email) error {
return a.Mailer.Send(ctx, data)
}
// Mailer is a type that is capable of sending an e-mail.
type Mailer interface {
Send(context.Context, Email) error

View File

@ -17,66 +17,51 @@ type mockUser struct {
Password string
}
type mockStoredUser struct {
mockUser
mockStoreLoader
}
type mockServerStorer map[string]mockUser
type mockStoreLoader map[string]mockUser
func (m mockStoreLoader) Load(ctx context.Context, key string) (Storer, error) {
func (m mockServerStorer) Load(ctx context.Context, key string) (User, error) {
u, ok := m[key]
if !ok {
return nil, ErrUserNotFound
}
return mockStoredUser{
mockUser: u,
mockStoreLoader: m,
}, nil
return u, nil
}
func (m mockStoredUser) Load(ctx context.Context) error {
u, ok := m.mockStoreLoader[m.Email]
if !ok {
return ErrUserNotFound
func (m mockServerStorer) Save(ctx context.Context, user User) error {
e, err := user.GetEmail(ctx)
if err != nil {
panic(err)
}
m.Email = u.Email
m.Password = u.Password
m[e] = user.(mockUser)
return nil
}
func (m mockStoredUser) Save(ctx context.Context) error {
m.mockStoreLoader[m.Email] = m.mockUser
return nil
}
func (m mockStoredUser) PutEmail(ctx context.Context, email string) error {
func (m mockUser) PutEmail(ctx context.Context, email string) error {
m.Email = email
return nil
}
func (m mockStoredUser) PutUsername(ctx context.Context, username string) error {
func (m mockUser) PutUsername(ctx context.Context, username string) error {
return errors.New("not impl")
}
func (m mockStoredUser) PutPassword(ctx context.Context, password string) error {
func (m mockUser) PutPassword(ctx context.Context, password string) error {
m.Password = password
return nil
}
func (m mockStoredUser) GetEmail(ctx context.Context) (email string, err error) {
func (m mockUser) GetEmail(ctx context.Context) (email string, err error) {
return m.Email, nil
}
func (m mockStoredUser) GetUsername(ctx context.Context) (username string, err error) {
func (m mockUser) GetUsername(ctx context.Context) (username string, err error) {
return "", errors.New("not impl")
}
func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err error) {
func (m mockUser) GetPassword(ctx context.Context) (password string, err error) {
return m.Password, nil
}

View File

@ -1,72 +0,0 @@
package authboss
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type testMailer struct {
io.Writer
}
func (t testMailer) Send(_ context.Context, email Email) error {
fmt.Fprintf(t.Writer, "%v", email)
return nil
}
func TestResponseEmail(t *testing.T) {
t.Parallel()
ab := New()
ab.renderer = mockEmailRenderer{}
ab.SessionStateStorer = newMockClientStateRW(
FlashSuccessKey, "flash_success",
FlashErrorKey, "flash_error",
)
ab.XSRFName = "xsrf"
ab.XSRFMaker = func(w http.ResponseWriter, r *http.Request) string {
return "xsrftoken"
}
ab.LayoutDataMaker = func(w http.ResponseWriter, r *http.Request) HTMLData {
return HTMLData{"hello": "world"}
}
output := &bytes.Buffer{}
ab.Mailer = testMailer{output}
r := httptest.NewRequest("GET", "/", nil)
wr := httptest.NewRecorder()
w := ab.NewResponse(wr, r)
email := Email{
To: []string{"test@example.com"},
From: "test@example.com",
Subject: "subject",
}
ro := EmailResponseOptions{Data: nil, HTMLTemplate: "html", TextTemplate: "text"}
err := ab.Email(w, r, email, ro)
if err != nil {
t.Error(err)
}
wantStrings := []string{
"To: test@example.com",
"From: test@example.com",
"Subject: subject",
"development text e-mail",
"development html e-mail",
}
out := output.String()
for i, test := range wantStrings {
if !strings.Contains(out, test) {
t.Errorf("output missing string(%d): %s\n%s", i, test, out)
}
}
}

View File

@ -1,171 +0,0 @@
package authboss
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/pkg/errors"
)
const testRouterModName = "testrouter"
func init() {
RegisterModule(testRouterModName, testRouterModule{})
}
type testRouterModule struct {
routes RouteTable
}
func (t testRouterModule) Initialize(ab *Authboss) error { return nil }
func (t testRouterModule) Routes() RouteTable { return t.routes }
func (t testRouterModule) Templates() []string { return []string{"template1.tpl"} }
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
ab := New()
logger := &bytes.Buffer{}
ab.LogWriter = logger
ab.ViewLoader = mockRenderLoader{}
ab.Init(testRouterModName)
ab.MountPath = "/prefix"
//ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
//ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
logger.Reset() // Clear out the module load messages
return ab, ab.NewRouter(), logger
}
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
registeredModules[testRouterModName] = testRouterModule{
routes: map[string]HandlerFunc{path: h},
}
w = httptest.NewRecorder()
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)
return w, r
}
func TestRouter(t *testing.T) {
called := false
w, r := testRouterCallbackSetup("/called", func(http.ResponseWriter, *http.Request) error {
called = true
return nil
})
_, router, _ := testRouterSetup()
router.ServeHTTP(w, r)
if !called {
t.Error("Expected handler to be called.")
}
}
func TestRouter_NotFound(t *testing.T) {
ab, router, _ := testRouterSetup()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "http://localhost/wat", nil)
router.ServeHTTP(w, r)
if w.Code != http.StatusNotFound {
t.Error("Wrong code:", w.Code)
}
if body := w.Body.String(); body != "404 Page not found" {
t.Error("Wrong body:", body)
}
called := false
ab.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})
router.ServeHTTP(w, r)
if !called {
t.Error("Should be called.")
}
}
func TestRouter_BadRequest(t *testing.T) {
err := ClientDataErr{"what"}
w, r := testRouterCallbackSetup("/badrequest",
func(http.ResponseWriter, *http.Request) error {
return err
},
)
ab, router, logger := testRouterSetup()
logger.Reset()
router.ServeHTTP(w, r)
if w.Code != http.StatusBadRequest {
t.Error("Wrong code:", w.Code)
}
if body := w.Body.String(); body != "400 Bad request" {
t.Error("Wrong body:", body)
}
if str := logger.String(); !strings.Contains(str, err.Error()) {
t.Error(str)
}
called := false
ab.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})
logger.Reset()
router.ServeHTTP(w, r)
if !called {
t.Error("Should be called.")
}
if str := logger.String(); !strings.Contains(str, err.Error()) {
t.Error(str)
}
}
func TestRouter_Error(t *testing.T) {
err := errors.New("error")
w, r := testRouterCallbackSetup("/error",
func(http.ResponseWriter, *http.Request) error {
return err
},
)
ab, router, logger := testRouterSetup()
logger.Reset()
router.ServeHTTP(w, r)
if w.Code != http.StatusInternalServerError {
t.Error("Wrong code:", w.Code)
}
if body := w.Body.String(); body != "500 An error has occurred" {
t.Error("Wrong body:", body)
}
if str := logger.String(); !strings.Contains(str, err.Error()) {
t.Error(str)
}
called := false
ab.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})
logger.Reset()
router.ServeHTTP(w, r)
if !called {
t.Error("Should be called.")
}
if str := logger.String(); !strings.Contains(str, err.Error()) {
t.Error(str)
}
}

View File

@ -51,7 +51,7 @@ func TestErrorList_Map(t *testing.T) {
t.Error("Wrong number of fields:", len(m))
}
usernameErrs := m["email"]
usernameErrs := m["username"]
if len(usernameErrs) != 2 {
t.Error("Wrong number of username errors:", len(usernameErrs))
}