1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-09-16 09:06:20 +02:00

Fix modules after refactor.

This commit is contained in:
Aaron
2015-03-31 15:27:47 -07:00
committed by Aaron L
parent 9ff0b65629
commit c98ef93e06
17 changed files with 550 additions and 414 deletions

View File

@@ -24,24 +24,27 @@ func init() {
// Auth module // Auth module
type Auth struct { type Auth struct {
*authboss.Authboss
templates response.Templates templates response.Templates
} }
// Initialize module // Initialize module
func (a *Auth) Initialize() (err error) { func (a *Auth) Initialize(ab *authboss.Authboss) (err error) {
if authboss.a.Storer == nil { a.Authboss = ab
if a.Storer == nil {
return errors.New("auth: Need a Storer") return errors.New("auth: Need a Storer")
} }
if len(authboss.a.XSRFName) == 0 { if len(a.XSRFName) == 0 {
return errors.New("auth: XSRFName must be set") return errors.New("auth: XSRFName must be set")
} }
if authboss.a.XSRFMaker == nil { if a.XSRFMaker == nil {
return errors.New("auth: XSRFMaker must be defined") return errors.New("auth: XSRFMaker must be defined")
} }
a.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplLogin) a.templates, err = response.LoadTemplates(a.Authboss, a.Layout, a.ViewsPath, tplLogin)
if err != nil { if err != nil {
return err return err
} }
@@ -60,7 +63,7 @@ func (a *Auth) Routes() authboss.RouteTable {
// Storage requirements // Storage requirements
func (a *Auth) Storage() authboss.StorageOptions { func (a *Auth) Storage() authboss.StorageOptions {
return authboss.StorageOptions{ return authboss.StorageOptions{
authboss.a.PrimaryID: authboss.String, a.PrimaryID: authboss.String,
authboss.StorePassword: authboss.String, authboss.StorePassword: authboss.String,
} }
} }
@@ -70,32 +73,32 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
case methodGET: case methodGET:
if _, ok := ctx.SessionStorer.Get(authboss.SessionKey); ok { if _, ok := ctx.SessionStorer.Get(authboss.SessionKey); ok {
if halfAuthed, ok := ctx.SessionStorer.Get(authboss.SessionHalfAuthKey); !ok || halfAuthed == "false" { if halfAuthed, ok := ctx.SessionStorer.Get(authboss.SessionHalfAuthKey); !ok || halfAuthed == "false" {
//http.Redirect(w, r, authboss.a.AuthLoginOKPath, http.StatusFound, true) //http.Redirect(w, r, a.AuthLoginOKPath, http.StatusFound, true)
response.Redirect(ctx, w, r, authboss.a.AuthLoginOKPath, "", "", true) response.Redirect(ctx, w, r, a.AuthLoginOKPath, "", "", true)
return nil return nil
} }
} }
data := authboss.NewHTMLData( data := authboss.NewHTMLData(
"showRemember", authboss.IsLoaded("remember"), "showRemember", a.IsLoaded("remember"),
"showRecover", authboss.IsLoaded("recover"), "showRecover", a.IsLoaded("recover"),
"primaryID", authboss.a.PrimaryID, "primaryID", a.PrimaryID,
"primaryIDValue", "", "primaryIDValue", "",
) )
return a.templates.Render(ctx, w, r, tplLogin, data) return a.templates.Render(ctx, w, r, tplLogin, data)
case methodPOST: case methodPOST:
key, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID) key, _ := ctx.FirstPostFormValue(a.PrimaryID)
password, _ := ctx.FirstPostFormValue("password") password, _ := ctx.FirstPostFormValue("password")
errData := authboss.NewHTMLData( errData := authboss.NewHTMLData(
"error", fmt.Sprintf("invalid %s and/or password", authboss.a.PrimaryID), "error", fmt.Sprintf("invalid %s and/or password", a.PrimaryID),
"primaryID", authboss.a.PrimaryID, "primaryID", a.PrimaryID,
"primaryIDValue", key, "primaryIDValue", key,
"showRemember", authboss.IsLoaded("remember"), "showRemember", a.IsLoaded("remember"),
"showRecover", authboss.IsLoaded("recover"), "showRecover", a.IsLoaded("recover"),
) )
policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID, authboss.StorePassword) policies := authboss.FilterValidators(a.Policies, a.PrimaryID, authboss.StorePassword)
if validationErrs := ctx.Validate(policies); len(validationErrs) > 0 { if validationErrs := ctx.Validate(policies); len(validationErrs) > 0 {
return a.templates.Render(ctx, w, r, tplLogin, errData) return a.templates.Render(ctx, w, r, tplLogin, errData)
} }
@@ -104,7 +107,7 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
return a.templates.Render(ctx, w, r, tplLogin, errData) return a.templates.Render(ctx, w, r, tplLogin, errData)
} }
interrupted, err := authboss.a.Callbacks.FireBefore(authboss.EventAuth, ctx) interrupted, err := a.Callbacks.FireBefore(authboss.EventAuth, ctx)
if err != nil { if err != nil {
return err return err
} else if interrupted != authboss.InterruptNone { } else if interrupted != authboss.InterruptNone {
@@ -115,17 +118,17 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
case authboss.InterruptAccountNotConfirmed: case authboss.InterruptAccountNotConfirmed:
reason = "Your account has not been confirmed." reason = "Your account has not been confirmed."
} }
response.Redirect(ctx, w, r, authboss.a.AuthLoginFailPath, "", reason, false) response.Redirect(ctx, w, r, a.AuthLoginFailPath, "", reason, false)
return nil return nil
} }
ctx.SessionStorer.Put(authboss.SessionKey, key) ctx.SessionStorer.Put(authboss.SessionKey, key)
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey) ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
if err := authboss.a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil { if err := a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil {
return err return err
} }
response.Redirect(ctx, w, r, authboss.a.AuthLoginOKPath, "", "", true) response.Redirect(ctx, w, r, a.AuthLoginOKPath, "", "", true)
default: default:
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }
@@ -157,7 +160,7 @@ func (a *Auth) logoutHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
ctx.CookieStorer.Del(authboss.CookieRemember) ctx.CookieStorer.Del(authboss.CookieRemember)
ctx.SessionStorer.Del(authboss.SessionLastAction) ctx.SessionStorer.Del(authboss.SessionLastAction)
response.Redirect(ctx, w, r, authboss.a.AuthLogoutOKPath, "You have logged out", "", true) response.Redirect(ctx, w, r, a.AuthLogoutOKPath, "You have logged out", "", true)
default: default:
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }

View File

@@ -16,43 +16,45 @@ import (
func testSetup() (a *Auth, s *mocks.MockStorer) { func testSetup() (a *Auth, s *mocks.MockStorer) {
s = mocks.NewMockStorer() s = mocks.NewMockStorer()
authboss.Cfg = authboss.NewConfig() ab := authboss.New()
authboss.a.LogWriter = ioutil.Discard ab.LogWriter = ioutil.Discard
authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) ab.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.Storer = s ab.Storer = s
authboss.a.XSRFName = "xsrf" ab.XSRFName = "xsrf"
authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { ab.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "xsrfvalue" return "xsrfvalue"
} }
authboss.a.PrimaryID = authboss.StoreUsername ab.PrimaryID = authboss.StoreUsername
a = &Auth{} a = &Auth{}
if err := a.Initialize(); err != nil { if err := a.Initialize(ab); err != nil {
panic(err) panic(err)
} }
return a, s return a, s
} }
func testRequest(method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) { func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
r, err := http.NewRequest(method, "", nil) r, err := http.NewRequest(method, "", nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
sessionStorer := mocks.NewMockClientStorer() sessionStorer := mocks.NewMockClientStorer()
ctx := mocks.MockRequestContext(postFormValues...) ctx := mocks.MockRequestContext(ab, postFormValues...)
ctx.SessionStorer = sessionStorer ctx.SessionStorer = sessionStorer
return ctx, httptest.NewRecorder(), r, sessionStorer return ctx, httptest.NewRecorder(), r, sessionStorer
} }
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
t.Parallel()
a, _ := testSetup() a, _ := testSetup()
storage := a.Storage() storage := a.Storage()
if storage[authboss.a.PrimaryID] != authboss.String { if storage[a.PrimaryID] != authboss.String {
t.Error("Expected storage KV:", authboss.a.PrimaryID, authboss.String) t.Error("Expected storage KV:", a.PrimaryID, authboss.String)
} }
if storage[authboss.StorePassword] != authboss.String { if storage[authboss.StorePassword] != authboss.String {
t.Error("Expected storage KV:", authboss.StorePassword, authboss.String) t.Error("Expected storage KV:", authboss.StorePassword, authboss.String)
@@ -68,13 +70,15 @@ func TestAuth(t *testing.T) {
} }
func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) { func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) {
t.Parallel()
a, _ := testSetup() a, _ := testSetup()
ctx, w, r, sessionStore := testRequest("GET") ctx, w, r, sessionStore := testRequest(a.Authboss, "GET")
sessionStore.Put(authboss.SessionKey, "a") sessionStore.Put(authboss.SessionKey, "a")
sessionStore.Put(authboss.SessionHalfAuthKey, "false") sessionStore.Put(authboss.SessionHalfAuthKey, "false")
authboss.a.AuthLoginOKPath = "/dashboard" a.AuthLoginOKPath = "/dashboard"
if err := a.loginHandlerFunc(ctx, w, r); err != nil { if err := a.loginHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpeced error:", err) t.Error("Unexpeced error:", err)
@@ -85,14 +89,16 @@ func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) {
} }
loc := w.Header().Get("Location") loc := w.Header().Get("Location")
if loc != authboss.a.AuthLoginOKPath { if loc != a.AuthLoginOKPath {
t.Error("Unexpected redirect:", loc) t.Error("Unexpected redirect:", loc)
} }
} }
func TestAuth_loginHandlerFunc_GET(t *testing.T) { func TestAuth_loginHandlerFunc_GET(t *testing.T) {
t.Parallel()
a, _ := testSetup() a, _ := testSetup()
ctx, w, r, _ := testRequest("GET") ctx, w, r, _ := testRequest(a.Authboss, "GET")
if err := a.loginHandlerFunc(ctx, w, r); err != nil { if err := a.loginHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -106,7 +112,7 @@ func TestAuth_loginHandlerFunc_GET(t *testing.T) {
if !strings.Contains(body, "<form") { if !strings.Contains(body, "<form") {
t.Error("Should have rendered a form") t.Error("Should have rendered a form")
} }
if !strings.Contains(body, `name="`+authboss.a.PrimaryID) { if !strings.Contains(body, `name="`+a.PrimaryID) {
t.Error("Form should contain the primary ID field:", body) t.Error("Form should contain the primary ID field:", body)
} }
if !strings.Contains(body, `name="password"`) { if !strings.Contains(body, `name="password"`) {
@@ -115,15 +121,17 @@ func TestAuth_loginHandlerFunc_GET(t *testing.T) {
} }
func TestAuth_loginHandlerFunc_POST_ReturnsErrorOnCallbackFailure(t *testing.T) { func TestAuth_loginHandlerFunc_POST_ReturnsErrorOnCallbackFailure(t *testing.T) {
t.Parallel()
a, storer := testSetup() a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"} storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
authboss.a.Callbacks = authboss.NewCallbacks() a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptNone, errors.New("explode") return authboss.InterruptNone, errors.New("explode")
}) })
ctx, w, r, _ := testRequest("POST", "username", "john", "password", "1234") ctx, w, r, _ := testRequest(a.Authboss, "POST", "username", "john", "password", "1234")
if err := a.loginHandlerFunc(ctx, w, r); err.Error() != "explode" { if err := a.loginHandlerFunc(ctx, w, r); err.Error() != "explode" {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -131,15 +139,17 @@ func TestAuth_loginHandlerFunc_POST_ReturnsErrorOnCallbackFailure(t *testing.T)
} }
func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) { func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
t.Parallel()
a, storer := testSetup() a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"} storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
authboss.a.Callbacks = authboss.NewCallbacks() a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptAccountLocked, nil return authboss.InterruptAccountLocked, nil
}) })
ctx, w, r, sessionStore := testRequest("POST", "username", "john", "password", "1234") ctx, w, r, sessionStore := testRequest(a.Authboss, "POST", "username", "john", "password", "1234")
if err := a.loginHandlerFunc(ctx, w, r); err != nil { if err := a.loginHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -150,7 +160,7 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
} }
loc := w.Header().Get("Location") loc := w.Header().Get("Location")
if loc != authboss.a.AuthLoginFailPath { if loc != a.AuthLoginFailPath {
t.Error("Unexpeced location:", loc) t.Error("Unexpeced location:", loc)
} }
@@ -159,8 +169,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
t.Error("Expected error flash message:", expectedMsg) t.Error("Expected error flash message:", expectedMsg)
} }
authboss.a.Callbacks = authboss.NewCallbacks() a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptAccountNotConfirmed, nil return authboss.InterruptAccountNotConfirmed, nil
}) })
@@ -173,7 +183,7 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
} }
loc = w.Header().Get("Location") loc = w.Header().Get("Location")
if loc != authboss.a.AuthLoginFailPath { if loc != a.AuthLoginFailPath {
t.Error("Unexpeced location:", loc) t.Error("Unexpeced location:", loc)
} }
@@ -184,9 +194,11 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
} }
func TestAuth_loginHandlerFunc_POST_AuthenticationFailure(t *testing.T) { func TestAuth_loginHandlerFunc_POST_AuthenticationFailure(t *testing.T) {
t.Parallel()
a, _ := testSetup() a, _ := testSetup()
ctx, w, r, _ := testRequest("POST", "username", "john", "password", "1") ctx, w, r, _ := testRequest(a.Authboss, "POST", "username", "john", "password", "1")
if err := a.loginHandlerFunc(ctx, w, r); err != nil { if err := a.loginHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -201,7 +213,7 @@ func TestAuth_loginHandlerFunc_POST_AuthenticationFailure(t *testing.T) {
t.Error("Should have rendered with error") t.Error("Should have rendered with error")
} }
ctx, w, r, _ = testRequest("POST", "username", "john", "password", "1234") ctx, w, r, _ = testRequest(a.Authboss, "POST", "username", "john", "password", "1234")
if err := a.loginHandlerFunc(ctx, w, r); err != nil { if err := a.loginHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -218,15 +230,17 @@ func TestAuth_loginHandlerFunc_POST_AuthenticationFailure(t *testing.T) {
} }
func TestAuth_loginHandlerFunc_POST(t *testing.T) { func TestAuth_loginHandlerFunc_POST(t *testing.T) {
t.Parallel()
a, storer := testSetup() a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"} storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
ctx, w, r, _ := testRequest("POST", "username", "john", "password", "1234") ctx, w, r, _ := testRequest(a.Authboss, "POST", "username", "john", "password", "1234")
cb := mocks.NewMockAfterCallback() cb := mocks.NewMockAfterCallback()
authboss.a.Callbacks = authboss.NewCallbacks() a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.After(authboss.EventAuth, cb.Fn) a.Callbacks.After(authboss.EventAuth, cb.Fn)
authboss.a.AuthLoginOKPath = "/dashboard" a.AuthLoginOKPath = "/dashboard"
sessions := mocks.NewMockClientStorer() sessions := mocks.NewMockClientStorer()
ctx.SessionStorer = sessions ctx.SessionStorer = sessions
@@ -244,7 +258,7 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
} }
loc := w.Header().Get("Location") loc := w.Header().Get("Location")
if loc != authboss.a.AuthLoginOKPath { if loc != a.AuthLoginOKPath {
t.Error("Unexpeced location:", loc) t.Error("Unexpeced location:", loc)
} }
@@ -257,6 +271,8 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
} }
func TestAuth_loginHandlerFunc_OtherMethods(t *testing.T) { func TestAuth_loginHandlerFunc_OtherMethods(t *testing.T) {
t.Parallel()
a, _ := testSetup() a, _ := testSetup()
methods := []string{"HEAD", "PUT", "DELETE", "TRACE", "CONNECT"} methods := []string{"HEAD", "PUT", "DELETE", "TRACE", "CONNECT"}
@@ -279,35 +295,39 @@ func TestAuth_loginHandlerFunc_OtherMethods(t *testing.T) {
} }
func TestAuth_validateCredentials(t *testing.T) { func TestAuth_validateCredentials(t *testing.T) {
authboss.Cfg = authboss.NewConfig() t.Parallel()
ab := authboss.New()
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
storer.GetErr = "Failed to load user" storer.GetErr = "Failed to load user"
authboss.a.Storer = storer ab.Storer = storer
ctx := authboss.Context{} ctx := ab.NewContext()
if err := validateCredentials(&ctx, "", ""); err.Error() != "Failed to load user" { if err := validateCredentials(ctx, "", ""); err.Error() != "Failed to load user" {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
storer.GetErr = "" storer.GetErr = ""
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$pgFsuQwdhwOdZp/v52dvHeEi53ZaI7dGmtwK4bAzGGN5A4nT6doqm"} storer.Users["john"] = authboss.Attributes{"password": "$2a$10$pgFsuQwdhwOdZp/v52dvHeEi53ZaI7dGmtwK4bAzGGN5A4nT6doqm"}
if err := validateCredentials(&ctx, "john", "b"); err == nil { if err := validateCredentials(ctx, "john", "b"); err == nil {
t.Error("Expected error about passwords mismatch") t.Error("Expected error about passwords mismatch")
} }
if err := validateCredentials(&ctx, "john", "a"); err != nil { if err := validateCredentials(ctx, "john", "a"); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
} }
func TestAuth_logoutHandlerFunc_GET(t *testing.T) { func TestAuth_logoutHandlerFunc_GET(t *testing.T) {
t.Parallel()
a, _ := testSetup() a, _ := testSetup()
authboss.a.AuthLogoutOKPath = "/dashboard" a.AuthLogoutOKPath = "/dashboard"
ctx, w, r, sessionStorer := testRequest("GET") ctx, w, r, sessionStorer := testRequest(a.Authboss, "GET")
sessionStorer.Put(authboss.SessionKey, "asdf") sessionStorer.Put(authboss.SessionKey, "asdf")
sessionStorer.Put(authboss.SessionLastAction, "1234") sessionStorer.Put(authboss.SessionLastAction, "1234")

View File

@@ -46,30 +46,33 @@ func init() {
// Confirm module // Confirm module
type Confirm struct { type Confirm struct {
*authboss.Authboss
emailHTMLTemplates response.Templates emailHTMLTemplates response.Templates
emailTextTemplates response.Templates emailTextTemplates response.Templates
} }
// Initialize the module // Initialize the module
func (c *Confirm) Initialize() (err error) { func (c *Confirm) Initialize(ab *authboss.Authboss) (err error) {
c.Authboss = ab
var ok bool var ok bool
storer, ok := authboss.a.Storer.(ConfirmStorer) storer, ok := c.Storer.(ConfirmStorer)
if storer == nil || !ok { if storer == nil || !ok {
return errors.New("confirm: Need a ConfirmStorer") return errors.New("confirm: Need a ConfirmStorer")
} }
c.emailHTMLTemplates, err = response.LoadTemplates(authboss.a.LayoutHTMLEmail, authboss.a.ViewsPath, tplConfirmHTML) c.emailHTMLTemplates, err = response.LoadTemplates(ab, c.LayoutHTMLEmail, c.ViewsPath, tplConfirmHTML)
if err != nil { if err != nil {
return err return err
} }
c.emailTextTemplates, err = response.LoadTemplates(authboss.a.LayoutTextEmail, authboss.a.ViewsPath, tplConfirmText) c.emailTextTemplates, err = response.LoadTemplates(ab, c.LayoutTextEmail, c.ViewsPath, tplConfirmText)
if err != nil { if err != nil {
return err return err
} }
authboss.a.Callbacks.Before(authboss.EventGet, c.beforeGet) c.Callbacks.Before(authboss.EventGet, c.beforeGet)
authboss.a.Callbacks.Before(authboss.EventAuth, c.beforeGet) c.Callbacks.Before(authboss.EventAuth, c.beforeGet)
authboss.a.Callbacks.After(authboss.EventRegister, c.afterRegister) c.Callbacks.After(authboss.EventRegister, c.afterRegister)
return nil return nil
} }
@@ -84,10 +87,10 @@ func (c *Confirm) Routes() authboss.RouteTable {
// Storage requirements // Storage requirements
func (c *Confirm) Storage() authboss.StorageOptions { func (c *Confirm) Storage() authboss.StorageOptions {
return authboss.StorageOptions{ return authboss.StorageOptions{
authboss.a.PrimaryID: authboss.String, c.PrimaryID: authboss.String,
authboss.StoreEmail: authboss.String, authboss.StoreEmail: authboss.String,
StoreConfirmToken: authboss.String, StoreConfirmToken: authboss.String,
StoreConfirmed: authboss.Bool, StoreConfirmed: authboss.Bool,
} }
} }
@@ -135,18 +138,18 @@ var goConfirmEmail = func(c *Confirm, to, token string) {
// confirmEmail sends a confirmation e-mail. // confirmEmail sends a confirmation e-mail.
func (c *Confirm) confirmEmail(to, token string) { func (c *Confirm) confirmEmail(to, token string) {
p := path.Join(authboss.a.MountPath, "confirm") p := path.Join(c.MountPath, "confirm")
url := fmt.Sprintf("%s%s?%s=%s", authboss.a.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token)) url := fmt.Sprintf("%s%s?%s=%s", c.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token))
email := authboss.Email{ email := authboss.Email{
To: []string{to}, To: []string{to},
From: authboss.a.EmailFrom, From: c.EmailFrom,
Subject: authboss.a.EmailSubjectPrefix + "Confirm New Account", Subject: c.EmailSubjectPrefix + "Confirm New Account",
} }
err := response.Email(email, c.emailHTMLTemplates, tplConfirmHTML, c.emailTextTemplates, tplConfirmText, url) err := response.Email(c.Mailer, email, c.emailHTMLTemplates, tplConfirmHTML, c.emailTextTemplates, tplConfirmText, url)
if err != nil { if err != nil {
fmt.Fprintf(authboss.a.LogWriter, "confirm: Failed to send e-mail: %v", err) fmt.Fprintf(c.LogWriter, "confirm: Failed to send e-mail: %v", err)
} }
} }
@@ -166,7 +169,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r
sum := md5.Sum(toHash) sum := md5.Sum(toHash)
dbTok := base64.StdEncoding.EncodeToString(sum[:]) dbTok := base64.StdEncoding.EncodeToString(sum[:])
user, err := authboss.a.Storer.(ConfirmStorer).ConfirmUser(dbTok) user, err := c.Storer.(ConfirmStorer).ConfirmUser(dbTok)
if err == authboss.ErrUserNotFound { if err == authboss.ErrUserNotFound {
return authboss.ErrAndRedirect{Location: "/", Err: errors.New("confirm: token not found")} return authboss.ErrAndRedirect{Location: "/", Err: errors.New("confirm: token not found")}
} else if err != nil { } else if err != nil {
@@ -178,7 +181,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r
ctx.User[StoreConfirmToken] = "" ctx.User[StoreConfirmToken] = ""
ctx.User[StoreConfirmed] = true ctx.User[StoreConfirmed] = true
key, err := ctx.User.StringErr(authboss.a.PrimaryID) key, err := ctx.User.StringErr(c.PrimaryID)
if err != nil { if err != nil {
return err return err
} }
@@ -188,7 +191,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r
} }
ctx.SessionStorer.Put(authboss.SessionKey, key) ctx.SessionStorer.Put(authboss.SessionKey, key)
response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "You have successfully confirmed your account.", "", true) response.Redirect(ctx, w, r, c.RegisterOKPath, "You have successfully confirmed your account.", "", true)
return nil return nil
} }

View File

@@ -17,22 +17,24 @@ import (
) )
func setup() *Confirm { func setup() *Confirm {
authboss.NewConfig() ab := authboss.New()
authboss.a.Storer = mocks.NewMockStorer() ab.Storer = mocks.NewMockStorer()
authboss.a.LayoutHTMLEmail = template.Must(template.New("").Parse(`email ^_^`)) ab.LayoutHTMLEmail = template.Must(template.New("").Parse(`email ^_^`))
authboss.a.LayoutTextEmail = template.Must(template.New("").Parse(`email`)) ab.LayoutTextEmail = template.Must(template.New("").Parse(`email`))
c := &Confirm{} c := &Confirm{}
if err := c.Initialize(); err != nil { if err := c.Initialize(ab); err != nil {
panic(err) panic(err)
} }
return c return c
} }
func TestConfirm_Initialize(t *testing.T) { func TestConfirm_Initialize(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
c := &Confirm{} c := &Confirm{}
if err := c.Initialize(); err == nil { if err := c.Initialize(ab); err == nil {
t.Error("Should cry about not having a storer.") t.Error("Should cry about not having a storer.")
} }
@@ -58,7 +60,7 @@ func TestConfirm_Routes(t *testing.T) {
func TestConfirm_Storage(t *testing.T) { func TestConfirm_Storage(t *testing.T) {
t.Parallel() t.Parallel()
c := &Confirm{} c := &Confirm{Authboss: authboss.New()}
storage := c.Storage() storage := c.Storage()
if authboss.String != storage[StoreConfirmToken] { if authboss.String != storage[StoreConfirmToken] {
@@ -70,8 +72,10 @@ func TestConfirm_Storage(t *testing.T) {
} }
func TestConfirm_BeforeGet(t *testing.T) { func TestConfirm_BeforeGet(t *testing.T) {
t.Parallel()
c := setup() c := setup()
ctx := authboss.NewContext() ctx := c.NewContext()
if _, err := c.beforeGet(ctx); err == nil { if _, err := c.beforeGet(ctx); err == nil {
t.Error("Should stop the get due to attribute missing:", err) t.Error("Should stop the get due to attribute missing:", err)
@@ -97,12 +101,14 @@ func TestConfirm_BeforeGet(t *testing.T) {
} }
func TestConfirm_AfterRegister(t *testing.T) { func TestConfirm_AfterRegister(t *testing.T) {
t.Parallel()
c := setup() c := setup()
ctx := authboss.NewContext() ctx := c.NewContext()
log := &bytes.Buffer{} log := &bytes.Buffer{}
authboss.a.LogWriter = log c.LogWriter = log
authboss.a.Mailer = authboss.LogMailer(log) c.Mailer = authboss.LogMailer(log)
authboss.a.PrimaryID = authboss.StoreUsername c.PrimaryID = authboss.StoreUsername
sentEmail := false sentEmail := false
@@ -115,7 +121,7 @@ func TestConfirm_AfterRegister(t *testing.T) {
t.Error("Expected it to die with user error:", err) t.Error("Expected it to die with user error:", err)
} }
ctx.User = authboss.Attributes{authboss.a.PrimaryID: "username"} ctx.User = authboss.Attributes{c.PrimaryID: "username"}
if err := c.afterRegister(ctx); err == nil || err.(authboss.AttributeErr).Name != "email" { if err := c.afterRegister(ctx); err == nil || err.(authboss.AttributeErr).Name != "email" {
t.Error("Expected it to die with e-mail address error:", err) t.Error("Expected it to die with e-mail address error:", err)
} }
@@ -133,10 +139,12 @@ func TestConfirm_AfterRegister(t *testing.T) {
} }
func TestConfirm_ConfirmHandlerErrors(t *testing.T) { func TestConfirm_ConfirmHandlerErrors(t *testing.T) {
t.Parallel()
c := setup() c := setup()
log := &bytes.Buffer{} log := &bytes.Buffer{}
authboss.a.LogWriter = log c.LogWriter = log
authboss.a.Mailer = authboss.LogMailer(log) c.Mailer = authboss.LogMailer(log)
tests := []struct { tests := []struct {
URL string URL string
@@ -155,7 +163,7 @@ func TestConfirm_ConfirmHandlerErrors(t *testing.T) {
for i, test := range tests { for i, test := range tests {
r, _ := http.NewRequest("GET", test.URL, nil) r, _ := http.NewRequest("GET", test.URL, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := c.ContextFromRequest(r)
err := c.confirmHandler(ctx, w, r) err := c.confirmHandler(ctx, w, r)
if err == nil { if err == nil {
@@ -174,11 +182,14 @@ func TestConfirm_ConfirmHandlerErrors(t *testing.T) {
} }
func TestConfirm_Confirm(t *testing.T) { func TestConfirm_Confirm(t *testing.T) {
t.Parallel()
c := setup() c := setup()
ctx := authboss.NewContext() ctx := c.NewContext()
log := &bytes.Buffer{} log := &bytes.Buffer{}
authboss.a.LogWriter = log c.LogWriter = log
authboss.a.Mailer = authboss.LogMailer(log) c.PrimaryID = authboss.StoreUsername
c.Mailer = authboss.LogMailer(log)
// Create a token // Create a token
token := []byte("hi") token := []byte("hi")
@@ -186,7 +197,7 @@ func TestConfirm_Confirm(t *testing.T) {
// Create the "database" // Create the "database"
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer c.Storer = storer
user := authboss.Attributes{ user := authboss.Attributes{
authboss.StoreUsername: "usern", authboss.StoreUsername: "usern",
StoreConfirmToken: base64.StdEncoding.EncodeToString(sum[:]), StoreConfirmToken: base64.StdEncoding.EncodeToString(sum[:]),
@@ -196,7 +207,7 @@ func TestConfirm_Confirm(t *testing.T) {
// Make a request with session and context support. // Make a request with session and context support.
r, _ := http.NewRequest("GET", "http://localhost?cnf="+base64.URLEncoding.EncodeToString(token), nil) r, _ := http.NewRequest("GET", "http://localhost?cnf="+base64.URLEncoding.EncodeToString(token), nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx, _ = authboss.ContextFromRequest(r) ctx, _ = c.ContextFromRequest(r)
ctx.CookieStorer = mocks.NewMockClientStorer() ctx.CookieStorer = mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
ctx.User = user ctx.User = user

View File

@@ -279,7 +279,7 @@ func (m *MockClientStorer) Put(key, val string) { m.Values[key] = val }
func (m *MockClientStorer) Del(key string) { delete(m.Values, key) } func (m *MockClientStorer) Del(key string) { delete(m.Values, key) }
// MockRequestContext returns a new context as if it came from POST request. // MockRequestContext returns a new context as if it came from POST request.
func MockRequestContext(ab authboss.Authboss, postKeyValues ...string) *authboss.Context { func MockRequestContext(ab *authboss.Authboss, postKeyValues ...string) *authboss.Context {
keyValues := &bytes.Buffer{} keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 { for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 { if i != 0 {

View File

@@ -14,8 +14,8 @@ import (
) )
var testViewTemplate = template.Must(template.New("").Parse(`{{.external}} {{.fun}} {{.flash_success}} {{.flash_error}} {{.xsrfName}} {{.xsrfToken}}`)) var testViewTemplate = template.Must(template.New("").Parse(`{{.external}} {{.fun}} {{.flash_success}} {{.flash_error}} {{.xsrfName}} {{.xsrfToken}}`))
var testEmailHTMLTempalte = template.Must(template.New("").Parse(`<h2>{{.}}</h2>`)) var testEmailHTMLTemplate = template.Must(template.New("").Parse(`<h2>{{.}}</h2>`))
var testEmailPlainTempalte = template.Must(template.New("").Parse(`i am a {{.}}`)) var testEmailPlainTemplate = template.Must(template.New("").Parse(`i am a {{.}}`))
func TestLoadTemplates(t *testing.T) { func TestLoadTemplates(t *testing.T) {
t.Parallel() t.Parallel()
@@ -35,7 +35,7 @@ func TestLoadTemplates(t *testing.T) {
filename := filepath.Base(file.Name()) filename := filepath.Base(file.Name())
tpls, err := LoadTemplates(layout, filepath.Dir(file.Name()), filename) tpls, err := LoadTemplates(authboss.New(), layout, filepath.Dir(file.Name()), filename)
if err != nil { if err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
@@ -50,15 +50,16 @@ func TestLoadTemplates(t *testing.T) {
} }
func TestTemplates_Render(t *testing.T) { func TestTemplates_Render(t *testing.T) {
t.Parallel()
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
authboss.Cfg = &authboss.Config{ ab := authboss.New()
LayoutDataMaker: func(_ http.ResponseWriter, _ *http.Request) authboss.HTMLData { ab.LayoutDataMaker = func(_ http.ResponseWriter, _ *http.Request) authboss.HTMLData {
return authboss.HTMLData{"fun": "is"} return authboss.HTMLData{"fun": "is"}
}, }
XSRFName: "do you think", ab.XSRFName = "do you think"
XSRFMaker: func(_ http.ResponseWriter, _ *http.Request) string { ab.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "that's air you're breathing now?" return "that's air you're breathing now?"
},
} }
// Set up flashes // Set up flashes
@@ -67,7 +68,7 @@ func TestTemplates_Render(t *testing.T) {
r, _ := http.NewRequest("GET", "http://localhost", nil) r, _ := http.NewRequest("GET", "http://localhost", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := ab.ContextFromRequest(r)
ctx.SessionStorer = cookies ctx.SessionStorer = cookies
tpls := Templates{ tpls := Templates{
@@ -85,17 +86,20 @@ func TestTemplates_Render(t *testing.T) {
} }
func Test_Email(t *testing.T) { func Test_Email(t *testing.T) {
mockMailer := &mocks.MockMailer{} t.Parallel()
authboss.a.Mailer = mockMailer
htmlTpls := Templates{"html": testEmailHTMLTempalte} ab := authboss.New()
textTpls := Templates{"plain": testEmailPlainTempalte} mockMailer := &mocks.MockMailer{}
ab.Mailer = mockMailer
htmlTpls := Templates{"html": testEmailHTMLTemplate}
textTpls := Templates{"plain": testEmailPlainTemplate}
email := authboss.Email{ email := authboss.Email{
To: []string{"a@b.c"}, To: []string{"a@b.c"},
} }
err := Email(email, htmlTpls, "html", textTpls, "plain", "spoon") err := Email(ab.Mailer, email, htmlTpls, "html", textTpls, "plain", "spoon")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -117,11 +121,14 @@ func Test_Email(t *testing.T) {
} }
func TestRedirect(t *testing.T) { func TestRedirect(t *testing.T) {
t.Parallel()
ab := authboss.New()
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
r, _ := http.NewRequest("GET", "http://localhost", nil) r, _ := http.NewRequest("GET", "http://localhost", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := ab.ContextFromRequest(r)
ctx.SessionStorer = cookies ctx.SessionStorer = cookies
Redirect(ctx, w, r, "/", "success", "failure", false) Redirect(ctx, w, r, "/", "success", "failure", false)
@@ -143,11 +150,14 @@ func TestRedirect(t *testing.T) {
} }
func TestRedirect_Override(t *testing.T) { func TestRedirect_Override(t *testing.T) {
t.Parallel()
ab := authboss.New()
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
r, _ := http.NewRequest("GET", "http://localhost?redir=foo/bar", nil) r, _ := http.NewRequest("GET", "http://localhost?redir=foo/bar", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := ab.ContextFromRequest(r)
ctx.SessionStorer = cookies ctx.SessionStorer = cookies
Redirect(ctx, w, r, "/shouldNotGo", "success", "failure", true) Redirect(ctx, w, r, "/shouldNotGo", "success", "failure", true)

View File

@@ -25,19 +25,21 @@ func init() {
// Lock module // Lock module
type Lock struct { type Lock struct {
*authboss.Authboss
} }
// Initialize the module // Initialize the module
func (l *Lock) Initialize() error { func (l *Lock) Initialize(ab *authboss.Authboss) error {
if authboss.a.Storer == nil { l.Authboss = ab
if l.Storer == nil {
return errors.New("lock: Need a Storer") return errors.New("lock: Need a Storer")
} }
// Events // Events
authboss.a.Callbacks.Before(authboss.EventGet, l.beforeAuth) l.Callbacks.Before(authboss.EventGet, l.beforeAuth)
authboss.a.Callbacks.Before(authboss.EventAuth, l.beforeAuth) l.Callbacks.Before(authboss.EventAuth, l.beforeAuth)
authboss.a.Callbacks.After(authboss.EventAuth, l.afterAuth) l.Callbacks.After(authboss.EventAuth, l.afterAuth)
authboss.a.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail) l.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail)
return nil return nil
} }
@@ -50,10 +52,10 @@ func (l *Lock) Routes() authboss.RouteTable {
// Storage requirements // Storage requirements
func (l *Lock) Storage() authboss.StorageOptions { func (l *Lock) Storage() authboss.StorageOptions {
return authboss.StorageOptions{ return authboss.StorageOptions{
authboss.a.PrimaryID: authboss.String, l.PrimaryID: authboss.String,
StoreAttemptNumber: authboss.Integer, StoreAttemptNumber: authboss.Integer,
StoreAttemptTime: authboss.DateTime, StoreAttemptTime: authboss.DateTime,
StoreLocked: authboss.DateTime, StoreLocked: authboss.DateTime,
} }
} }
@@ -104,9 +106,9 @@ func (l *Lock) afterAuthFail(ctx *authboss.Context) error {
nAttempts++ nAttempts++
if time.Now().UTC().Sub(lastAttempt) <= authboss.a.LockWindow { if time.Now().UTC().Sub(lastAttempt) <= l.LockWindow {
if nAttempts >= int64(authboss.a.LockAfter) { if nAttempts >= int64(l.LockAfter) {
ctx.User[StoreLocked] = time.Now().UTC().Add(authboss.a.LockDuration) ctx.User[StoreLocked] = time.Now().UTC().Add(l.LockDuration)
} }
ctx.User[StoreAttemptNumber] = nAttempts ctx.User[StoreAttemptNumber] = nAttempts
@@ -124,7 +126,7 @@ func (l *Lock) afterAuthFail(ctx *authboss.Context) error {
// Lock a user manually. // Lock a user manually.
func (l *Lock) Lock(key string) error { func (l *Lock) Lock(key string) error {
user, err := authboss.a.Storer.Get(key) user, err := l.Storer.Get(key)
if err != nil { if err != nil {
return err return err
} }
@@ -134,14 +136,14 @@ func (l *Lock) Lock(key string) error {
return err return err
} }
attr[StoreLocked] = time.Now().UTC().Add(authboss.a.LockDuration) attr[StoreLocked] = time.Now().UTC().Add(l.LockDuration)
return authboss.a.Storer.Put(key, attr) return l.Storer.Put(key, attr)
} }
// Unlock a user that was locked by this module. // Unlock a user that was locked by this module.
func (l *Lock) Unlock(key string) error { func (l *Lock) Unlock(key string) error {
user, err := authboss.a.Storer.Get(key) user, err := l.Storer.Get(key)
if err != nil { if err != nil {
return err return err
} }
@@ -153,9 +155,9 @@ func (l *Lock) Unlock(key string) error {
// Set the last attempt to be -window*2 to avoid immediately // Set the last attempt to be -window*2 to avoid immediately
// giving another login failure. // giving another login failure.
attr[StoreAttemptTime] = time.Now().UTC().Add(-authboss.a.LockWindow * 2) attr[StoreAttemptTime] = time.Now().UTC().Add(-l.LockWindow * 2)
attr[StoreAttemptNumber] = int64(0) attr[StoreAttemptNumber] = int64(0)
attr[StoreLocked] = time.Now().UTC().Add(-authboss.a.LockDuration) attr[StoreLocked] = time.Now().UTC().Add(-l.LockDuration)
return authboss.a.Storer.Put(key, attr) return l.Storer.Put(key, attr)
} }

View File

@@ -9,8 +9,9 @@ import (
) )
func TestStorage(t *testing.T) { func TestStorage(t *testing.T) {
l := &Lock{} t.Parallel()
authboss.NewConfig()
l := &Lock{authboss.New()}
storage := l.Storage() storage := l.Storage()
if _, ok := storage[StoreAttemptNumber]; !ok { if _, ok := storage[StoreAttemptNumber]; !ok {
t.Error("Expected attempt number storage option.") t.Error("Expected attempt number storage option.")
@@ -24,9 +25,11 @@ func TestStorage(t *testing.T) {
} }
func TestBeforeAuth(t *testing.T) { func TestBeforeAuth(t *testing.T) {
t.Parallel()
l := &Lock{} l := &Lock{}
authboss.NewConfig() ab := authboss.New()
ctx := authboss.NewContext() ctx := ab.NewContext()
if interrupt, err := l.beforeAuth(ctx); err != errUserMissing { if interrupt, err := l.beforeAuth(ctx); err != errUserMissing {
t.Error("Expected an error because of missing user:", err) t.Error("Expected an error because of missing user:", err)
@@ -44,17 +47,19 @@ func TestBeforeAuth(t *testing.T) {
} }
func TestAfterAuth(t *testing.T) { func TestAfterAuth(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
lock := Lock{} lock := Lock{}
ctx := authboss.NewContext() ctx := ab.NewContext()
if err := lock.afterAuth(ctx); err != errUserMissing { if err := lock.afterAuth(ctx); err != errUserMissing {
t.Error("Expected an error because of missing user:", err) t.Error("Expected an error because of missing user:", err)
} }
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer ab.Storer = storer
ctx.User = authboss.Attributes{authboss.a.PrimaryID: "john@john.com"} ctx.User = authboss.Attributes{ab.PrimaryID: "john@john.com"}
if err := lock.afterAuth(ctx); err != nil { if err := lock.afterAuth(ctx); err != nil {
t.Error(err) t.Error(err)
@@ -68,21 +73,23 @@ func TestAfterAuth(t *testing.T) {
} }
func TestAfterAuthFail_Lock(t *testing.T) { func TestAfterAuthFail_Lock(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
var old, current time.Time var old, current time.Time
var ok bool var ok bool
ctx := authboss.NewContext() ctx := ab.NewContext()
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer ab.Storer = storer
lock := Lock{} lock := Lock{ab}
authboss.a.LockWindow = 30 * time.Minute ab.LockWindow = 30 * time.Minute
authboss.a.LockDuration = 30 * time.Minute ab.LockDuration = 30 * time.Minute
authboss.a.LockAfter = 3 ab.LockAfter = 3
email := "john@john.com" email := "john@john.com"
ctx.User = map[string]interface{}{authboss.a.PrimaryID: email} ctx.User = map[string]interface{}{ab.PrimaryID: email}
old = time.Now().UTC().Add(-1 * time.Hour) old = time.Now().UTC().Add(-1 * time.Hour)
@@ -116,24 +123,26 @@ func TestAfterAuthFail_Lock(t *testing.T) {
} }
func TestAfterAuthFail_Reset(t *testing.T) { func TestAfterAuthFail_Reset(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
var old, current time.Time var old, current time.Time
var ok bool var ok bool
ctx := authboss.NewContext() ctx := ab.NewContext()
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
lock := Lock{} lock := Lock{ab}
authboss.a.LockWindow = 30 * time.Minute ab.LockWindow = 30 * time.Minute
authboss.a.Storer = storer ab.Storer = storer
old = time.Now().UTC().Add(-time.Hour) old = time.Now().UTC().Add(-time.Hour)
email := "john@john.com" email := "john@john.com"
ctx.User = map[string]interface{}{ ctx.User = map[string]interface{}{
authboss.a.PrimaryID: email, ab.PrimaryID: email,
StoreAttemptNumber: int64(2), StoreAttemptNumber: int64(2),
StoreAttemptTime: old, StoreAttemptTime: old,
StoreLocked: old, StoreLocked: old,
} }
lock.afterAuthFail(ctx) lock.afterAuthFail(ctx)
@@ -149,9 +158,11 @@ func TestAfterAuthFail_Reset(t *testing.T) {
} }
func TestAfterAuthFail_Errors(t *testing.T) { func TestAfterAuthFail_Errors(t *testing.T) {
authboss.NewConfig() t.Parallel()
lock := Lock{}
ctx := authboss.NewContext() ab := authboss.New()
lock := Lock{ab}
ctx := ab.NewContext()
lock.afterAuthFail(ctx) lock.afterAuthFail(ctx)
if _, ok := ctx.User[StoreAttemptNumber]; ok { if _, ok := ctx.User[StoreAttemptNumber]; ok {
@@ -160,15 +171,17 @@ func TestAfterAuthFail_Errors(t *testing.T) {
} }
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer ab.Storer = storer
lock := Lock{} lock := Lock{ab}
email := "john@john.com" email := "john@john.com"
storer.Users[email] = map[string]interface{}{ storer.Users[email] = map[string]interface{}{
authboss.a.PrimaryID: email, ab.PrimaryID: email,
"password": "password", "password": "password",
} }
err := lock.Lock(email) err := lock.Lock(email)
@@ -182,17 +195,19 @@ func TestLock(t *testing.T) {
} }
func TestUnlock(t *testing.T) { func TestUnlock(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer ab.Storer = storer
lock := Lock{} lock := Lock{ab}
authboss.a.LockWindow = 1 * time.Hour ab.LockWindow = 1 * time.Hour
email := "john@john.com" email := "john@john.com"
storer.Users[email] = map[string]interface{}{ storer.Users[email] = map[string]interface{}{
authboss.a.PrimaryID: email, ab.PrimaryID: email,
"password": "password", "password": "password",
"locked": true, "locked": true,
} }
err := lock.Unlock(email) err := lock.Unlock(email)
@@ -201,7 +216,7 @@ func TestUnlock(t *testing.T) {
} }
attemptTime := storer.Users[email][StoreAttemptTime].(time.Time) attemptTime := storer.Users[email][StoreAttemptTime].(time.Time)
if attemptTime.After(time.Now().UTC().Add(-authboss.a.LockWindow)) { if attemptTime.After(time.Now().UTC().Add(-ab.LockWindow)) {
t.Error("StoreLocked not set correctly:", attemptTime) t.Error("StoreLocked not set correctly:", attemptTime)
} }
if number := storer.Users[email][StoreAttemptNumber].(int64); number != int64(0) { if number := storer.Users[email][StoreAttemptNumber].(int64); number != int64(0) {

View File

@@ -22,15 +22,18 @@ var (
) )
// OAuth2 module // OAuth2 module
type OAuth2 struct{} type OAuth2 struct {
*authboss.Authboss
}
func init() { func init() {
authboss.RegisterModule("oauth2", &OAuth2{}) authboss.RegisterModule("oauth2", &OAuth2{})
} }
// Initialize module // Initialize module
func (o *OAuth2) Initialize() error { func (o *OAuth2) Initialize(ab *authboss.Authboss) error {
if authboss.a.OAuth2Storer == nil { o.Authboss = ab
if o.OAuth2Storer == nil {
return errors.New("oauth2: need an OAuth2Storer") return errors.New("oauth2: need an OAuth2Storer")
} }
return nil return nil
@@ -40,23 +43,23 @@ func (o *OAuth2) Initialize() error {
func (o *OAuth2) Routes() authboss.RouteTable { func (o *OAuth2) Routes() authboss.RouteTable {
routes := make(authboss.RouteTable) routes := make(authboss.RouteTable)
for prov, cfg := range authboss.a.OAuth2Providers { for prov, cfg := range o.OAuth2Providers {
prov = strings.ToLower(prov) prov = strings.ToLower(prov)
init := fmt.Sprintf("/oauth2/%s", prov) init := fmt.Sprintf("/oauth2/%s", prov)
callback := fmt.Sprintf("/oauth2/callback/%s", prov) callback := fmt.Sprintf("/oauth2/callback/%s", prov)
routes[init] = oauthInit routes[init] = o.oauthInit
routes[callback] = oauthCallback routes[callback] = o.oauthCallback
if len(authboss.a.MountPath) > 0 { if len(o.MountPath) > 0 {
callback = path.Join(authboss.a.MountPath, callback) callback = path.Join(o.MountPath, callback)
} }
a.OAuth2Config.RedirectURL = authboss.a.RootURL + callback cfg.OAuth2Config.RedirectURL = o.RootURL + callback
} }
routes["/oauth2/logout"] = logout routes["/oauth2/logout"] = o.logout
return routes return routes
} }
@@ -73,9 +76,9 @@ func (o *OAuth2) Storage() authboss.StorageOptions {
} }
} }
func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { func (o *OAuth2) oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path)) provider := strings.ToLower(filepath.Base(r.URL.Path))
cfg, ok := authboss.a.OAuth2Providers[provider] cfg, ok := o.OAuth2Providers[provider]
if !ok { if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider) return fmt.Errorf("OAuth2 provider %q not found", provider)
} }
@@ -106,9 +109,9 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
ctx.SessionStorer.Del(authboss.SessionOAuth2Params) ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
} }
url := a.OAuth2Config.AuthCodeURL(state) url := cfg.OAuth2Config.AuthCodeURL(state)
extraParams := a.AdditionalParams.Encode() extraParams := cfg.AdditionalParams.Encode()
if len(extraParams) > 0 { if len(extraParams) > 0 {
url = fmt.Sprintf("%s&%s", url, extraParams) url = fmt.Sprintf("%s&%s", url, extraParams)
} }
@@ -120,7 +123,7 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
// for testing // for testing
var exchanger = (*oauth2.Config).Exchange var exchanger = (*oauth2.Config).Exchange
func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { func (o *OAuth2) oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path)) provider := strings.ToLower(filepath.Base(r.URL.Path))
sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State) sessState, err := ctx.SessionStorer.GetErr(authboss.SessionOAuth2State)
@@ -140,18 +143,18 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
hasErr := r.FormValue("error") hasErr := r.FormValue("error")
if len(hasErr) > 0 { if len(hasErr) > 0 {
if err := authboss.a.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil { if err := o.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
return err return err
} }
return authboss.ErrAndRedirect{ return authboss.ErrAndRedirect{
Err: errors.New(r.FormValue("error_reason")), Err: errors.New(r.FormValue("error_reason")),
Location: authboss.a.AuthLoginFailPath, Location: o.AuthLoginFailPath,
FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)), FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)),
} }
} }
cfg, ok := authboss.a.OAuth2Providers[provider] cfg, ok := o.OAuth2Providers[provider]
if !ok { if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider) return fmt.Errorf("OAuth2 provider %q not found", provider)
} }
@@ -165,12 +168,12 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
// Get the code // Get the code
code := r.FormValue("code") code := r.FormValue("code")
token, err := exchanger(a.OAuth2Config, oauth2.NoContext, code) token, err := exchanger(cfg.OAuth2Config, oauth2.NoContext, code)
if err != nil { if err != nil {
return fmt.Errorf("Could not validate oauth2 code: %v", err) return fmt.Errorf("Could not validate oauth2 code: %v", err)
} }
user, err := a.Callback(*cfg.OAuth2Config, token) user, err := cfg.Callback(*cfg.OAuth2Config, token)
if err != nil { if err != nil {
return err return err
} }
@@ -189,7 +192,7 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
user[authboss.StoreOAuth2Refresh] = token.RefreshToken user[authboss.StoreOAuth2Refresh] = token.RefreshToken
} }
if err = authboss.a.OAuth2Storer.PutOAuth(uid, provider, user); err != nil { if err = o.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
return err return err
} }
@@ -197,13 +200,13 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider)) ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey) ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
if err = authboss.a.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil { if err = o.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
return nil return nil
} }
ctx.SessionStorer.Del(authboss.SessionOAuth2Params) ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
redirect := authboss.a.AuthLoginOKPath redirect := o.AuthLoginOKPath
query := make(url.Values) query := make(url.Values)
for k, v := range values { for k, v := range values {
switch k { switch k {
@@ -224,14 +227,14 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
return nil return nil
} }
func logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { func (o *OAuth2) logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
switch r.Method { switch r.Method {
case "GET": case "GET":
ctx.SessionStorer.Del(authboss.SessionKey) ctx.SessionStorer.Del(authboss.SessionKey)
ctx.CookieStorer.Del(authboss.CookieRemember) ctx.CookieStorer.Del(authboss.CookieRemember)
ctx.SessionStorer.Del(authboss.SessionLastAction) ctx.SessionStorer.Del(authboss.SessionLastAction)
response.Redirect(ctx, w, r, authboss.a.AuthLogoutOKPath, "You have logged out", "", true) response.Redirect(ctx, w, r, o.AuthLogoutOKPath, "You have logged out", "", true)
default: default:
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }

View File

@@ -30,29 +30,34 @@ var testProviders = map[string]authboss.OAuth2Provider{
} }
func TestInitialize(t *testing.T) { func TestInitialize(t *testing.T) {
authboss.Cfg = authboss.NewConfig() t.Parallel()
authboss.a.OAuth2Storer = mocks.NewMockStorer()
ab := authboss.New()
ab.OAuth2Storer = mocks.NewMockStorer()
o := OAuth2{} o := OAuth2{}
if err := o.Initialize(); err != nil { if err := o.Initialize(ab); err != nil {
t.Error(err) t.Error(err)
} }
} }
func TestRoutes(t *testing.T) { func TestRoutes(t *testing.T) {
t.Parallel()
root := "https://localhost:8080" root := "https://localhost:8080"
mount := "/auth" mount := "/auth"
authboss.Cfg = authboss.NewConfig() ab := authboss.New()
authboss.a.RootURL = root o := OAuth2{ab}
authboss.a.MountPath = mount
authboss.a.OAuth2Providers = testProviders
googleCfg := authboss.a.OAuth2Providers["google"].OAuth2Config ab.RootURL = root
ab.MountPath = mount
ab.OAuth2Providers = testProviders
googleCfg := ab.OAuth2Providers["google"].OAuth2Config
if 0 != len(googleCfg.RedirectURL) { if 0 != len(googleCfg.RedirectURL) {
t.Error("RedirectURL should not be set") t.Error("RedirectURL should not be set")
} }
o := OAuth2{}
routes := o.Routes() routes := o.Routes()
authURL := path.Join("/oauth2", "google") authURL := path.Join("/oauth2", "google")
tokenURL := path.Join("/oauth2", "callback", "google") tokenURL := path.Join("/oauth2", "callback", "google")
@@ -71,18 +76,20 @@ func TestRoutes(t *testing.T) {
} }
func TestOAuth2Init(t *testing.T) { func TestOAuth2Init(t *testing.T) {
cfg := authboss.NewConfig() t.Parallel()
ab := authboss.New()
oauth := OAuth2{ab}
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
a.OAuth2Providers = testProviders ab.OAuth2Providers = testProviders
authboss.Cfg = cfg
r, _ := http.NewRequest("GET", "/oauth2/google?redir=/my/redirect%23lol&rm=true", nil) r, _ := http.NewRequest("GET", "/oauth2/google?redir=/my/redirect%23lol&rm=true", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx := authboss.NewContext() ctx := ab.NewContext()
ctx.SessionStorer = session ctx.SessionStorer = session
oauthInit(ctx, w, r) oauth.oauthInit(ctx, w, r)
if w.Code != http.StatusFound { if w.Code != http.StatusFound {
t.Error("Code was wrong:", w.Code) t.Error("Code was wrong:", w.Code)
@@ -112,7 +119,10 @@ func TestOAuth2Init(t *testing.T) {
} }
func TestOAuthSuccess(t *testing.T) { func TestOAuthSuccess(t *testing.T) {
cfg := authboss.NewConfig() t.Parallel()
ab := authboss.New()
oauth := OAuth2{ab}
expiry := time.Now().UTC().Add(3600 * time.Second) expiry := time.Now().UTC().Add(3600 * time.Second)
fakeToken := &oauth2.Token{ fakeToken := &oauth2.Token{
@@ -137,7 +147,7 @@ func TestOAuthSuccess(t *testing.T) {
return fakeToken, nil return fakeToken, nil
} }
a.OAuth2Providers = map[string]authboss.OAuth2Provider{ ab.OAuth2Providers = map[string]authboss.OAuth2Provider{
"fake": authboss.OAuth2Provider{ "fake": authboss.OAuth2Provider{
OAuth2Config: &oauth2.Config{ OAuth2Config: &oauth2.Config{
ClientID: `jazz`, ClientID: `jazz`,
@@ -152,7 +162,6 @@ func TestOAuthSuccess(t *testing.T) {
AdditionalParams: url.Values{"include_requested_scopes": []string{"true"}}, AdditionalParams: url.Values{"include_requested_scopes": []string{"true"}},
}, },
} }
authboss.Cfg = cfg
values := make(url.Values) values := make(url.Values)
values.Set("code", "code") values.Set("code", "code")
@@ -161,17 +170,17 @@ func TestOAuthSuccess(t *testing.T) {
url := fmt.Sprintf("/oauth2/fake?%s", values.Encode()) url := fmt.Sprintf("/oauth2/fake?%s", values.Encode())
r, _ := http.NewRequest("GET", url, nil) r, _ := http.NewRequest("GET", url, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx := authboss.NewContext() ctx := ab.NewContext()
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State) session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
session.Put(authboss.SessionOAuth2Params, `{"redir":"/myurl?myparam=5","rm":"true"}`) session.Put(authboss.SessionOAuth2Params, `{"redir":"/myurl?myparam=5","rm":"true"}`)
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
ctx.SessionStorer = session ctx.SessionStorer = session
a.OAuth2Storer = storer ab.OAuth2Storer = storer
a.AuthLoginOKPath = "/fakeloginok" ab.AuthLoginOKPath = "/fakeloginok"
if err := oauthCallback(ctx, w, r); err != nil { if err := oauth.oauthCallback(ctx, w, r); err != nil {
t.Error(err) t.Error(err)
} }
@@ -209,46 +218,50 @@ func TestOAuthSuccess(t *testing.T) {
} }
func TestOAuthXSRFFailure(t *testing.T) { func TestOAuthXSRFFailure(t *testing.T) {
cfg := authboss.NewConfig() t.Parallel()
ab := authboss.New()
oauth := OAuth2{ab}
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State) session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
a.OAuth2Providers = testProviders ab.OAuth2Providers = testProviders
authboss.Cfg = cfg
values := url.Values{} values := url.Values{}
values.Set(authboss.FormValueOAuth2State, "notstate") values.Set(authboss.FormValueOAuth2State, "notstate")
values.Set("code", "code") values.Set("code", "code")
ctx := authboss.NewContext() ctx := ab.NewContext()
ctx.SessionStorer = session ctx.SessionStorer = session
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil) r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
err := oauthCallback(ctx, nil, r) err := oauth.oauthCallback(ctx, nil, r)
if err != errOAuthStateValidation { if err != errOAuthStateValidation {
t.Error("Should have gotten an error about state validation:", err) t.Error("Should have gotten an error about state validation:", err)
} }
} }
func TestOAuthFailure(t *testing.T) { func TestOAuthFailure(t *testing.T) {
cfg := authboss.NewConfig() t.Parallel()
a.OAuth2Providers = testProviders ab := authboss.New()
authboss.Cfg = cfg oauth := OAuth2{ab}
ab.OAuth2Providers = testProviders
values := url.Values{} values := url.Values{}
values.Set("error", "something") values.Set("error", "something")
values.Set("error_reason", "auth_failure") values.Set("error_reason", "auth_failure")
values.Set("error_description", "Failed to auth.") values.Set("error_description", "Failed to auth.")
ctx := authboss.NewContext() ctx := ab.NewContext()
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State) session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
ctx.SessionStorer = session ctx.SessionStorer = session
r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil) r, _ := http.NewRequest("GET", "/oauth2/google?"+values.Encode(), nil)
err := oauthCallback(ctx, nil, r) err := oauth.oauthCallback(ctx, nil, r)
if red, ok := err.(authboss.ErrAndRedirect); !ok { if red, ok := err.(authboss.ErrAndRedirect); !ok {
t.Error("Should be a redirect error") t.Error("Should be a redirect error")
} else if len(red.FlashError) == 0 { } else if len(red.FlashError) == 0 {
@@ -259,19 +272,22 @@ func TestOAuthFailure(t *testing.T) {
} }
func TestLogout(t *testing.T) { func TestLogout(t *testing.T) {
authboss.Cfg = authboss.NewConfig() t.Parallel()
authboss.a.AuthLogoutOKPath = "/dashboard"
ab := authboss.New()
oauth := OAuth2{ab}
ab.AuthLogoutOKPath = "/dashboard"
r, _ := http.NewRequest("GET", "/oauth2/google?", nil) r, _ := http.NewRequest("GET", "/oauth2/google?", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
ctx := authboss.NewContext() ctx := ab.NewContext()
session := mocks.NewMockClientStorer(authboss.SessionKey, "asdf", authboss.SessionLastAction, "1234") session := mocks.NewMockClientStorer(authboss.SessionKey, "asdf", authboss.SessionLastAction, "1234")
cookies := mocks.NewMockClientStorer(authboss.CookieRemember, "qwert") cookies := mocks.NewMockClientStorer(authboss.CookieRemember, "qwert")
ctx.SessionStorer = session ctx.SessionStorer = session
ctx.CookieStorer = cookies ctx.CookieStorer = cookies
if err := logout(ctx, w, r); err != nil { if err := oauth.logout(ctx, w, r); err != nil {
t.Error(err) t.Error(err)
} }
@@ -292,7 +308,7 @@ func TestLogout(t *testing.T) {
} }
location := w.Header().Get("Location") location := w.Header().Get("Location")
if location != authboss.a.AuthLogoutOKPath { if location != ab.AuthLogoutOKPath {
t.Error("Redirect wrong:", location) t.Error("Redirect wrong:", location)
} }
} }

View File

@@ -26,8 +26,8 @@ type googleMeResponse struct {
var clientGet = (*http.Client).Get var clientGet = (*http.Client).Get
// Google is a callback appropriate for use with Google's OAuth2 configuration. // Google is a callback appropriate for use with Google's OAuth2 configuration.
func Google(a.oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) { func Google(cfg oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) {
client := a.Client(oauth2.NoContext, token) client := cfg.Client(oauth2.NoContext, token)
resp, err := clientGet(client, googleInfoEndpoint) resp, err := clientGet(client, googleInfoEndpoint)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -56,39 +56,42 @@ func init() {
// Recover module // Recover module
type Recover struct { type Recover struct {
*authboss.Authboss
templates response.Templates templates response.Templates
emailHTMLTemplates response.Templates emailHTMLTemplates response.Templates
emailTextTemplates response.Templates emailTextTemplates response.Templates
} }
// Initialize module // Initialize module
func (r *Recover) Initialize() (err error) { func (r *Recover) Initialize(ab *authboss.Authboss) (err error) {
if authboss.a.Storer == nil { r.Authboss = ab
if r.Storer == nil {
return errors.New("recover: Need a RecoverStorer") return errors.New("recover: Need a RecoverStorer")
} }
if _, ok := authboss.a.Storer.(RecoverStorer); !ok { if _, ok := r.Storer.(RecoverStorer); !ok {
return errors.New("recover: RecoverStorer required for recover functionality") return errors.New("recover: RecoverStorer required for recover functionality")
} }
if len(authboss.a.XSRFName) == 0 { if len(r.XSRFName) == 0 {
return errors.New("auth: XSRFName must be set") return errors.New("auth: XSRFName must be set")
} }
if authboss.a.XSRFMaker == nil { if r.XSRFMaker == nil {
return errors.New("auth: XSRFMaker must be defined") return errors.New("auth: XSRFMaker must be defined")
} }
r.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplRecover, tplRecoverComplete) r.templates, err = response.LoadTemplates(r.Authboss, r.Layout, r.ViewsPath, tplRecover, tplRecoverComplete)
if err != nil { if err != nil {
return err return err
} }
r.emailHTMLTemplates, err = response.LoadTemplates(authboss.a.LayoutHTMLEmail, authboss.a.ViewsPath, tplInitHTMLEmail) r.emailHTMLTemplates, err = response.LoadTemplates(r.Authboss, r.LayoutHTMLEmail, r.ViewsPath, tplInitHTMLEmail)
if err != nil { if err != nil {
return err return err
} }
r.emailTextTemplates, err = response.LoadTemplates(authboss.a.LayoutTextEmail, authboss.a.ViewsPath, tplInitTextEmail) r.emailTextTemplates, err = response.LoadTemplates(r.Authboss, r.LayoutTextEmail, r.ViewsPath, tplInitTextEmail)
if err != nil { if err != nil {
return err return err
} }
@@ -107,7 +110,7 @@ func (r *Recover) Routes() authboss.RouteTable {
// Storage requirements // Storage requirements
func (r *Recover) Storage() authboss.StorageOptions { func (r *Recover) Storage() authboss.StorageOptions {
return authboss.StorageOptions{ return authboss.StorageOptions{
authboss.a.PrimaryID: authboss.String, r.PrimaryID: authboss.String,
authboss.StoreEmail: authboss.String, authboss.StoreEmail: authboss.String,
authboss.StorePassword: authboss.String, authboss.StorePassword: authboss.String,
StoreRecoverToken: authboss.String, StoreRecoverToken: authboss.String,
@@ -119,31 +122,31 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
switch r.Method { switch r.Method {
case methodGET: case methodGET:
data := authboss.NewHTMLData( data := authboss.NewHTMLData(
"primaryID", authboss.a.PrimaryID, "primaryID", rec.PrimaryID,
"primaryIDValue", "", "primaryIDValue", "",
"confirmPrimaryIDValue", "", "confirmPrimaryIDValue", "",
) )
return rec.templates.Render(ctx, w, r, tplRecover, data) return rec.templates.Render(ctx, w, r, tplRecover, data)
case methodPOST: case methodPOST:
primaryID, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID) primaryID, _ := ctx.FirstPostFormValue(rec.PrimaryID)
confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", authboss.a.PrimaryID)) confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", rec.PrimaryID))
errData := authboss.NewHTMLData( errData := authboss.NewHTMLData(
"primaryID", authboss.a.PrimaryID, "primaryID", rec.PrimaryID,
"primaryIDValue", primaryID, "primaryIDValue", primaryID,
"confirmPrimaryIDValue", confirmPrimaryID, "confirmPrimaryIDValue", confirmPrimaryID,
) )
policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID) policies := authboss.FilterValidators(rec.Policies, rec.PrimaryID)
if validationErrs := ctx.Validate(policies, authboss.a.PrimaryID, authboss.ConfirmPrefix+authboss.a.PrimaryID).Map(); len(validationErrs) > 0 { if validationErrs := ctx.Validate(policies, rec.PrimaryID, authboss.ConfirmPrefix+rec.PrimaryID).Map(); len(validationErrs) > 0 {
errData.MergeKV("errs", validationErrs) errData.MergeKV("errs", validationErrs)
return rec.templates.Render(ctx, w, r, tplRecover, errData) return rec.templates.Render(ctx, w, r, tplRecover, errData)
} }
// redirect to login when user not found to prevent username sniffing // redirect to login when user not found to prevent username sniffing
if err := ctx.LoadUser(primaryID); err == authboss.ErrUserNotFound { if err := ctx.LoadUser(primaryID); err == authboss.ErrUserNotFound {
return authboss.ErrAndRedirect{err, authboss.a.RecoverOKPath, recoverInitiateSuccessFlash, ""} return authboss.ErrAndRedirect{err, rec.RecoverOKPath, recoverInitiateSuccessFlash, ""}
} else if err != nil { } else if err != nil {
return err return err
} }
@@ -159,7 +162,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
} }
ctx.User[StoreRecoverToken] = encodedChecksum ctx.User[StoreRecoverToken] = encodedChecksum
ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(authboss.a.RecoverTokenDuration) ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(rec.RecoverTokenDuration)
if err := ctx.SaveUser(); err != nil { if err := ctx.SaveUser(); err != nil {
return err return err
@@ -168,7 +171,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
goRecoverEmail(rec, email, encodedToken) goRecoverEmail(rec, email, encodedToken)
ctx.SessionStorer.Put(authboss.FlashSuccessKey, recoverInitiateSuccessFlash) ctx.SessionStorer.Put(authboss.FlashSuccessKey, recoverInitiateSuccessFlash)
response.Redirect(ctx, w, r, authboss.a.RecoverOKPath, "", "", true) response.Redirect(ctx, w, r, rec.RecoverOKPath, "", "", true)
default: default:
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }
@@ -191,17 +194,17 @@ var goRecoverEmail = func(r *Recover, to, encodedToken string) {
} }
func (r *Recover) sendRecoverEmail(to, encodedToken string) { func (r *Recover) sendRecoverEmail(to, encodedToken string) {
p := path.Join(authboss.a.MountPath, "recover/complete") p := path.Join(r.MountPath, "recover/complete")
url := fmt.Sprintf("%s%s?token=%s", authboss.a.RootURL, p, encodedToken) url := fmt.Sprintf("%s%s?token=%s", r.RootURL, p, encodedToken)
email := authboss.Email{ email := authboss.Email{
To: []string{to}, To: []string{to},
From: authboss.a.EmailFrom, From: r.EmailFrom,
Subject: authboss.a.EmailSubjectPrefix + "Password Reset", Subject: r.EmailSubjectPrefix + "Password Reset",
} }
if err := response.Email(email, r.emailHTMLTemplates, tplInitHTMLEmail, r.emailTextTemplates, tplInitTextEmail, url); err != nil { if err := response.Email(r.Mailer, email, r.emailHTMLTemplates, tplInitHTMLEmail, r.emailTextTemplates, tplInitTextEmail, url); err != nil {
fmt.Fprintln(authboss.a.LogWriter, "recover: failed to send recover email:", err) fmt.Fprintln(r.LogWriter, "recover: failed to send recover email:", err)
} }
} }
@@ -227,7 +230,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
password, _ := ctx.FirstPostFormValue("password") password, _ := ctx.FirstPostFormValue("password")
confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword") confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword")
policies := authboss.FilterValidators(authboss.a.Policies, "password") policies := authboss.FilterValidators(r.Policies, "password")
if validationErrs := ctx.Validate(policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 { if validationErrs := ctx.Validate(policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
data := authboss.NewHTMLData( data := authboss.NewHTMLData(
"token", token, "token", token,
@@ -242,7 +245,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
return err return err
} }
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), authboss.a.BCryptCost) encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), r.BCryptCost)
if err != nil { if err != nil {
return err return err
} }
@@ -252,7 +255,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
var nullTime time.Time var nullTime time.Time
ctx.User[StoreRecoverTokenExpiry] = nullTime ctx.User[StoreRecoverTokenExpiry] = nullTime
primaryID, err := ctx.User.StringErr(authboss.a.PrimaryID) primaryID, err := ctx.User.StringErr(r.PrimaryID)
if err != nil { if err != nil {
return err return err
} }
@@ -261,12 +264,12 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
return err return err
} }
if err := authboss.a.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil { if err := r.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil {
return err return err
} }
ctx.SessionStorer.Put(authboss.SessionKey, primaryID) ctx.SessionStorer.Put(authboss.SessionKey, primaryID)
response.Redirect(ctx, w, req, authboss.a.AuthLoginOKPath, "", "", true) response.Redirect(ctx, w, req, r.AuthLoginOKPath, "", "", true)
default: default:
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
} }
@@ -287,7 +290,7 @@ func verifyToken(ctx *authboss.Context) (attrs authboss.Attributes, err error) {
} }
sum := md5.Sum(decoded) sum := md5.Sum(decoded)
storer := authboss.a.Storer.(RecoverStorer) storer := ctx.Storer.(RecoverStorer)
userInter, err := storer.RecoverUser(base64.StdEncoding.EncodeToString(sum[:])) userInter, err := storer.RecoverUser(base64.StdEncoding.EncodeToString(sum[:]))
if err != nil { if err != nil {

View File

@@ -25,45 +25,47 @@ func testSetup() (r *Recover, s *mocks.MockStorer, l *bytes.Buffer) {
s = mocks.NewMockStorer() s = mocks.NewMockStorer()
l = &bytes.Buffer{} l = &bytes.Buffer{}
authboss.Cfg = authboss.NewConfig() ab := authboss.New()
authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) ab.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`)) ab.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`))
authboss.a.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) ab.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.Storer = s ab.Storer = s
authboss.a.XSRFName = "xsrf" ab.XSRFName = "xsrf"
authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { ab.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "xsrfvalue" return "xsrfvalue"
} }
authboss.a.PrimaryID = authboss.StoreUsername ab.PrimaryID = authboss.StoreUsername
authboss.a.LogWriter = l ab.LogWriter = l
r = &Recover{} r = &Recover{}
if err := r.Initialize(); err != nil { if err := r.Initialize(ab); err != nil {
panic(err) panic(err)
} }
return r, s, l return r, s, l
} }
func testRequest(method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) { func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) {
r, err := http.NewRequest(method, "", nil) r, err := http.NewRequest(method, "", nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
sessionStorer := mocks.NewMockClientStorer() sessionStorer := mocks.NewMockClientStorer()
ctx := mocks.MockRequestContext(postFormValues...) ctx := mocks.MockRequestContext(ab, postFormValues...)
ctx.SessionStorer = sessionStorer ctx.SessionStorer = sessionStorer
return ctx, httptest.NewRecorder(), r, sessionStorer return ctx, httptest.NewRecorder(), r, sessionStorer
} }
func TestRecover(t *testing.T) { func TestRecover(t *testing.T) {
t.Parallel()
r, _, _ := testSetup() r, _, _ := testSetup()
storage := r.Storage() storage := r.Storage()
if storage[authboss.a.PrimaryID] != authboss.String { if storage[r.PrimaryID] != authboss.String {
t.Error("Expected storage KV:", authboss.a.PrimaryID, authboss.String) t.Error("Expected storage KV:", r.PrimaryID, authboss.String)
} }
if storage[authboss.StoreEmail] != authboss.String { if storage[authboss.StoreEmail] != authboss.String {
t.Error("Expected storage KV:", authboss.StoreEmail, authboss.String) t.Error("Expected storage KV:", authboss.StoreEmail, authboss.String)
@@ -88,8 +90,10 @@ func TestRecover(t *testing.T) {
} }
func TestRecover_startHandlerFunc_GET(t *testing.T) { func TestRecover_startHandlerFunc_GET(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
ctx, w, r, _ := testRequest("GET") ctx, w, r, _ := testRequest(rec.Authboss, "GET")
if err := rec.startHandlerFunc(ctx, w, r); err != nil { if err := rec.startHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -103,17 +107,19 @@ func TestRecover_startHandlerFunc_GET(t *testing.T) {
if !strings.Contains(body, `<form action="recover"`) { if !strings.Contains(body, `<form action="recover"`) {
t.Error("Should have rendered a form") t.Error("Should have rendered a form")
} }
if !strings.Contains(body, `name="`+authboss.a.PrimaryID) { if !strings.Contains(body, `name="`+rec.PrimaryID) {
t.Error("Form should contain the primary ID field") t.Error("Form should contain the primary ID field")
} }
if !strings.Contains(body, `name="confirm_`+authboss.a.PrimaryID) { if !strings.Contains(body, `name="confirm_`+rec.PrimaryID) {
t.Error("Form should contain the confirm primary ID field") t.Error("Form should contain the confirm primary ID field")
} }
} }
func TestRecover_startHandlerFunc_POST_ValidationFails(t *testing.T) { func TestRecover_startHandlerFunc_POST_ValidationFails(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
ctx, w, r, _ := testRequest("POST") ctx, w, r, _ := testRequest(rec.Authboss, "POST")
if err := rec.startHandlerFunc(ctx, w, r); err != nil { if err := rec.startHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -129,8 +135,10 @@ func TestRecover_startHandlerFunc_POST_ValidationFails(t *testing.T) {
} }
func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) { func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
ctx, w, r, _ := testRequest("POST", "username", "john", "confirm_username", "john") ctx, w, r, _ := testRequest(rec.Authboss, "POST", "username", "john", "confirm_username", "john")
err := rec.startHandlerFunc(ctx, w, r) err := rec.startHandlerFunc(ctx, w, r)
if err == nil { if err == nil {
@@ -141,7 +149,7 @@ func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) {
t.Error("Expected ErrAndRedirect error") t.Error("Expected ErrAndRedirect error")
} }
if rerr.Location != authboss.a.RecoverOKPath { if rerr.Location != rec.RecoverOKPath {
t.Error("Unexpected location:", rerr.Location) t.Error("Unexpected location:", rerr.Location)
} }
@@ -151,6 +159,8 @@ func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) {
} }
func TestRecover_startHandlerFunc_POST(t *testing.T) { func TestRecover_startHandlerFunc_POST(t *testing.T) {
t.Parallel()
rec, storer, _ := testSetup() rec, storer, _ := testSetup()
storer.Users["john"] = authboss.Attributes{authboss.StoreUsername: "john", authboss.StoreEmail: "a@b.c"} storer.Users["john"] = authboss.Attributes{authboss.StoreUsername: "john", authboss.StoreEmail: "a@b.c"}
@@ -160,7 +170,7 @@ func TestRecover_startHandlerFunc_POST(t *testing.T) {
sentEmail = true sentEmail = true
} }
ctx, w, r, sessionStorer := testRequest("POST", "username", "john", "confirm_username", "john") ctx, w, r, sessionStorer := testRequest(rec.Authboss, "POST", "username", "john", "confirm_username", "john")
if err := rec.startHandlerFunc(ctx, w, r); err != nil { if err := rec.startHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -187,7 +197,7 @@ func TestRecover_startHandlerFunc_POST(t *testing.T) {
} }
loc := w.Header().Get("Location") loc := w.Header().Get("Location")
if loc != authboss.a.RecoverOKPath { if loc != rec.RecoverOKPath {
t.Error("Unexpected location:", loc) t.Error("Unexpected location:", loc)
} }
@@ -199,12 +209,14 @@ func TestRecover_startHandlerFunc_POST(t *testing.T) {
} }
func TestRecover_startHandlerFunc_OtherMethods(t *testing.T) { func TestRecover_startHandlerFunc_OtherMethods(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
methods := []string{"HEAD", "PUT", "DELETE", "TRACE", "CONNECT"} methods := []string{"HEAD", "PUT", "DELETE", "TRACE", "CONNECT"}
for i, method := range methods { for i, method := range methods {
_, w, r, _ := testRequest(method) _, w, r, _ := testRequest(rec.Authboss, method)
if err := rec.startHandlerFunc(nil, w, r); err != nil { if err := rec.startHandlerFunc(nil, w, r); err != nil {
t.Errorf("%d> Unexpected error: %s", i, err) t.Errorf("%d> Unexpected error: %s", i, err)
@@ -218,6 +230,8 @@ func TestRecover_startHandlerFunc_OtherMethods(t *testing.T) {
} }
func TestRecover_newToken(t *testing.T) { func TestRecover_newToken(t *testing.T) {
t.Parallel()
regexURL := regexp.MustCompile(`^(?:[A-Za-z0-9-_]{4})*(?:[A-Za-z0-9-_]{2}==|[A-Za-z0-9-_]{3}=)?$`) regexURL := regexp.MustCompile(`^(?:[A-Za-z0-9-_]{4})*(?:[A-Za-z0-9-_]{2}==|[A-Za-z0-9-_]{3}=)?$`)
regexSTD := regexp.MustCompile(`^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`) regexSTD := regexp.MustCompile(`^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/]{2}==|[A-Za-z0-9+/]{3}=)?$`)
@@ -233,13 +247,15 @@ func TestRecover_newToken(t *testing.T) {
} }
func TestRecover_sendRecoverMail_FailToSend(t *testing.T) { func TestRecover_sendRecoverMail_FailToSend(t *testing.T) {
a, _, logger := testSetup() t.Parallel()
r, _, logger := testSetup()
mailer := mocks.NewMockMailer() mailer := mocks.NewMockMailer()
mailer.SendErr = "failed to send" mailer.SendErr = "failed to send"
authboss.a.Mailer = mailer r.Mailer = mailer
a.sendRecoverEmail("", "") r.sendRecoverEmail("", "")
if !strings.Contains(logger.String(), "failed to send") { if !strings.Contains(logger.String(), "failed to send") {
t.Error("Expected logged to have msg:", "failed to send") t.Error("Expected logged to have msg:", "failed to send")
@@ -247,14 +263,16 @@ func TestRecover_sendRecoverMail_FailToSend(t *testing.T) {
} }
func TestRecover_sendRecoverEmail(t *testing.T) { func TestRecover_sendRecoverEmail(t *testing.T) {
a, _, _ := testSetup() t.Parallel()
r, _, _ := testSetup()
mailer := mocks.NewMockMailer() mailer := mocks.NewMockMailer()
authboss.a.EmailSubjectPrefix = "foo " r.EmailSubjectPrefix = "foo "
authboss.a.RootURL = "bar" r.RootURL = "bar"
authboss.a.Mailer = mailer r.Mailer = mailer
a.sendRecoverEmail("a@b.c", "abc=") r.sendRecoverEmail("a@b.c", "abc=")
if len(mailer.Last.To) != 1 { if len(mailer.Last.To) != 1 {
t.Error("Expected 1 to email") t.Error("Expected 1 to email")
} }
@@ -265,7 +283,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
t.Error("Unexpected subject:", mailer.Last.Subject) t.Error("Unexpected subject:", mailer.Last.Subject)
} }
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.a.RootURL) url := fmt.Sprintf("%s/recover/complete?token=abc=", r.RootURL)
if !strings.Contains(mailer.Last.HTMLBody, url) { if !strings.Contains(mailer.Last.HTMLBody, url) {
t.Error("Expected HTMLBody to contain url:", url) t.Error("Expected HTMLBody to contain url:", url)
} }
@@ -275,9 +293,11 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
} }
func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) { func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) {
t.Parallel()
rec, storer, _ := testSetup() rec, storer, _ := testSetup()
ctx, w, r, _ := testRequest("GET", "token", testURLBase64Token) ctx, w, r, _ := testRequest(rec.Authboss, "GET", "token", testURLBase64Token)
err := rec.completeHandlerFunc(ctx, w, r) err := rec.completeHandlerFunc(ctx, w, r)
rerr, ok := err.(authboss.ErrAndRedirect) rerr, ok := err.(authboss.ErrAndRedirect)
@@ -291,7 +311,7 @@ func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) {
var zeroTime time.Time var zeroTime time.Time
storer.Users["john"] = authboss.Attributes{StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: zeroTime} storer.Users["john"] = authboss.Attributes{StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: zeroTime}
ctx, w, r, _ = testRequest("GET", "token", testURLBase64Token) ctx, w, r, _ = testRequest(rec.Authboss, "GET", "token", testURLBase64Token)
err = rec.completeHandlerFunc(ctx, w, r) err = rec.completeHandlerFunc(ctx, w, r)
rerr, ok = err.(authboss.ErrAndRedirect) rerr, ok = err.(authboss.ErrAndRedirect)
@@ -307,11 +327,13 @@ func TestRecover_completeHandlerFunc_GET_VerifyFails(t *testing.T) {
} }
func TestRecover_completeHandlerFunc_GET(t *testing.T) { func TestRecover_completeHandlerFunc_GET(t *testing.T) {
t.Parallel()
rec, storer, _ := testSetup() rec, storer, _ := testSetup()
storer.Users["john"] = authboss.Attributes{StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour)} storer.Users["john"] = authboss.Attributes{StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour)}
ctx, w, r, _ := testRequest("GET", "token", testURLBase64Token) ctx, w, r, _ := testRequest(rec.Authboss, "GET", "token", testURLBase64Token)
if err := rec.completeHandlerFunc(ctx, w, r); err != nil { if err := rec.completeHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -337,8 +359,10 @@ func TestRecover_completeHandlerFunc_GET(t *testing.T) {
} }
func TestRecover_completeHandlerFunc_POST_TokenMissing(t *testing.T) { func TestRecover_completeHandlerFunc_POST_TokenMissing(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
ctx, w, r, _ := testRequest("POST") ctx, w, r, _ := testRequest(rec.Authboss, "POST")
err := rec.completeHandlerFunc(ctx, w, r) err := rec.completeHandlerFunc(ctx, w, r)
if err.Error() != "Failed to retrieve client attribute: token" { if err.Error() != "Failed to retrieve client attribute: token" {
@@ -348,8 +372,10 @@ func TestRecover_completeHandlerFunc_POST_TokenMissing(t *testing.T) {
} }
func TestRecover_completeHandlerFunc_POST_ValidationFails(t *testing.T) { func TestRecover_completeHandlerFunc_POST_ValidationFails(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
ctx, w, r, _ := testRequest("POST", "token", testURLBase64Token) ctx, w, r, _ := testRequest(rec.Authboss, "POST", "token", testURLBase64Token)
if err := rec.completeHandlerFunc(ctx, w, r); err != nil { if err := rec.completeHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -365,8 +391,10 @@ func TestRecover_completeHandlerFunc_POST_ValidationFails(t *testing.T) {
} }
func TestRecover_completeHandlerFunc_POST_VerificationFails(t *testing.T) { func TestRecover_completeHandlerFunc_POST_VerificationFails(t *testing.T) {
t.Parallel()
rec, _, _ := testSetup() rec, _, _ := testSetup()
ctx, w, r, _ := testRequest("POST", "token", testURLBase64Token, authboss.StorePassword, "abcd", "confirm_"+authboss.StorePassword, "abcd") ctx, w, r, _ := testRequest(rec.Authboss, "POST", "token", testURLBase64Token, authboss.StorePassword, "abcd", "confirm_"+authboss.StorePassword, "abcd")
if err := rec.completeHandlerFunc(ctx, w, r); err == nil { if err := rec.completeHandlerFunc(ctx, w, r); err == nil {
log.Println(w.Body.String()) log.Println(w.Body.String())
@@ -375,19 +403,21 @@ func TestRecover_completeHandlerFunc_POST_VerificationFails(t *testing.T) {
} }
func TestRecover_completeHandlerFunc_POST(t *testing.T) { func TestRecover_completeHandlerFunc_POST(t *testing.T) {
t.Parallel()
rec, storer, _ := testSetup() rec, storer, _ := testSetup()
storer.Users["john"] = authboss.Attributes{authboss.a.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"} storer.Users["john"] = authboss.Attributes{rec.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"}
cbCalled := false cbCalled := false
authboss.a.Callbacks = authboss.NewCallbacks() rec.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error { rec.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error {
cbCalled = true cbCalled = true
return nil return nil
}) })
ctx, w, r, sessionStorer := testRequest("POST", "token", testURLBase64Token, authboss.StorePassword, "abcd", "confirm_"+authboss.StorePassword, "abcd") ctx, w, r, sessionStorer := testRequest(rec.Authboss, "POST", "token", testURLBase64Token, authboss.StorePassword, "abcd", "confirm_"+authboss.StorePassword, "abcd")
if err := rec.completeHandlerFunc(ctx, w, r); err != nil { if err := rec.completeHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
@@ -421,12 +451,14 @@ func TestRecover_completeHandlerFunc_POST(t *testing.T) {
} }
loc := w.Header().Get("Location") loc := w.Header().Get("Location")
if loc != authboss.a.AuthLogoutOKPath { if loc != rec.AuthLogoutOKPath {
t.Error("Unexpected location:", loc) t.Error("Unexpected location:", loc)
} }
} }
func Test_verifyToken_MissingToken(t *testing.T) { func Test_verifyToken_MissingToken(t *testing.T) {
t.Parallel()
testSetup() testSetup()
ctx := &authboss.Context{} ctx := &authboss.Context{}
@@ -436,38 +468,44 @@ func Test_verifyToken_MissingToken(t *testing.T) {
} }
func Test_verifyToken_InvalidToken(t *testing.T) { func Test_verifyToken_InvalidToken(t *testing.T) {
_, storer, _ := testSetup() t.Parallel()
rec, storer, _ := testSetup()
storer.Users["a"] = authboss.Attributes{ storer.Users["a"] = authboss.Attributes{
StoreRecoverToken: testStdBase64Token, StoreRecoverToken: testStdBase64Token,
} }
ctx := mocks.MockRequestContext("token", "asdf") ctx := mocks.MockRequestContext(rec.Authboss, "token", "asdf")
if _, err := verifyToken(ctx); err != authboss.ErrUserNotFound { if _, err := verifyToken(ctx); err != authboss.ErrUserNotFound {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
} }
func Test_verifyToken_ExpiredToken(t *testing.T) { func Test_verifyToken_ExpiredToken(t *testing.T) {
_, storer, _ := testSetup() t.Parallel()
rec, storer, _ := testSetup()
storer.Users["a"] = authboss.Attributes{ storer.Users["a"] = authboss.Attributes{
StoreRecoverToken: testStdBase64Token, StoreRecoverToken: testStdBase64Token,
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(-24) * time.Hour), StoreRecoverTokenExpiry: time.Now().Add(time.Duration(-24) * time.Hour),
} }
ctx := mocks.MockRequestContext("token", testURLBase64Token) ctx := mocks.MockRequestContext(rec.Authboss, "token", testURLBase64Token)
if _, err := verifyToken(ctx); err != errRecoveryTokenExpired { if _, err := verifyToken(ctx); err != errRecoveryTokenExpired {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
} }
func Test_verifyToken(t *testing.T) { func Test_verifyToken(t *testing.T) {
_, storer, _ := testSetup() t.Parallel()
rec, storer, _ := testSetup()
storer.Users["a"] = authboss.Attributes{ storer.Users["a"] = authboss.Attributes{
StoreRecoverToken: testStdBase64Token, StoreRecoverToken: testStdBase64Token,
StoreRecoverTokenExpiry: time.Now().Add(time.Duration(24) * time.Hour), StoreRecoverTokenExpiry: time.Now().Add(time.Duration(24) * time.Hour),
} }
ctx := mocks.MockRequestContext("token", testURLBase64Token) ctx := mocks.MockRequestContext(rec.Authboss, "token", testURLBase64Token)
attrs, err := verifyToken(ctx) attrs, err := verifyToken(ctx)
if err != nil { if err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)

View File

@@ -28,20 +28,23 @@ func init() {
// Register module. // Register module.
type Register struct { type Register struct {
*authboss.Authboss
templates response.Templates templates response.Templates
} }
// Initialize the module. // Initialize the module.
func (r *Register) Initialize() (err error) { func (r *Register) Initialize(ab *authboss.Authboss) (err error) {
if authboss.a.Storer == nil { r.Authboss = ab
if r.Storer == nil {
return errors.New("register: Need a RegisterStorer") return errors.New("register: Need a RegisterStorer")
} }
if _, ok := authboss.a.Storer.(RegisterStorer); !ok { if _, ok := r.Storer.(RegisterStorer); !ok {
return errors.New("register: RegisterStorer required for register functionality") return errors.New("register: RegisterStorer required for register functionality")
} }
if r.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplRegister); err != nil { if r.templates, err = response.LoadTemplates(r.Authboss, r.Layout, r.ViewsPath, tplRegister); err != nil {
return err return err
} }
@@ -58,7 +61,7 @@ func (r *Register) Routes() authboss.RouteTable {
// Storage returns storage requirements. // Storage returns storage requirements.
func (r *Register) Storage() authboss.StorageOptions { func (r *Register) Storage() authboss.StorageOptions {
return authboss.StorageOptions{ return authboss.StorageOptions{
authboss.a.PrimaryID: authboss.String, r.PrimaryID: authboss.String,
authboss.StorePassword: authboss.String, authboss.StorePassword: authboss.String,
} }
} }
@@ -67,7 +70,7 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite
switch r.Method { switch r.Method {
case "GET": case "GET":
data := authboss.HTMLData{ data := authboss.HTMLData{
"primaryID": authboss.a.PrimaryID, "primaryID": reg.PrimaryID,
"primaryIDValue": "", "primaryIDValue": "",
} }
return reg.templates.Render(ctx, w, r, tplRegister, data) return reg.templates.Render(ctx, w, r, tplRegister, data)
@@ -78,15 +81,15 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite
} }
func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
key, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID) key, _ := ctx.FirstPostFormValue(reg.PrimaryID)
password, _ := ctx.FirstPostFormValue(authboss.StorePassword) password, _ := ctx.FirstPostFormValue(authboss.StorePassword)
policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID, authboss.StorePassword) policies := authboss.FilterValidators(reg.Policies, reg.PrimaryID, authboss.StorePassword)
validationErrs := ctx.Validate(policies, authboss.a.ConfirmFields...) validationErrs := ctx.Validate(policies, reg.ConfirmFields...)
if len(validationErrs) != 0 { if len(validationErrs) != 0 {
data := authboss.HTMLData{ data := authboss.HTMLData{
"primaryID": authboss.a.PrimaryID, "primaryID": reg.PrimaryID,
"primaryIDValue": key, "primaryIDValue": key,
"errs": validationErrs.Map(), "errs": validationErrs.Map(),
} }
@@ -99,30 +102,30 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
return err return err
} }
pass, err := bcrypt.GenerateFromPassword([]byte(password), authboss.a.BCryptCost) pass, err := bcrypt.GenerateFromPassword([]byte(password), reg.BCryptCost)
if err != nil { if err != nil {
return err return err
} }
attr[authboss.a.PrimaryID] = key attr[reg.PrimaryID] = key
attr[authboss.StorePassword] = string(pass) attr[authboss.StorePassword] = string(pass)
ctx.User = attr ctx.User = attr
if err := authboss.a.Storer.(RegisterStorer).Create(key, attr); err != nil { if err := reg.Storer.(RegisterStorer).Create(key, attr); err != nil {
return err return err
} }
if err := authboss.a.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil { if err := reg.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil {
return err return err
} }
if authboss.IsLoaded("confirm") { if reg.IsLoaded("confirm") {
response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true) response.Redirect(ctx, w, r, reg.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true)
return nil return nil
} }
ctx.SessionStorer.Put(authboss.SessionKey, key) ctx.SessionStorer.Put(authboss.SessionKey, key)
response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "Account successfully created, you are now logged in.", "", true) response.Redirect(ctx, w, r, reg.RegisterOKPath, "Account successfully created, you are now logged in.", "", true)
return nil return nil
} }

View File

@@ -14,18 +14,18 @@ import (
) )
func setup() *Register { func setup() *Register {
authboss.Cfg = authboss.NewConfig() ab := authboss.New()
authboss.a.RegisterOKPath = "/regsuccess" ab.RegisterOKPath = "/regsuccess"
authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) ab.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.XSRFName = "xsrf" ab.XSRFName = "xsrf"
authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { ab.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "xsrfvalue" return "xsrfvalue"
} }
authboss.a.ConfirmFields = []string{"password", "confirm_password"} ab.ConfirmFields = []string{"password", "confirm_password"}
authboss.a.Storer = mocks.NewMockStorer() ab.Storer = mocks.NewMockStorer()
reg := Register{} reg := Register{}
if err := reg.Initialize(); err != nil { if err := reg.Initialize(ab); err != nil {
panic(err) panic(err)
} }
@@ -33,11 +33,10 @@ func setup() *Register {
} }
func TestRegister(t *testing.T) { func TestRegister(t *testing.T) {
authboss.Cfg = authboss.NewConfig() ab := authboss.New()
authboss.a.Storer = mocks.NewMockStorer() ab.Storer = mocks.NewMockStorer()
r := Register{} r := Register{}
if err := r.Initialize(ab); err != nil {
if err := r.Initialize(); err != nil {
t.Error(err) t.Error(err)
} }
@@ -46,7 +45,7 @@ func TestRegister(t *testing.T) {
} }
sto := r.Storage() sto := r.Storage()
if sto[authboss.a.PrimaryID] != authboss.String { if sto[r.PrimaryID] != authboss.String {
t.Error("Wanted primary ID to be a string.") t.Error("Wanted primary ID to be a string.")
} }
if sto[authboss.StorePassword] != authboss.String { if sto[authboss.StorePassword] != authboss.String {
@@ -59,7 +58,7 @@ func TestRegisterGet(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "/register", nil) r, _ := http.NewRequest("GET", "/register", nil)
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := reg.ContextFromRequest(r)
ctx.SessionStorer = mocks.NewMockClientStorer() ctx.SessionStorer = mocks.NewMockClientStorer()
if err := reg.registerHandler(ctx, w, r); err != nil { if err := reg.registerHandler(ctx, w, r); err != nil {
@@ -76,7 +75,7 @@ func TestRegisterGet(t *testing.T) {
if str := w.Body.String(); !strings.Contains(str, "<form") { if str := w.Body.String(); !strings.Contains(str, "<form") {
t.Error("It should have rendered a nice form:", str) t.Error("It should have rendered a nice form:", str)
} else if !strings.Contains(str, `name="`+authboss.a.PrimaryID) { } else if !strings.Contains(str, `name="`+reg.PrimaryID) {
t.Error("Form should contain the primary ID:", str) t.Error("Form should contain the primary ID:", str)
} }
} }
@@ -88,13 +87,13 @@ func TestRegisterPostValidationErrs(t *testing.T) {
vals := url.Values{} vals := url.Values{}
email := "email@address.com" email := "email@address.com"
vals.Set(authboss.a.PrimaryID, email) vals.Set(reg.PrimaryID, email)
vals.Set(authboss.StorePassword, "pass") vals.Set(authboss.StorePassword, "pass")
vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass2") vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass2")
r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode())) r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := reg.ContextFromRequest(r)
ctx.SessionStorer = mocks.NewMockClientStorer() ctx.SessionStorer = mocks.NewMockClientStorer()
if err := reg.registerHandler(ctx, w, r); err != nil { if err := reg.registerHandler(ctx, w, r); err != nil {
@@ -113,7 +112,7 @@ func TestRegisterPostValidationErrs(t *testing.T) {
t.Error("Confirm password should have an error:", str) t.Error("Confirm password should have an error:", str)
} }
if _, err := authboss.a.Storer.Get(email); err != authboss.ErrUserNotFound { if _, err := reg.Storer.Get(email); err != authboss.ErrUserNotFound {
t.Error("The user should not have been saved.") t.Error("The user should not have been saved.")
} }
} }
@@ -125,13 +124,13 @@ func TestRegisterPostSuccess(t *testing.T) {
vals := url.Values{} vals := url.Values{}
email := "email@address.com" email := "email@address.com"
vals.Set(authboss.a.PrimaryID, email) vals.Set(reg.PrimaryID, email)
vals.Set(authboss.StorePassword, "pass") vals.Set(authboss.StorePassword, "pass")
vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass") vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass")
r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode())) r, _ := http.NewRequest("POST", "/register", bytes.NewBufferString(vals.Encode()))
r.Header.Set("Content-Type", "application/x-www-form-urlencoded") r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, _ := authboss.ContextFromRequest(r) ctx, _ := reg.ContextFromRequest(r)
ctx.SessionStorer = mocks.NewMockClientStorer() ctx.SessionStorer = mocks.NewMockClientStorer()
if err := reg.registerHandler(ctx, w, r); err != nil { if err := reg.registerHandler(ctx, w, r); err != nil {
@@ -142,17 +141,17 @@ func TestRegisterPostSuccess(t *testing.T) {
t.Error("It should have written a redirect:", w.Code) t.Error("It should have written a redirect:", w.Code)
} }
if loc := w.Header().Get("Location"); loc != authboss.a.RegisterOKPath { if loc := w.Header().Get("Location"); loc != reg.RegisterOKPath {
t.Error("Redirected to the wrong location", loc) t.Error("Redirected to the wrong location", loc)
} }
user, err := authboss.a.Storer.Get(email) user, err := reg.Storer.Get(email)
if err == authboss.ErrUserNotFound { if err == authboss.ErrUserNotFound {
t.Error("The user have been saved.") t.Error("The user have been saved.")
} }
attrs := authboss.Unbind(user) attrs := authboss.Unbind(user)
if e, err := attrs.StringErr(authboss.a.PrimaryID); err != nil { if e, err := attrs.StringErr(reg.PrimaryID); err != nil {
t.Error(err) t.Error(err)
} else if e != email { } else if e != email {
t.Errorf("Email was not set properly, want: %s, got: %s", email, e) t.Errorf("Email was not set properly, want: %s, got: %s", email, e)

View File

@@ -43,24 +43,28 @@ func init() {
} }
// Remember module // Remember module
type Remember struct{} type Remember struct {
*authboss.Authboss
}
// Initialize module // Initialize module
func (r *Remember) Initialize() error { func (r *Remember) Initialize(ab *authboss.Authboss) error {
if authboss.a.Storer == nil && authboss.a.OAuth2Storer == nil { r.Authboss = ab
if r.Storer == nil && r.OAuth2Storer == nil {
return errors.New("remember: Need a RememberStorer") return errors.New("remember: Need a RememberStorer")
} }
if _, ok := authboss.a.Storer.(RememberStorer); !ok { if _, ok := r.Storer.(RememberStorer); !ok {
if _, ok := authboss.a.OAuth2Storer.(RememberStorer); !ok { if _, ok := r.OAuth2Storer.(RememberStorer); !ok {
return errors.New("remember: RememberStorer required for remember functionality") return errors.New("remember: RememberStorer required for remember functionality")
} }
} }
authboss.a.Callbacks.Before(authboss.EventGetUserSession, r.auth) r.Callbacks.Before(authboss.EventGetUserSession, r.auth)
authboss.a.Callbacks.After(authboss.EventAuth, r.afterAuth) r.Callbacks.After(authboss.EventAuth, r.afterAuth)
authboss.a.Callbacks.After(authboss.EventOAuth, r.afterOAuth) r.Callbacks.After(authboss.EventOAuth, r.afterOAuth)
authboss.a.Callbacks.After(authboss.EventPasswordReset, r.afterPassword) r.Callbacks.After(authboss.EventPasswordReset, r.afterPassword)
return nil return nil
} }
@@ -73,7 +77,7 @@ func (r *Remember) Routes() authboss.RouteTable {
// Storage requirements // Storage requirements
func (r *Remember) Storage() authboss.StorageOptions { func (r *Remember) Storage() authboss.StorageOptions {
return authboss.StorageOptions{ return authboss.StorageOptions{
authboss.a.PrimaryID: authboss.String, r.PrimaryID: authboss.String,
} }
} }
@@ -87,7 +91,7 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error {
return errUserMissing return errUserMissing
} }
key, err := ctx.User.StringErr(authboss.a.PrimaryID) key, err := ctx.User.StringErr(r.PrimaryID)
if err != nil { if err != nil {
return err return err
} }
@@ -146,7 +150,7 @@ func (r *Remember) afterPassword(ctx *authboss.Context) error {
return nil return nil
} }
id, ok := ctx.User.String(authboss.a.PrimaryID) id, ok := ctx.User.String(r.PrimaryID)
if !ok { if !ok {
return nil return nil
} }
@@ -154,8 +158,8 @@ func (r *Remember) afterPassword(ctx *authboss.Context) error {
ctx.CookieStorer.Del(authboss.CookieRemember) ctx.CookieStorer.Del(authboss.CookieRemember)
var storer RememberStorer var storer RememberStorer
if storer, ok = authboss.a.Storer.(RememberStorer); !ok { if storer, ok = r.Storer.(RememberStorer); !ok {
if storer, ok = authboss.a.OAuth2Storer.(RememberStorer); !ok { if storer, ok = r.OAuth2Storer.(RememberStorer); !ok {
return nil return nil
} }
} }
@@ -181,8 +185,8 @@ func (r *Remember) new(cstorer authboss.ClientStorer, storageKey string) (string
var storer RememberStorer var storer RememberStorer
var ok bool var ok bool
if storer, ok = authboss.a.Storer.(RememberStorer); !ok { if storer, ok = r.Storer.(RememberStorer); !ok {
storer, ok = authboss.a.OAuth2Storer.(RememberStorer) storer, ok = r.OAuth2Storer.(RememberStorer)
} }
// Save the token in the DB // Save the token in the DB
@@ -226,8 +230,8 @@ func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) {
sum := md5.Sum(token) sum := md5.Sum(token)
var storer RememberStorer var storer RememberStorer
if storer, ok = authboss.a.Storer.(RememberStorer); !ok { if storer, ok = r.Storer.(RememberStorer); !ok {
storer, ok = authboss.a.OAuth2Storer.(RememberStorer) storer, ok = r.OAuth2Storer.(RememberStorer)
} }
err = storer.UseToken(givenKey, base64.StdEncoding.EncodeToString(sum[:])) err = storer.UseToken(givenKey, base64.StdEncoding.EncodeToString(sum[:]))

View File

@@ -11,32 +11,34 @@ import (
) )
func TestInitialize(t *testing.T) { func TestInitialize(t *testing.T) {
authboss.NewConfig() t.Parallel()
ab := authboss.New()
r := &Remember{} r := &Remember{}
err := r.Initialize() err := r.Initialize(ab)
if err == nil { if err == nil {
t.Error("Expected error about token storers.") t.Error("Expected error about token storers.")
} }
authboss.a.Storer = mocks.MockFailStorer{} ab.Storer = mocks.MockFailStorer{}
err = r.Initialize() err = r.Initialize(ab)
if err == nil { if err == nil {
t.Error("Expected error about token storers.") t.Error("Expected error about token storers.")
} }
authboss.a.Storer = mocks.NewMockStorer() ab.Storer = mocks.NewMockStorer()
err = r.Initialize() err = r.Initialize(ab)
if err != nil { if err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
} }
func TestAfterAuth(t *testing.T) { func TestAfterAuth(t *testing.T) {
r := Remember{} t.Parallel()
authboss.NewConfig()
r := Remember{authboss.New()}
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer r.Storer = storer
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
@@ -47,14 +49,14 @@ func TestAfterAuth(t *testing.T) {
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := authboss.ContextFromRequest(req) ctx, err := r.ContextFromRequest(req)
if err != nil { if err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
ctx.SessionStorer = session ctx.SessionStorer = session
ctx.CookieStorer = cookies ctx.CookieStorer = cookies
ctx.User = authboss.Attributes{authboss.a.PrimaryID: "test@email.com"} ctx.User = authboss.Attributes{r.PrimaryID: "test@email.com"}
if err := r.afterAuth(ctx); err != nil { if err := r.afterAuth(ctx); err != nil {
t.Error(err) t.Error(err)
@@ -66,10 +68,11 @@ func TestAfterAuth(t *testing.T) {
} }
func TestAfterOAuth(t *testing.T) { func TestAfterOAuth(t *testing.T) {
r := Remember{} t.Parallel()
authboss.NewConfig()
r := Remember{authboss.New()}
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer r.Storer = storer
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`) session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
@@ -80,7 +83,7 @@ func TestAfterOAuth(t *testing.T) {
t.Error("Unexpected Error:", err) t.Error("Unexpected Error:", err)
} }
ctx, err := authboss.ContextFromRequest(req) ctx, err := r.ContextFromRequest(req)
if err != nil { if err != nil {
t.Error("Unexpected error:", err) t.Error("Unexpected error:", err)
} }
@@ -102,20 +105,21 @@ func TestAfterOAuth(t *testing.T) {
} }
func TestAfterPasswordReset(t *testing.T) { func TestAfterPasswordReset(t *testing.T) {
r := Remember{} t.Parallel()
authboss.NewConfig()
r := Remember{authboss.New()}
id := "test@email.com" id := "test@email.com"
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer r.Storer = storer
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
storer.Tokens[id] = []string{"one", "two"} storer.Tokens[id] = []string{"one", "two"}
cookies.Values[authboss.CookieRemember] = "token" cookies.Values[authboss.CookieRemember] = "token"
ctx := authboss.NewContext() ctx := r.NewContext()
ctx.User = authboss.Attributes{authboss.a.PrimaryID: id} ctx.User = authboss.Attributes{r.PrimaryID: id}
ctx.SessionStorer = session ctx.SessionStorer = session
ctx.CookieStorer = cookies ctx.CookieStorer = cookies
@@ -133,10 +137,11 @@ func TestAfterPasswordReset(t *testing.T) {
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
r := &Remember{} t.Parallel()
authboss.NewConfig()
r := &Remember{authboss.New()}
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer r.Storer = storer
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
key := "tester" key := "tester"
@@ -162,14 +167,15 @@ func TestNew(t *testing.T) {
} }
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
r := &Remember{} t.Parallel()
authboss.NewConfig()
r := &Remember{authboss.New()}
storer := mocks.NewMockStorer() storer := mocks.NewMockStorer()
authboss.a.Storer = storer r.Storer = storer
cookies := mocks.NewMockClientStorer() cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer() session := mocks.NewMockClientStorer()
ctx := authboss.NewContext() ctx := r.NewContext()
ctx.CookieStorer = cookies ctx.CookieStorer = cookies
ctx.SessionStorer = session ctx.SessionStorer = session