diff --git a/authboss.go b/authboss.go index 7c8f4b5..283882c 100644 --- a/authboss.go +++ b/authboss.go @@ -83,21 +83,31 @@ func (a *Authboss) currentUser(ctx *Context, w http.ResponseWriter, r *http.Requ return nil, nil } - err = ctx.LoadUser(key) + _, err = a.Callbacks.FireBefore(EventGetUser, ctx) if err != nil { return nil, err } - _, err = a.Callbacks.FireBefore(EventGet, ctx) - if err != nil { - return nil, err - } + var user interface{} if index := strings.IndexByte(key, ';'); index > 0 { - return a.OAuth2Storer.GetOAuth(key[:index], key[index+1:]) + user, err = a.OAuth2Storer.GetOAuth(key[:index], key[index+1:]) + } else { + user, err = a.Storer.Get(key) } - return a.Storer.Get(key) + if err != nil { + return nil, err + } + + ctx.User = Unbind(user) + + err = a.Callbacks.FireAfter(EventGetUser, ctx) + if err != nil { + return nil, err + } + + return user, err } // CurrentUserP retrieves the current user but panics if it's not available for diff --git a/authboss_test.go b/authboss_test.go index 6274961..00251a6 100644 --- a/authboss_test.go +++ b/authboss_test.go @@ -48,6 +48,52 @@ func TestAuthBossCurrentUser(t *testing.T) { } } +func TestAuthBossCurrentUserCallbacks(t *testing.T) { + t.Parallel() + + ab := New() + ab.LogWriter = ioutil.Discard + ab.Storer = mockStorer{"joe": Attributes{"email": "john@john.com", "password": "lies"}} + ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { + return mockClientStore{SessionKey: "joe"} + } + ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { + return mockClientStore{} + } + + if err := ab.Init(); err != nil { + t.Error("Unexpected error:", err) + } + + rec := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "localhost", nil) + + afterGetUser := errors.New("afterGetUser") + beforeGetUser := errors.New("beforeGetUser") + beforeGetUserSession := errors.New("beforeGetUserSession") + + ab.Callbacks.After(EventGetUser, func(*Context) error { + return afterGetUser + }) + if _, err := ab.CurrentUser(rec, req); err != afterGetUser { + t.Error("Want:", afterGetUser, "Got:", err) + } + + ab.Callbacks.Before(EventGetUser, func(*Context) (Interrupt, error) { + return InterruptNone, beforeGetUser + }) + if _, err := ab.CurrentUser(rec, req); err != beforeGetUser { + t.Error("Want:", beforeGetUser, "Got:", err) + } + + ab.Callbacks.Before(EventGetUserSession, func(*Context) (Interrupt, error) { + return InterruptNone, beforeGetUserSession + }) + if _, err := ab.CurrentUser(rec, req); err != beforeGetUserSession { + t.Error("Want:", beforeGetUserSession, "Got:", err) + } +} + func TestAuthbossUpdatePassword(t *testing.T) { t.Parallel() diff --git a/callbacks.go b/callbacks.go index 20a2991..946d4e0 100644 --- a/callbacks.go +++ b/callbacks.go @@ -20,7 +20,7 @@ const ( EventOAuthFail EventRecoverStart EventRecoverEnd - EventGet + EventGetUser EventGetUserSession EventPasswordReset ) diff --git a/callbacks_test.go b/callbacks_test.go index 4a015e6..84fddfa 100644 --- a/callbacks_test.go +++ b/callbacks_test.go @@ -171,7 +171,7 @@ func TestEventString(t *testing.T) { {EventOAuthFail, "EventOAuthFail"}, {EventRecoverStart, "EventRecoverStart"}, {EventRecoverEnd, "EventRecoverEnd"}, - {EventGet, "EventGet"}, + {EventGetUser, "EventGetUser"}, {EventGetUserSession, "EventGetUserSession"}, {EventPasswordReset, "EventPasswordReset"}, } diff --git a/confirm/confirm.go b/confirm/confirm.go index d9b98f0..acf4097 100644 --- a/confirm/confirm.go +++ b/confirm/confirm.go @@ -70,7 +70,10 @@ func (c *Confirm) Initialize(ab *authboss.Authboss) (err error) { return err } - c.Callbacks.Before(authboss.EventGet, c.beforeGet) + c.Callbacks.After(authboss.EventGetUser, func(ctx *authboss.Context) error { + _, err := c.beforeGet(ctx) + return err + }) c.Callbacks.Before(authboss.EventAuth, c.beforeGet) c.Callbacks.After(authboss.EventRegister, c.afterRegister) diff --git a/lock/lock.go b/lock/lock.go index 2517f3b..59da114 100644 --- a/lock/lock.go +++ b/lock/lock.go @@ -36,7 +36,10 @@ func (l *Lock) Initialize(ab *authboss.Authboss) error { } // Events - l.Callbacks.Before(authboss.EventGet, l.beforeAuth) + l.Callbacks.After(authboss.EventGetUser, func(ctx *authboss.Context) error { + _, err := l.beforeAuth(ctx) + return err + }) l.Callbacks.Before(authboss.EventAuth, l.beforeAuth) l.Callbacks.After(authboss.EventAuth, l.afterAuth) l.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail) diff --git a/stringers.go b/stringers.go index 51f2fcd..0e8ae28 100644 --- a/stringers.go +++ b/stringers.go @@ -4,9 +4,9 @@ package authboss import "fmt" -const _Event_name = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetEventGetUserSessionEventPasswordReset" +const _Event_name = "EventRegisterEventAuthEventOAuthEventAuthFailEventOAuthFailEventRecoverStartEventRecoverEndEventGetUserEventGetUserSessionEventPasswordReset" -var _Event_index = [...]uint8{13, 22, 32, 45, 59, 76, 91, 99, 118, 136} +var _Event_index = [...]uint8{13, 22, 32, 45, 59, 76, 91, 103, 122, 140} func (i Event) String() string { if i < 0 || i >= Event(len(_Event_index)) {