1
0
mirror of https://github.com/volatiletech/authboss.git synced 2024-11-24 08:42:17 +02:00

Bring back events

- Rename callbacks -> events
- Regenerate stringers.go with later version of stringer
This commit is contained in:
Aaron L 2018-02-01 16:31:08 -08:00
parent de1c2ed081
commit ad5230a303
14 changed files with 261 additions and 76 deletions

View File

@ -94,13 +94,13 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
fmt.Fprintf(ctx.LogWriter, "auth: validate credentials failed: %v\n", err)
return a.templates.Render(ctx, w, r, tplLogin, errData)
} else if !valid {
if err := a.Callbacks.FireAfter(authboss.EventAuthFail, ctx); err != nil {
if err := a.Events.FireAfter(authboss.EventAuthFail, ctx); err != nil {
fmt.Fprintf(ctx.LogWriter, "EventAuthFail callback error'd out: %v\n", err)
}
return a.templates.Render(ctx, w, r, tplLogin, errData)
}
interrupted, err := a.Callbacks.FireBefore(authboss.EventAuth, ctx)
interrupted, err := a.Events.FireBefore(authboss.EventAuth, ctx)
if err != nil {
return err
} else if interrupted != authboss.InterruptNone {
@ -119,7 +119,7 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
ctx.Values = map[string]string{authboss.CookieRemember: r.FormValue(authboss.CookieRemember)}
if err := a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil {
if err := a.Events.FireAfter(authboss.EventAuth, ctx); err != nil {
return err
}
response.Redirect(ctx, w, r, a.AuthLoginOKPath, "", "", true)

View File

@ -99,8 +99,8 @@ func TestAuth_loginHandlerFunc_POST_ReturnsErrorOnCallbackFailure(t *testing.T)
a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
a.Callbacks = authboss.NewCallbacks()
a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
a.Events = authboss.NewCallbacks()
a.Events.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptNone, errors.New("explode")
})
@ -117,8 +117,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
a.Callbacks = authboss.NewCallbacks()
a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
a.Events = authboss.NewCallbacks()
a.Events.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptAccountLocked, nil
})
@ -142,8 +142,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
t.Error("Expected error flash message:", expectedMsg)
}
a.Callbacks = authboss.NewCallbacks()
a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
a.Events = authboss.NewCallbacks()
a.Events.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptAccountNotConfirmed, nil
})
@ -229,8 +229,8 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
ctx, w, r, _ := testRequest(a.Authboss, "POST", "username", "john", "password", "1234")
cb := mocks.NewMockAfterCallback()
a.Callbacks = authboss.NewCallbacks()
a.Callbacks.After(authboss.EventAuth, cb.Fn)
a.Events = authboss.NewCallbacks()
a.Events.After(authboss.EventAuth, cb.Fn)
a.AuthLoginOKPath = "/dashboard"
sessions := mocks.NewMockClientStorer()

View File

@ -11,6 +11,8 @@ import "github.com/pkg/errors"
// Authboss contains a configuration and other details for running.
type Authboss struct {
Config
Events *Events
loadedModules map[string]Moduler
}
@ -18,7 +20,9 @@ type Authboss struct {
// configuration.
func New() *Authboss {
ab := &Authboss{}
ab.loadedModules = make(map[string]Moduler)
ab.Events = NewEvents()
ab.Config.Defaults()
return ab
@ -90,7 +94,7 @@ func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
return nil
}
return a.Callbacks.FireAfter(EventPasswordReset, r.Context())
return a.Events.FireAfter(EventPasswordReset, r.Context())
// TODO(aarondl): Fix
return errors.New("not implemented")
}

View File

@ -26,7 +26,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
ab.CookieStoreMaker = newMockClientStoreMaker(cookies)
called := false
ab.Callbacks.After(EventPasswordReset, func(ctx context.Context) error {
ab.Events.After(EventPasswordReset, func(ctx context.Context) error {
called = true
return nil
})
@ -50,7 +50,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
t.Error("Password not updated")
}
if !called {
t.Error("Callbacks should have been called.")
t.Error("Events should have been called.")
}
called = false
@ -63,7 +63,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
t.Error("Password not updated")
}
if !called {
t.Error("Callbacks should have been called.")
t.Error("Events should have been called.")
}
called = false
@ -77,7 +77,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
t.Error("Password not updated")
}
if called {
t.Error("Callbacks should not have been called")
t.Error("Events should not have been called")
}
*/
}

View File

@ -66,12 +66,12 @@ type Confirm struct {
func (c *Confirm) Initialize(ab *authboss.Authboss) (err error) {
c.Authboss = ab
c.Callbacks.After(authboss.EventGetUser, func(ctx context.Context) error {
c.Events.After(authboss.EventGetUser, func(ctx context.Context) error {
_, err := c.beforeGet(ctx)
return err
})
c.Callbacks.Before(authboss.EventAuth, c.beforeGet)
c.Callbacks.After(authboss.EventRegister, c.afterRegister)
c.Events.Before(authboss.EventAuth, c.beforeGet)
c.Events.After(authboss.EventRegister, c.afterRegister)
return nil
}

View File

@ -40,7 +40,7 @@ const (
InterruptSessionExpired
)
// Before callbacks can interrupt the flow by returning an interrupt value.
// Before Events can interrupt the flow by returning an interrupt value.
// This is used to stop the callback chain and the original handler from
// continuing execution. The execution should also stopped if there is an error.
type Before func(context.Context) (Interrupt, error)
@ -48,45 +48,45 @@ type Before func(context.Context) (Interrupt, error)
// After is a request callback that happens after the event.
type After func(context.Context) error
// Callbacks is a collection of callbacks that fire before and after certain
// Events is a collection of Events that fire before and after certain
// methods.
type Callbacks struct {
type Events struct {
before map[Event][]Before
after map[Event][]After
}
// NewCallbacks creates a new set of before and after callbacks.
// NewEvents creates a new set of before and after Events.
// Called only by authboss internals and for testing.
func NewCallbacks() *Callbacks {
return &Callbacks{
func NewEvents() *Events {
return &Events{
before: make(map[Event][]Before),
after: make(map[Event][]After),
}
}
// Before event, call f.
func (c *Callbacks) Before(e Event, f Before) {
callbacks := c.before[e]
callbacks = append(callbacks, f)
c.before[e] = callbacks
func (c *Events) Before(e Event, f Before) {
Events := c.before[e]
Events = append(Events, f)
c.before[e] = Events
}
// After event, call f.
func (c *Callbacks) After(e Event, f After) {
callbacks := c.after[e]
callbacks = append(callbacks, f)
c.after[e] = callbacks
func (c *Events) After(e Event, f After) {
Events := c.after[e]
Events = append(Events, f)
c.after[e] = Events
}
// FireBefore event to all the callbacks with a context. The error
// FireBefore event to all the Events with a context. The error
// should be passed up despite being logged once here already so it
// can write an error out to the HTTP Client. If err is nil then
// check the value of interrupted. If error is nil then the interrupt
// value should be checked. If it is not InterruptNone then there is a reason
// the current process should stop it's course of action.
func (c *Callbacks) FireBefore(e Event, ctx context.Context) (interrupt Interrupt, err error) {
callbacks := c.before[e]
for _, fn := range callbacks {
func (c *Events) FireBefore(ctx context.Context, e Event) (interrupt Interrupt, err error) {
Events := c.before[e]
for _, fn := range Events {
interrupt, err = fn(ctx)
if err != nil {
return InterruptNone, err
@ -99,11 +99,11 @@ func (c *Callbacks) FireBefore(e Event, ctx context.Context) (interrupt Interrup
return InterruptNone, nil
}
// FireAfter event to all the callbacks with a context. The error can safely be
// FireAfter event to all the Events with a context. The error can safely be
// ignored as it is logged.
func (c *Callbacks) FireAfter(e Event, ctx context.Context) (err error) {
callbacks := c.after[e]
for _, fn := range callbacks {
func (c *Events) FireAfter(ctx context.Context, e Event) (err error) {
Events := c.after[e]
for _, fn := range Events {
if err = fn(ctx); err != nil {
return err
}

191
events_test.go Normal file
View File

@ -0,0 +1,191 @@
package authboss
import (
"context"
"errors"
"testing"
)
func TestEvents(t *testing.T) {
t.Parallel()
ab := New()
afterCalled := false
beforeCalled := false
ab.Events.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
beforeCalled = true
return InterruptNone, nil
})
ab.Events.After(EventRegister, func(ctx context.Context) error {
afterCalled = true
return nil
})
if beforeCalled || afterCalled {
t.Error("Neither should be called.")
}
interrupt, err := ab.Events.FireBefore(context.Background(), EventRegister)
if err != nil {
t.Error("Unexpected error:", err)
}
if interrupt != InterruptNone {
t.Error("It should not have been stopped.")
}
if !beforeCalled {
t.Error("Expected before to have been called.")
}
if afterCalled {
t.Error("Expected after not to be called.")
}
ab.Events.FireAfter(context.Background(), EventRegister)
if !afterCalled {
t.Error("Expected after to be called.")
}
}
func TestCallbacksInterrupt(t *testing.T) {
t.Parallel()
ev := NewEvents()
before1 := false
before2 := false
ev.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before1 = true
return InterruptAccountLocked, nil
})
ev.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before2 = true
return InterruptNone, nil
})
interrupt, err := ev.FireBefore(context.Background(), EventRegister)
if err != nil {
t.Error(err)
}
if interrupt != InterruptAccountLocked {
t.Error("The interrupt signal was not account locked:", interrupt)
}
if !before1 {
t.Error("Before1 should have been called.")
}
if before2 {
t.Error("Before2 should not have been called.")
}
}
func TestCallbacksBeforeErrors(t *testing.T) {
t.Parallel()
ev := NewEvents()
before1 := false
before2 := false
errValue := errors.New("Problem occured")
ev.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before1 = true
return InterruptNone, errValue
})
ev.Before(EventRegister, func(ctx context.Context) (Interrupt, error) {
before2 = true
return InterruptNone, nil
})
interrupt, err := ev.FireBefore(context.Background(), EventRegister)
if err != errValue {
t.Error("Expected an error to come back.")
}
if interrupt != InterruptNone {
t.Error("It should not have been stopped.")
}
if !before1 {
t.Error("Before1 should have been called.")
}
if before2 {
t.Error("Before2 should not have been called.")
}
}
func TestCallbacksAfterErrors(t *testing.T) {
t.Parallel()
ev := NewEvents()
after1 := false
after2 := false
errValue := errors.New("Problem occured")
ev.After(EventRegister, func(ctx context.Context) error {
after1 = true
return errValue
})
ev.After(EventRegister, func(ctx context.Context) error {
after2 = true
return nil
})
err := ev.FireAfter(context.Background(), EventRegister)
if err != errValue {
t.Error("Expected an error to come back.")
}
if !after1 {
t.Error("After1 should have been called.")
}
if after2 {
t.Error("After2 should not have been called.")
}
}
func TestEventString(t *testing.T) {
t.Parallel()
tests := []struct {
ev Event
str string
}{
{EventRegister, "EventRegister"},
{EventAuth, "EventAuth"},
{EventOAuth, "EventOAuth"},
{EventAuthFail, "EventAuthFail"},
{EventOAuthFail, "EventOAuthFail"},
{EventRecoverStart, "EventRecoverStart"},
{EventRecoverEnd, "EventRecoverEnd"},
{EventGetUser, "EventGetUser"},
{EventGetUserSession, "EventGetUserSession"},
{EventPasswordReset, "EventPasswordReset"},
}
for i, test := range tests {
if got := test.ev.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.ev, test.str, got)
}
}
}
func TestInterruptString(t *testing.T) {
t.Parallel()
tests := []struct {
in Interrupt
str string
}{
{InterruptNone, "InterruptNone"},
{InterruptAccountLocked, "InterruptAccountLocked"},
{InterruptAccountNotConfirmed, "InterruptAccountNotConfirmed"},
{InterruptSessionExpired, "InterruptSessionExpired"},
}
for i, test := range tests {
if got := test.in.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.in, test.str, got)
}
}
}

View File

@ -37,13 +37,13 @@ func (l *Lock) Initialize(ab *authboss.Authboss) error {
}
// Events
l.Callbacks.After(authboss.EventGetUser, func(ctx *authboss.Context) error {
l.Events.After(authboss.EventGetUser, func(ctx *authboss.Context) error {
_, err := l.beforeAuth(ctx)
return err
})
l.Callbacks.Before(authboss.EventAuth, l.beforeAuth)
l.Callbacks.After(authboss.EventAuth, l.afterAuth)
l.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail)
l.Events.Before(authboss.EventAuth, l.beforeAuth)
l.Events.After(authboss.EventAuth, l.afterAuth)
l.Events.After(authboss.EventAuthFail, l.afterAuthFail)
return nil
}

View File

@ -134,7 +134,7 @@ func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *
}
sessValues, ok := ctx.SessionStorer.Get(authboss.SessionOAuth2Params)
// Don't delete this value from session immediately, callbacks use this too
// Don't delete this value from session immediately, Events use this too
var values map[string]string
if ok {
if err := json.Unmarshal([]byte(sessValues), &values); err != nil {
@ -144,7 +144,7 @@ func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
if err := o.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
if err := o.Events.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
return err
}
@ -201,7 +201,7 @@ func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
if err = o.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
if err = o.Events.FireAfter(authboss.EventOAuth, ctx); err != nil {
return nil
}

View File

@ -273,7 +273,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
return err
}
if err := r.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil {
if err := r.Events.FireAfter(authboss.EventPasswordReset, ctx); err != nil {
return err
}

View File

@ -424,8 +424,8 @@ func TestRecover_completeHandlerFunc_POST(t *testing.T) {
cbCalled := false
rec.Callbacks = authboss.NewCallbacks()
rec.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error {
rec.Events = authboss.NewCallbacks()
rec.Events.After(authboss.EventPasswordReset, func(_ *authboss.Context) error {
cbCalled = true
return nil
})

View File

@ -141,7 +141,7 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
return err
}
if err := reg.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil {
if err := reg.Events.FireAfter(authboss.EventRegister, ctx); err != nil {
return err
}

View File

@ -61,10 +61,10 @@ func (r *Remember) Initialize(ab *authboss.Authboss) error {
return errors.New("need a rememberStorer")
}
r.Callbacks.Before(authboss.EventGetUserSession, r.auth)
r.Callbacks.After(authboss.EventAuth, r.afterAuth)
r.Callbacks.After(authboss.EventOAuth, r.afterOAuth)
r.Callbacks.After(authboss.EventPasswordReset, r.afterPassword)
r.Events.Before(authboss.EventGetUserSession, r.auth)
r.Events.After(authboss.EventAuth, r.afterAuth)
r.Events.After(authboss.EventOAuth, r.afterOAuth)
r.Events.After(authboss.EventPasswordReset, r.afterPassword)
return nil
}

View File

@ -1,37 +1,27 @@
// generated by stringer -output stringers.go -type Event,Interrupt; DO NOT EDIT
// Code generated by "stringer -output stringers.go -type Event,Interrupt"; DO NOT EDIT.
package authboss
import "fmt"
import "strconv"
const _Event_name = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset"
var _Event_index = [...]uint8{13, 22, 32, 45, 59, 76, 91, 103, 122, 140}
var _Event_index = [...]uint8{0, 13, 22, 32, 45, 59, 76, 91, 103, 122, 140}
func (i Event) String() string {
if i < 0 || i >= Event(len(_Event_index)) {
return fmt.Sprintf("Event(%d)", i)
if i < 0 || i >= Event(len(_Event_index)-1) {
return "Event(" + strconv.FormatInt(int64(i), 10) + ")"
}
hi := _Event_index[i]
lo := uint8(0)
if i > 0 {
lo = _Event_index[i-1]
}
return _Event_name[lo:hi]
return _Event_name[_Event_index[i]:_Event_index[i+1]]
}
const _Interrupt_name = "InterruptNoneInterruptAccountLockedInterruptAccountNotConfirmedInterruptSessionExpired"
var _Interrupt_index = [...]uint8{13, 35, 63, 86}
var _Interrupt_index = [...]uint8{0, 13, 35, 63, 86}
func (i Interrupt) String() string {
if i < 0 || i >= Interrupt(len(_Interrupt_index)) {
return fmt.Sprintf("Interrupt(%d)", i)
if i < 0 || i >= Interrupt(len(_Interrupt_index)-1) {
return "Interrupt(" + strconv.FormatInt(int64(i), 10) + ")"
}
hi := _Interrupt_index[i]
lo := uint8(0)
if i > 0 {
lo = _Interrupt_index[i-1]
}
return _Interrupt_name[lo:hi]
return _Interrupt_name[_Interrupt_index[i]:_Interrupt_index[i+1]]
}