mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-06 03:54:17 +02:00
Fix authboss core tests
- Delete callbacks tests - Remove some useless code (SendMail), as well as some extra arguments in certain functions that didn't require them. - Remove tests for more code that has been moved to default implementations
This commit is contained in:
parent
2db3a3f782
commit
cbfc1d8388
@ -11,6 +11,7 @@ func TestAuthBossInit(t *testing.T) {
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.MailViewLoader = mockRenderLoader{}
|
||||
err := ab.Init()
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
|
@ -1,197 +0,0 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestCallbacks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
afterCalled := false
|
||||
beforeCalled := false
|
||||
|
||||
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
|
||||
beforeCalled = true
|
||||
return InterruptNone, nil
|
||||
})
|
||||
ab.Callbacks.After(EventRegister, func(ctx context.Context) error {
|
||||
afterCalled = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if beforeCalled || afterCalled {
|
||||
t.Error("Neither should be called.")
|
||||
}
|
||||
|
||||
interrupt, err := ab.Callbacks.FireBefore(EventRegister, context.TODO())
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
if interrupt != InterruptNone {
|
||||
t.Error("It should not have been stopped.")
|
||||
}
|
||||
|
||||
if !beforeCalled {
|
||||
t.Error("Expected before to have been called.")
|
||||
}
|
||||
if afterCalled {
|
||||
t.Error("Expected after not to be called.")
|
||||
}
|
||||
|
||||
ab.Callbacks.FireAfter(EventRegister, context.TODO())
|
||||
if !afterCalled {
|
||||
t.Error("Expected after to be called.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksInterrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
before1 := false
|
||||
before2 := false
|
||||
|
||||
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
|
||||
before1 = true
|
||||
return InterruptAccountLocked, nil
|
||||
})
|
||||
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
|
||||
before2 = true
|
||||
return InterruptNone, nil
|
||||
})
|
||||
|
||||
interrupt, err := ab.Callbacks.FireBefore(EventRegister, context.TODO())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if interrupt != InterruptAccountLocked {
|
||||
t.Error("The interrupt signal was not account locked:", interrupt)
|
||||
}
|
||||
|
||||
if !before1 {
|
||||
t.Error("Before1 should have been called.")
|
||||
}
|
||||
if before2 {
|
||||
t.Error("Before2 should not have been called.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksBeforeErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
log := &bytes.Buffer{}
|
||||
ab.LogWriter = log
|
||||
before1 := false
|
||||
before2 := false
|
||||
|
||||
errValue := errors.New("problem occured")
|
||||
|
||||
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
|
||||
before1 = true
|
||||
return InterruptNone, errValue
|
||||
})
|
||||
ab.Callbacks.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
|
||||
before2 = true
|
||||
return InterruptNone, nil
|
||||
})
|
||||
|
||||
interrupt, err := ab.Callbacks.FireBefore(EventRegister, context.TODO())
|
||||
if err != errValue {
|
||||
t.Error("Expected an error to come back.")
|
||||
}
|
||||
if interrupt != InterruptNone {
|
||||
t.Error("It should not have been stopped.")
|
||||
}
|
||||
|
||||
if !before1 {
|
||||
t.Error("Before1 should have been called.")
|
||||
}
|
||||
if before2 {
|
||||
t.Error("Before2 should not have been called.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksAfterErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := &bytes.Buffer{}
|
||||
ab := New()
|
||||
ab.LogWriter = log
|
||||
after1 := false
|
||||
after2 := false
|
||||
|
||||
errValue := errors.New("problem occured")
|
||||
|
||||
ab.Callbacks.After(EventRegister, func(ctx context.Context) error {
|
||||
after1 = true
|
||||
return errValue
|
||||
})
|
||||
ab.Callbacks.After(EventRegister, func(ctx context.Context) error {
|
||||
after2 = true
|
||||
return nil
|
||||
})
|
||||
|
||||
err := ab.Callbacks.FireAfter(EventRegister, context.TODO())
|
||||
if err != errValue {
|
||||
t.Error("Expected an error to come back.")
|
||||
}
|
||||
|
||||
if !after1 {
|
||||
t.Error("After1 should have been called.")
|
||||
}
|
||||
if after2 {
|
||||
t.Error("After2 should not have been called.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
ev Event
|
||||
str string
|
||||
}{
|
||||
{EventRegister, "EventRegister"},
|
||||
{EventAuth, "EventAuth"},
|
||||
{EventOAuth, "EventOAuth"},
|
||||
{EventAuthFail, "EventAuthFail"},
|
||||
{EventOAuthFail, "EventOAuthFail"},
|
||||
{EventRecoverStart, "EventRecoverStart"},
|
||||
{EventRecoverEnd, "EventRecoverEnd"},
|
||||
{EventGetUser, "EventGetUser"},
|
||||
{EventGetUserSession, "EventGetUserSession"},
|
||||
{EventPasswordReset, "EventPasswordReset"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
if got := test.ev.String(); got != test.str {
|
||||
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.ev, test.str, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterruptString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
in Interrupt
|
||||
str string
|
||||
}{
|
||||
{InterruptNone, "InterruptNone"},
|
||||
{InterruptAccountLocked, "InterruptAccountLocked"},
|
||||
{InterruptAccountNotConfirmed, "InterruptAccountNotConfirmed"},
|
||||
{InterruptSessionExpired, "InterruptSessionExpired"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
if got := test.in.String(); got != test.str {
|
||||
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.in, test.str, got)
|
||||
}
|
||||
}
|
||||
}
|
@ -40,7 +40,6 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
w.WriteHeader(200)
|
||||
// Check this doesn't panic
|
||||
@ -52,7 +51,7 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
csrw.putClientState()
|
||||
w.putClientState()
|
||||
}
|
||||
|
||||
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
|
||||
@ -126,30 +125,28 @@ func TestStateResponseWriterEvents(t *testing.T) {
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
DelSession(w, "one")
|
||||
DelCookie(w, "one")
|
||||
PutCookie(w, "two", "one")
|
||||
|
||||
want := ClientStateEvent{Kind: ClientStateEventPut, Key: "one", Value: "two"}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
if got := w.sessionStateEvents[0]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
|
||||
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
|
||||
if got := csrw.sessionStateEvents[1]; got != want {
|
||||
if got := w.sessionStateEvents[1]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
|
||||
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
|
||||
if got := csrw.cookieStateEvents[0]; got != want {
|
||||
if got := w.cookieStateEvents[0]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
|
||||
want = ClientStateEvent{Kind: ClientStateEventPut, Key: "two", Value: "one"}
|
||||
if got := csrw.cookieStateEvents[1]; got != want {
|
||||
if got := w.cookieStateEvents[1]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
}
|
||||
@ -162,7 +159,6 @@ func TestFlashClearer(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
if msg := FlashSuccess(w, r); msg != "" {
|
||||
t.Error("Unexpected flash success:", msg)
|
||||
@ -187,11 +183,11 @@ func TestFlashClearer(t *testing.T) {
|
||||
}
|
||||
|
||||
want := ClientStateEvent{Kind: ClientStateEventDel, Key: FlashSuccessKey}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
if got := w.sessionStateEvents[0]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
want = ClientStateEvent{Kind: ClientStateEventDel, Key: FlashErrorKey}
|
||||
if got := csrw.sessionStateEvents[1]; got != want {
|
||||
if got := w.sessionStateEvents[1]; got != want {
|
||||
t.Error("event was wrong", got)
|
||||
}
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *htt
|
||||
func testSetupContext() (*Authboss, *http.Request) {
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
|
||||
ab.StoreLoader = mockStoreLoader{
|
||||
ab.Storer = mockServerStorer{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil))
|
||||
@ -29,12 +29,9 @@ func testSetupContext() (*Authboss, *http.Request) {
|
||||
func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
|
||||
ab := New()
|
||||
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
|
||||
storer := mockStoredUser{
|
||||
mockUser: wantUser,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid")
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, storer)
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
return ab, wantUser, req
|
||||
@ -43,7 +40,7 @@ func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
|
||||
func testSetupContextPanic() *Authboss {
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
|
||||
ab.StoreLoader = mockStoreLoader{}
|
||||
ab.Storer = mockServerStorer{}
|
||||
|
||||
return ab
|
||||
}
|
||||
@ -207,10 +204,10 @@ func TestLoadCurrentUser(t *testing.T) {
|
||||
t.Error("got:", got)
|
||||
}
|
||||
|
||||
want := user.(mockStoredUser).mockUser
|
||||
got := r.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
|
||||
want := user.(mockUser)
|
||||
got := r.Context().Value(ctxKeyUser).(mockUser)
|
||||
if got != want {
|
||||
t.Error("users mismatched:\nwant: %#v\ngot: %#v", want, got)
|
||||
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
@ -224,9 +221,9 @@ func TestLoadCurrentUserContext(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
got := user.(mockStoredUser).mockUser
|
||||
got := user.(mockUser)
|
||||
if got != wantUser {
|
||||
t.Error("users mismatched:\nwant: %#v\ngot: %#v", wantUser, got)
|
||||
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", wantUser, got)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -76,6 +76,7 @@ func TestResponder(t *testing.T) {
|
||||
t.Errorf("data mismatched:\nwant: %#v\ngot: %#v", expectData, gotData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirector(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -10,7 +10,7 @@ var nowTime = time.Now
|
||||
|
||||
// TimeToExpiry returns zero if the user session is expired else the time
|
||||
// until expiry. Takes in the allowed idle duration.
|
||||
func TimeToExpiry(w http.ResponseWriter, r *http.Request, expireAfter time.Duration) time.Duration {
|
||||
func TimeToExpiry(r *http.Request, expireAfter time.Duration) time.Duration {
|
||||
return timeToExpiry(r, expireAfter)
|
||||
}
|
||||
|
||||
|
@ -54,36 +54,37 @@ func TestExpireIsExpired(t *testing.T) {
|
||||
t.Error("expected user not to be present")
|
||||
}
|
||||
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
want := ClientStateEvent{
|
||||
Kind: ClientStateEventDel,
|
||||
Key: SessionKey,
|
||||
}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
if got := w.sessionStateEvents[0]; got != want {
|
||||
t.Error("want:", want, "got:", got)
|
||||
}
|
||||
want = ClientStateEvent{
|
||||
Kind: ClientStateEventDel,
|
||||
Key: SessionLastAction,
|
||||
}
|
||||
if got := csrw.sessionStateEvents[1]; got != want {
|
||||
if got := w.sessionStateEvents[1]; got != want {
|
||||
t.Error("want:", want, "got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpireNotExpired(t *testing.T) {
|
||||
ab := New()
|
||||
ab.Config.ExpireAfter = time.Hour
|
||||
ab.SessionStateStorer = newMockClientStateRW(
|
||||
SessionKey, "username",
|
||||
SessionLastAction, time.Now().UTC().Format(time.RFC3339),
|
||||
)
|
||||
|
||||
var err error
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), ctxKeyPID, "primaryid"))
|
||||
r = r.WithContext(context.WithValue(r.Context(), ctxKeyUser, struct{}{}))
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
r, err := ab.LoadClientState(w, r)
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -119,14 +120,13 @@ func TestExpireNotExpired(t *testing.T) {
|
||||
t.Error("expected user to be present")
|
||||
}
|
||||
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
want := ClientStateEvent{
|
||||
Kind: ClientStateEventPut,
|
||||
Key: SessionLastAction,
|
||||
Value: newTime.Format(time.RFC3339),
|
||||
}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
|
||||
if got := w.sessionStateEvents[0]; got != want {
|
||||
t.Error("want:", want, "got:", got)
|
||||
}
|
||||
}
|
||||
@ -134,12 +134,10 @@ func TestExpireNotExpired(t *testing.T) {
|
||||
func TestExpireTimeToExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
want := 5 * time.Second
|
||||
dur := TimeToExpiry(w, r, want)
|
||||
dur := TimeToExpiry(r, want)
|
||||
if dur != want {
|
||||
t.Error("duration was wrong:", dur)
|
||||
}
|
||||
@ -153,11 +151,10 @@ func TestExpireRefreshExpiry(t *testing.T) {
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
RefreshExpiry(w, r)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
if got := csrw.sessionStateEvents[0].Kind; got != ClientStateEventPut {
|
||||
if got := w.sessionStateEvents[0].Kind; got != ClientStateEventPut {
|
||||
t.Error("wrong event:", got)
|
||||
}
|
||||
if got := csrw.sessionStateEvents[0].Key; got != SessionLastAction {
|
||||
if got := w.sessionStateEvents[0].Key; got != SessionLastAction {
|
||||
t.Error("wrong key:", got)
|
||||
}
|
||||
}
|
||||
|
@ -4,11 +4,6 @@ import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// SendMail uses the currently configured mailer to deliver e-mails.
|
||||
func (a *Authboss) SendMail(ctx context.Context, data Email) error {
|
||||
return a.Mailer.Send(ctx, data)
|
||||
}
|
||||
|
||||
// Mailer is a type that is capable of sending an e-mail.
|
||||
type Mailer interface {
|
||||
Send(context.Context, Email) error
|
||||
|
@ -17,66 +17,51 @@ type mockUser struct {
|
||||
Password string
|
||||
}
|
||||
|
||||
type mockStoredUser struct {
|
||||
mockUser
|
||||
mockStoreLoader
|
||||
}
|
||||
type mockServerStorer map[string]mockUser
|
||||
|
||||
type mockStoreLoader map[string]mockUser
|
||||
|
||||
func (m mockStoreLoader) Load(ctx context.Context, key string) (Storer, error) {
|
||||
func (m mockServerStorer) Load(ctx context.Context, key string) (User, error) {
|
||||
u, ok := m[key]
|
||||
if !ok {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
return mockStoredUser{
|
||||
mockUser: u,
|
||||
mockStoreLoader: m,
|
||||
}, nil
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (m mockStoredUser) Load(ctx context.Context) error {
|
||||
u, ok := m.mockStoreLoader[m.Email]
|
||||
if !ok {
|
||||
return ErrUserNotFound
|
||||
func (m mockServerStorer) Save(ctx context.Context, user User) error {
|
||||
e, err := user.GetEmail(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
m.Email = u.Email
|
||||
m.Password = u.Password
|
||||
m[e] = user.(mockUser)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockStoredUser) Save(ctx context.Context) error {
|
||||
m.mockStoreLoader[m.Email] = m.mockUser
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockStoredUser) PutEmail(ctx context.Context, email string) error {
|
||||
func (m mockUser) PutEmail(ctx context.Context, email string) error {
|
||||
m.Email = email
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockStoredUser) PutUsername(ctx context.Context, username string) error {
|
||||
func (m mockUser) PutUsername(ctx context.Context, username string) error {
|
||||
return errors.New("not impl")
|
||||
}
|
||||
|
||||
func (m mockStoredUser) PutPassword(ctx context.Context, password string) error {
|
||||
func (m mockUser) PutPassword(ctx context.Context, password string) error {
|
||||
m.Password = password
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockStoredUser) GetEmail(ctx context.Context) (email string, err error) {
|
||||
func (m mockUser) GetEmail(ctx context.Context) (email string, err error) {
|
||||
return m.Email, nil
|
||||
}
|
||||
|
||||
func (m mockStoredUser) GetUsername(ctx context.Context) (username string, err error) {
|
||||
func (m mockUser) GetUsername(ctx context.Context) (username string, err error) {
|
||||
return "", errors.New("not impl")
|
||||
}
|
||||
|
||||
func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err error) {
|
||||
func (m mockUser) GetPassword(ctx context.Context) (password string, err error) {
|
||||
return m.Password, nil
|
||||
}
|
||||
|
||||
|
@ -1,72 +0,0 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testMailer struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (t testMailer) Send(_ context.Context, email Email) error {
|
||||
fmt.Fprintf(t.Writer, "%v", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestResponseEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.renderer = mockEmailRenderer{}
|
||||
ab.SessionStateStorer = newMockClientStateRW(
|
||||
FlashSuccessKey, "flash_success",
|
||||
FlashErrorKey, "flash_error",
|
||||
)
|
||||
ab.XSRFName = "xsrf"
|
||||
ab.XSRFMaker = func(w http.ResponseWriter, r *http.Request) string {
|
||||
return "xsrftoken"
|
||||
}
|
||||
ab.LayoutDataMaker = func(w http.ResponseWriter, r *http.Request) HTMLData {
|
||||
return HTMLData{"hello": "world"}
|
||||
}
|
||||
|
||||
output := &bytes.Buffer{}
|
||||
ab.Mailer = testMailer{output}
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
wr := httptest.NewRecorder()
|
||||
w := ab.NewResponse(wr, r)
|
||||
|
||||
email := Email{
|
||||
To: []string{"test@example.com"},
|
||||
From: "test@example.com",
|
||||
Subject: "subject",
|
||||
}
|
||||
ro := EmailResponseOptions{Data: nil, HTMLTemplate: "html", TextTemplate: "text"}
|
||||
err := ab.Email(w, r, email, ro)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
wantStrings := []string{
|
||||
"To: test@example.com",
|
||||
"From: test@example.com",
|
||||
"Subject: subject",
|
||||
"development text e-mail",
|
||||
"development html e-mail",
|
||||
}
|
||||
|
||||
out := output.String()
|
||||
for i, test := range wantStrings {
|
||||
if !strings.Contains(out, test) {
|
||||
t.Errorf("output missing string(%d): %s\n%s", i, test, out)
|
||||
}
|
||||
}
|
||||
}
|
171
router_test.go
171
router_test.go
@ -1,171 +0,0 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const testRouterModName = "testrouter"
|
||||
|
||||
func init() {
|
||||
RegisterModule(testRouterModName, testRouterModule{})
|
||||
}
|
||||
|
||||
type testRouterModule struct {
|
||||
routes RouteTable
|
||||
}
|
||||
|
||||
func (t testRouterModule) Initialize(ab *Authboss) error { return nil }
|
||||
func (t testRouterModule) Routes() RouteTable { return t.routes }
|
||||
func (t testRouterModule) Templates() []string { return []string{"template1.tpl"} }
|
||||
|
||||
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
|
||||
ab := New()
|
||||
logger := &bytes.Buffer{}
|
||||
ab.LogWriter = logger
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.Init(testRouterModName)
|
||||
ab.MountPath = "/prefix"
|
||||
//ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
//ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
logger.Reset() // Clear out the module load messages
|
||||
|
||||
return ab, ab.NewRouter(), logger
|
||||
}
|
||||
|
||||
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
|
||||
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
|
||||
registeredModules[testRouterModName] = testRouterModule{
|
||||
routes: map[string]HandlerFunc{path: h},
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)
|
||||
|
||||
return w, r
|
||||
}
|
||||
|
||||
func TestRouter(t *testing.T) {
|
||||
called := false
|
||||
|
||||
w, r := testRouterCallbackSetup("/called", func(http.ResponseWriter, *http.Request) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
_, router, _ := testRouterSetup()
|
||||
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
if !called {
|
||||
t.Error("Expected handler to be called.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_NotFound(t *testing.T) {
|
||||
ab, router, _ := testRouterSetup()
|
||||
w := httptest.NewRecorder()
|
||||
r, _ := http.NewRequest("GET", "http://localhost/wat", nil)
|
||||
|
||||
router.ServeHTTP(w, r)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Error("Wrong code:", w.Code)
|
||||
}
|
||||
if body := w.Body.String(); body != "404 Page not found" {
|
||||
t.Error("Wrong body:", body)
|
||||
}
|
||||
|
||||
called := false
|
||||
ab.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
router.ServeHTTP(w, r)
|
||||
if !called {
|
||||
t.Error("Should be called.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_BadRequest(t *testing.T) {
|
||||
err := ClientDataErr{"what"}
|
||||
w, r := testRouterCallbackSetup("/badrequest",
|
||||
func(http.ResponseWriter, *http.Request) error {
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
ab, router, logger := testRouterSetup()
|
||||
logger.Reset()
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Error("Wrong code:", w.Code)
|
||||
}
|
||||
if body := w.Body.String(); body != "400 Bad request" {
|
||||
t.Error("Wrong body:", body)
|
||||
}
|
||||
|
||||
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
||||
t.Error(str)
|
||||
}
|
||||
|
||||
called := false
|
||||
ab.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
logger.Reset()
|
||||
router.ServeHTTP(w, r)
|
||||
if !called {
|
||||
t.Error("Should be called.")
|
||||
}
|
||||
|
||||
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
||||
t.Error(str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Error(t *testing.T) {
|
||||
err := errors.New("error")
|
||||
w, r := testRouterCallbackSetup("/error",
|
||||
func(http.ResponseWriter, *http.Request) error {
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
ab, router, logger := testRouterSetup()
|
||||
logger.Reset()
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Error("Wrong code:", w.Code)
|
||||
}
|
||||
if body := w.Body.String(); body != "500 An error has occurred" {
|
||||
t.Error("Wrong body:", body)
|
||||
}
|
||||
|
||||
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
||||
t.Error(str)
|
||||
}
|
||||
|
||||
called := false
|
||||
ab.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
})
|
||||
|
||||
logger.Reset()
|
||||
router.ServeHTTP(w, r)
|
||||
if !called {
|
||||
t.Error("Should be called.")
|
||||
}
|
||||
|
||||
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
||||
t.Error(str)
|
||||
}
|
||||
}
|
@ -51,7 +51,7 @@ func TestErrorList_Map(t *testing.T) {
|
||||
t.Error("Wrong number of fields:", len(m))
|
||||
}
|
||||
|
||||
usernameErrs := m["email"]
|
||||
usernameErrs := m["username"]
|
||||
if len(usernameErrs) != 2 {
|
||||
t.Error("Wrong number of username errors:", len(usernameErrs))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user