1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-03-19 21:28:00 +02:00

Change the way callbacks work AGAIN.

This commit is contained in:
Aaron 2015-02-21 21:03:03 -08:00
parent 80e40e7817
commit 1ca6eb1cc0
2 changed files with 103 additions and 27 deletions

View File

@ -9,7 +9,7 @@ import (
// Event is used for callback registration. // Event is used for callback registration.
type Event int type Event int
// These are the events that are available for use. // Event values
const ( const (
EventRegister Event = iota EventRegister Event = iota
EventAuth EventAuth
@ -19,11 +19,49 @@ const (
EventGet EventGet
) )
const eventNames = "EventRegisterEventAuthEventAuthFailEventRecoverStartEventRecoverEndEventGet"
var eventIndexes = [...]uint8{0, 13, 22, 35, 52, 67, 75}
func (i Event) String() string {
if i < 0 || i+1 >= Event(len(eventIndexes)) {
return fmt.Sprintf("Event(%d)", i)
}
return eventNames[eventIndexes[i]:eventIndexes[i+1]]
}
// Interrupt is used to signal to callback mechanisms
// that the current process should not continue.
type Interrupt int
// Interrupt values
const (
// InterruptNone means there was no interrupt present and the process should continue.
InterruptNone Interrupt = iota
// InterruptAccountLocked occurs if a user's account has been locked
// by the lock module.
InterruptAccountLocked
// InterruptAccountNotConfirmed occurs if a user's account is not confirmed
// and therefore cannot be used yet.
InterruptAccountNotConfirmed
)
const interruptNames = "InterruptNoneInterruptAccountLockedInterruptAccountNotConfirmed"
var interruptIndexes = [...]uint8{0, 13, 35, 63}
func (i Interrupt) String() string {
if i < 0 || i+1 >= Interrupt(len(interruptIndexes)) {
return fmt.Sprintf("Interrupt(%d)", i)
}
return interruptNames[interruptIndexes[i]:interruptIndexes[i+1]]
}
// Before callbacks can interrupt the flow by returning a bool. This is used to stop // Before callbacks can interrupt the flow by returning a bool. This is used to stop
// the callback chain and the original handler from continuing execution. // the callback chain and the original handler from continuing execution.
// The execution should also stopped if there is an error (and therefore if error is set // The execution should also stopped if there is an error (and therefore if error is set
// the bool is automatically considered set). // the bool is automatically considered set).
type Before func(*Context) (bool, error) type Before func(*Context) (Interrupt, error)
// After is a request callback that happens after the event. // After is a request callback that happens after the event.
type After func(*Context) error type After func(*Context) error
@ -59,22 +97,23 @@ func (c *Callbacks) After(e Event, f After) {
// FireBefore event to all the callbacks with a context. The error // FireBefore event to all the callbacks with a context. The error
// should be passed up despite being logged once here already so it // should be passed up despite being logged once here already so it
// can write an error out to the HTTP Client. If err is nil then // can write an error out to the HTTP Client. If err is nil then
// check the value of interrupted. If interrupted is false // check the value of interrupted. If error is nil then the interrupt
// then the handler can continue it's task otherwise it must stop. // value should be checked. If it is not InterruptNone then there is a reason
func (c *Callbacks) FireBefore(e Event, ctx *Context) (interrupted bool, err error) { // the current process should stop it's course of action.
func (c *Callbacks) FireBefore(e Event, ctx *Context) (interrupt Interrupt, err error) {
callbacks := c.before[e] callbacks := c.before[e]
for _, fn := range callbacks { for _, fn := range callbacks {
interrupted, err = fn(ctx) interrupt, err = fn(ctx)
if err != nil { if err != nil {
fmt.Fprintf(Cfg.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err) fmt.Fprintf(Cfg.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
return false, err return InterruptNone, err
} }
if interrupted { if interrupt != InterruptNone {
return true, nil return interrupt, nil
} }
} }
return false, nil return InterruptNone, nil
} }
// FireAfter event to all the callbacks with a context. The error can safely be // FireAfter event to all the callbacks with a context. The error can safely be

View File

@ -12,9 +12,9 @@ func TestCallbacks(t *testing.T) {
beforeCalled := false beforeCalled := false
c := NewCallbacks() c := NewCallbacks()
c.Before(EventRegister, func(ctx *Context) (bool, error) { c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
beforeCalled = true beforeCalled = true
return false, nil return InterruptNone, nil
}) })
c.After(EventRegister, func(ctx *Context) error { c.After(EventRegister, func(ctx *Context) error {
afterCalled = true afterCalled = true
@ -25,11 +25,11 @@ func TestCallbacks(t *testing.T) {
t.Error("Neither should be called.") t.Error("Neither should be called.")
} }
stopped, err := c.FireBefore(EventRegister, NewContext()) interrupt, err := c.FireBefore(EventRegister, NewContext())
if err != nil { if err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
if stopped { if interrupt != InterruptNone {
t.Error("It should not have been stopped.") t.Error("It should not have been stopped.")
} }
@ -51,21 +51,21 @@ func TestCallbacksInterrupt(t *testing.T) {
before2 := false before2 := false
c := NewCallbacks() c := NewCallbacks()
c.Before(EventRegister, func(ctx *Context) (bool, error) { c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before1 = true before1 = true
return true, nil return InterruptAccountLocked, nil
}) })
c.Before(EventRegister, func(ctx *Context) (bool, error) { c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before2 = true before2 = true
return false, nil return InterruptNone, nil
}) })
stopped, err := c.FireBefore(EventRegister, NewContext()) interrupt, err := c.FireBefore(EventRegister, NewContext())
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if !stopped { if interrupt != InterruptAccountLocked {
t.Error("It was not stopped.") t.Error("The interrupt signal was not account locked:", interrupt)
} }
if !before1 { if !before1 {
@ -87,20 +87,20 @@ func TestCallbacksBeforeErrors(t *testing.T) {
errValue := errors.New("Problem occured") errValue := errors.New("Problem occured")
c.Before(EventRegister, func(ctx *Context) (bool, error) { c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before1 = true before1 = true
return false, errValue return InterruptNone, errValue
}) })
c.Before(EventRegister, func(ctx *Context) (bool, error) { c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before2 = true before2 = true
return false, nil return InterruptNone, nil
}) })
stopped, err := c.FireBefore(EventRegister, NewContext()) interrupt, err := c.FireBefore(EventRegister, NewContext())
if err != errValue { if err != errValue {
t.Error("Expected an error to come back.") t.Error("Expected an error to come back.")
} }
if stopped { if interrupt != InterruptNone {
t.Error("It should not have been stopped.") t.Error("It should not have been stopped.")
} }
@ -152,3 +152,40 @@ func TestCallbacksAfterErrors(t *testing.T) {
t.Error("Error string wrong:", estr) t.Error("Error string wrong:", estr)
} }
} }
func TestEventString(t *testing.T) {
tests := []struct {
ev Event
str string
}{
{EventRegister, "EventRegister"},
{EventAuth, "EventAuth"},
{EventAuthFail, "EventAuthFail"},
{EventRecoverStart, "EventRecoverStart"},
{EventRecoverEnd, "EventRecoverEnd"},
{EventGet, "EventGet"},
}
for i, test := range tests {
if got := test.ev.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.ev, test.str, got)
}
}
}
func TestInterruptString(t *testing.T) {
tests := []struct {
in Interrupt
str string
}{
{InterruptNone, "InterruptNone"},
{InterruptAccountLocked, "InterruptAccountLocked"},
{InterruptAccountNotConfirmed, "InterruptAccountNotConfirmed"},
}
for i, test := range tests {
if got := test.in.String(); got != test.str {
t.Errorf("%d) Wrong string for Event(%d) expected: %v got: %s", i, test.in, test.str, got)
}
}
}