1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-08 04:03:53 +02:00
authboss/callbacks_test.go

205 lines
4.5 KiB
Go
Raw Normal View History

package authboss
import (
2015-02-21 09:01:45 +02:00
"bytes"
"errors"
2015-02-21 09:01:45 +02:00
"strings"
"testing"
)
func TestCallbacks(t *testing.T) {
t.Parallel()
ab := New()
afterCalled := false
beforeCalled := false
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
beforeCalled = true
2015-02-22 07:03:03 +02:00
return InterruptNone, nil
})
ab.Callbacks.After(EventRegister, func(ctx *Context) error {
afterCalled = true
2015-02-20 14:03:22 +02:00
return nil
})
if beforeCalled || afterCalled {
t.Error("Neither should be called.")
}
interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext())
if err != nil {
t.Error("Unexpected error:", err)
}
2015-02-22 07:03:03 +02:00
if interrupt != InterruptNone {
2015-02-20 14:03:22 +02:00
t.Error("It should not have been stopped.")
}
if !beforeCalled {
t.Error("Expected before to have been called.")
}
if afterCalled {
t.Error("Expected after not to be called.")
}
ab.Callbacks.FireAfter(EventRegister, ab.NewContext())
if !afterCalled {
t.Error("Expected after to be called.")
}
}
func TestCallbacksInterrupt(t *testing.T) {
t.Parallel()
ab := New()
before1 := false
before2 := false
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
2015-02-20 14:03:22 +02:00
before1 = true
2015-02-22 07:03:03 +02:00
return InterruptAccountLocked, nil
2015-02-20 14:03:22 +02:00
})
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
2015-02-20 14:03:22 +02:00
before2 = true
2015-02-22 07:03:03 +02:00
return InterruptNone, nil
2015-02-20 14:03:22 +02:00
})
interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext())
2015-02-20 14:03:22 +02:00
if err != nil {
t.Error(err)
}
2015-02-22 07:03:03 +02:00
if interrupt != InterruptAccountLocked {
t.Error("The interrupt signal was not account locked:", interrupt)
2015-02-20 14:03:22 +02:00
}
if !before1 {
t.Error("Before1 should have been called.")
}
if before2 {
t.Error("Before2 should not have been called.")
}
}
2015-02-21 09:01:45 +02:00
func TestCallbacksBeforeErrors(t *testing.T) {
t.Parallel()
ab := New()
2015-02-21 09:01:45 +02:00
log := &bytes.Buffer{}
ab.LogWriter = log
2015-02-20 14:03:22 +02:00
before1 := false
before2 := false
2015-02-21 09:01:45 +02:00
errValue := errors.New("Problem occured")
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before1 = true
2015-02-22 07:03:03 +02:00
return InterruptNone, errValue
})
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before2 = true
2015-02-22 07:03:03 +02:00
return InterruptNone, nil
})
interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext())
if err != errValue {
t.Error("Expected an error to come back.")
}
2015-02-22 07:03:03 +02:00
if interrupt != InterruptNone {
2015-02-21 09:10:18 +02:00
t.Error("It should not have been stopped.")
2015-02-20 14:03:22 +02:00
}
if !before1 {
t.Error("Before1 should have been called.")
}
if before2 {
t.Error("Before2 should not have been called.")
}
2015-02-21 09:01:45 +02:00
if estr := log.String(); !strings.Contains(estr, errValue.Error()) {
t.Error("Error string wrong:", estr)
}
}
func TestCallbacksAfterErrors(t *testing.T) {
t.Parallel()
2015-02-21 09:01:45 +02:00
log := &bytes.Buffer{}
ab := New()
ab.LogWriter = log
2015-02-21 09:01:45 +02:00
after1 := false
after2 := false
errValue := errors.New("Problem occured")
ab.Callbacks.After(EventRegister, func(ctx *Context) error {
2015-02-21 09:01:45 +02:00
after1 = true
return errValue
})
ab.Callbacks.After(EventRegister, func(ctx *Context) error {
2015-02-21 09:01:45 +02:00
after2 = true
return nil
})
err := ab.Callbacks.FireAfter(EventRegister, ab.NewContext())
2015-02-21 09:01:45 +02:00
if err != errValue {
t.Error("Expected an error to come back.")
}
if !after1 {
t.Error("After1 should have been called.")
}
if after2 {
t.Error("After2 should not have been called.")
}
if estr := log.String(); !strings.Contains(estr, errValue.Error()) {
t.Error("Error string wrong:", estr)
}
}
2015-02-22 07:03:03 +02:00
func TestEventString(t *testing.T) {
t.Parallel()
2015-02-22 07:03:03 +02:00
tests := []struct {
ev Event
str string
}{
{EventRegister, "EventRegister"},
{EventAuth, "EventAuth"},
{EventOAuth, "EventOAuth"},
2015-02-22 07:03:03 +02:00
{EventAuthFail, "EventAuthFail"},
{EventOAuthFail, "EventOAuthFail"},
2015-02-22 07:03:03 +02:00
{EventRecoverStart, "EventRecoverStart"},
{EventRecoverEnd, "EventRecoverEnd"},
{EventGetUser, "EventGetUser"},
{EventGetUserSession, "EventGetUserSession"},
{EventPasswordReset, "EventPasswordReset"},
2015-02-22 07:03:03 +02:00
}
for i, test := range tests {
if got := test.ev.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.ev, test.str, got)
}
}
}
func TestInterruptString(t *testing.T) {
t.Parallel()
2015-02-22 07:03:03 +02:00
tests := []struct {
in Interrupt
str string
}{
{InterruptNone, "InterruptNone"},
{InterruptAccountLocked, "InterruptAccountLocked"},
{InterruptAccountNotConfirmed, "InterruptAccountNotConfirmed"},
2015-02-22 10:24:57 +02:00
{InterruptSessionExpired, "InterruptSessionExpired"},
2015-02-22 07:03:03 +02:00
}
for i, test := range tests {
if got := test.in.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.in, test.str, got)
}
}
}