1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-03 13:21:22 +02:00

Stop reliance on global scope.

- This change was necessary because multi-tenancy sites could not use
  authboss properly.
This commit is contained in:
Aaron 2015-03-31 12:34:03 -07:00
parent bd0d3c5f68
commit f12f10fa43
40 changed files with 556 additions and 495 deletions

View File

@ -29,19 +29,19 @@ type Auth struct {
// Initialize module
func (a *Auth) Initialize() (err error) {
if authboss.Cfg.Storer == nil {
if authboss.a.Storer == nil {
return errors.New("auth: Need a Storer")
}
if len(authboss.Cfg.XSRFName) == 0 {
if len(authboss.a.XSRFName) == 0 {
return errors.New("auth: XSRFName must be set")
}
if authboss.Cfg.XSRFMaker == nil {
if authboss.a.XSRFMaker == nil {
return errors.New("auth: XSRFMaker must be defined")
}
a.templates, err = response.LoadTemplates(authboss.Cfg.Layout, authboss.Cfg.ViewsPath, tplLogin)
a.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplLogin)
if err != nil {
return err
}
@ -60,7 +60,7 @@ func (a *Auth) Routes() authboss.RouteTable {
// Storage requirements
func (a *Auth) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.Cfg.PrimaryID: authboss.String,
authboss.a.PrimaryID: authboss.String,
authboss.StorePassword: authboss.String,
}
}
@ -70,8 +70,8 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
case methodGET:
if _, ok := ctx.SessionStorer.Get(authboss.SessionKey); ok {
if halfAuthed, ok := ctx.SessionStorer.Get(authboss.SessionHalfAuthKey); !ok || halfAuthed == "false" {
//http.Redirect(w, r, authboss.Cfg.AuthLoginOKPath, http.StatusFound, true)
response.Redirect(ctx, w, r, authboss.Cfg.AuthLoginOKPath, "", "", true)
//http.Redirect(w, r, authboss.a.AuthLoginOKPath, http.StatusFound, true)
response.Redirect(ctx, w, r, authboss.a.AuthLoginOKPath, "", "", true)
return nil
}
}
@ -79,23 +79,23 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
data := authboss.NewHTMLData(
"showRemember", authboss.IsLoaded("remember"),
"showRecover", authboss.IsLoaded("recover"),
"primaryID", authboss.Cfg.PrimaryID,
"primaryID", authboss.a.PrimaryID,
"primaryIDValue", "",
)
return a.templates.Render(ctx, w, r, tplLogin, data)
case methodPOST:
key, _ := ctx.FirstPostFormValue(authboss.Cfg.PrimaryID)
key, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID)
password, _ := ctx.FirstPostFormValue("password")
errData := authboss.NewHTMLData(
"error", fmt.Sprintf("invalid %s and/or password", authboss.Cfg.PrimaryID),
"primaryID", authboss.Cfg.PrimaryID,
"error", fmt.Sprintf("invalid %s and/or password", authboss.a.PrimaryID),
"primaryID", authboss.a.PrimaryID,
"primaryIDValue", key,
"showRemember", authboss.IsLoaded("remember"),
"showRecover", authboss.IsLoaded("recover"),
)
policies := authboss.FilterValidators(authboss.Cfg.Policies, authboss.Cfg.PrimaryID, authboss.StorePassword)
policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID, authboss.StorePassword)
if validationErrs := ctx.Validate(policies); len(validationErrs) > 0 {
return a.templates.Render(ctx, w, r, tplLogin, errData)
}
@ -104,7 +104,7 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
return a.templates.Render(ctx, w, r, tplLogin, errData)
}
interrupted, err := authboss.Cfg.Callbacks.FireBefore(authboss.EventAuth, ctx)
interrupted, err := authboss.a.Callbacks.FireBefore(authboss.EventAuth, ctx)
if err != nil {
return err
} else if interrupted != authboss.InterruptNone {
@ -115,17 +115,17 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
case authboss.InterruptAccountNotConfirmed:
reason = "Your account has not been confirmed."
}
response.Redirect(ctx, w, r, authboss.Cfg.AuthLoginFailPath, "", reason, false)
response.Redirect(ctx, w, r, authboss.a.AuthLoginFailPath, "", reason, false)
return nil
}
ctx.SessionStorer.Put(authboss.SessionKey, key)
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil {
if err := authboss.a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil {
return err
}
response.Redirect(ctx, w, r, authboss.Cfg.AuthLoginOKPath, "", "", true)
response.Redirect(ctx, w, r, authboss.a.AuthLoginOKPath, "", "", true)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
@ -157,7 +157,7 @@ func (a *Auth) logoutHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r
ctx.CookieStorer.Del(authboss.CookieRemember)
ctx.SessionStorer.Del(authboss.SessionLastAction)
response.Redirect(ctx, w, r, authboss.Cfg.AuthLogoutOKPath, "You have logged out", "", true)
response.Redirect(ctx, w, r, authboss.a.AuthLogoutOKPath, "You have logged out", "", true)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}

View File

@ -17,14 +17,14 @@ func testSetup() (a *Auth, s *mocks.MockStorer) {
s = mocks.NewMockStorer()
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.LogWriter = ioutil.Discard
authboss.Cfg.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.Cfg.Storer = s
authboss.Cfg.XSRFName = "xsrf"
authboss.Cfg.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
authboss.a.LogWriter = ioutil.Discard
authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.Storer = s
authboss.a.XSRFName = "xsrf"
authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "xsrfvalue"
}
authboss.Cfg.PrimaryID = authboss.StoreUsername
authboss.a.PrimaryID = authboss.StoreUsername
a = &Auth{}
if err := a.Initialize(); err != nil {
@ -51,8 +51,8 @@ func TestAuth(t *testing.T) {
a, _ := testSetup()
storage := a.Storage()
if storage[authboss.Cfg.PrimaryID] != authboss.String {
t.Error("Expected storage KV:", authboss.Cfg.PrimaryID, authboss.String)
if storage[authboss.a.PrimaryID] != authboss.String {
t.Error("Expected storage KV:", authboss.a.PrimaryID, authboss.String)
}
if storage[authboss.StorePassword] != authboss.String {
t.Error("Expected storage KV:", authboss.StorePassword, authboss.String)
@ -74,7 +74,7 @@ func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) {
sessionStore.Put(authboss.SessionKey, "a")
sessionStore.Put(authboss.SessionHalfAuthKey, "false")
authboss.Cfg.AuthLoginOKPath = "/dashboard"
authboss.a.AuthLoginOKPath = "/dashboard"
if err := a.loginHandlerFunc(ctx, w, r); err != nil {
t.Error("Unexpeced error:", err)
@ -85,7 +85,7 @@ func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) {
}
loc := w.Header().Get("Location")
if loc != authboss.Cfg.AuthLoginOKPath {
if loc != authboss.a.AuthLoginOKPath {
t.Error("Unexpected redirect:", loc)
}
}
@ -106,7 +106,7 @@ func TestAuth_loginHandlerFunc_GET(t *testing.T) {
if !strings.Contains(body, "<form") {
t.Error("Should have rendered a form")
}
if !strings.Contains(body, `name="`+authboss.Cfg.PrimaryID) {
if !strings.Contains(body, `name="`+authboss.a.PrimaryID) {
t.Error("Form should contain the primary ID field:", body)
}
if !strings.Contains(body, `name="password"`) {
@ -118,8 +118,8 @@ func TestAuth_loginHandlerFunc_POST_ReturnsErrorOnCallbackFailure(t *testing.T)
a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
authboss.Cfg.Callbacks = authboss.NewCallbacks()
authboss.Cfg.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
authboss.a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptNone, errors.New("explode")
})
@ -134,8 +134,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
a, storer := testSetup()
storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"}
authboss.Cfg.Callbacks = authboss.NewCallbacks()
authboss.Cfg.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
authboss.a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptAccountLocked, nil
})
@ -150,7 +150,7 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
}
loc := w.Header().Get("Location")
if loc != authboss.Cfg.AuthLoginFailPath {
if loc != authboss.a.AuthLoginFailPath {
t.Error("Unexpeced location:", loc)
}
@ -159,8 +159,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
t.Error("Expected error flash message:", expectedMsg)
}
authboss.Cfg.Callbacks = authboss.NewCallbacks()
authboss.Cfg.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
authboss.a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) {
return authboss.InterruptAccountNotConfirmed, nil
})
@ -173,7 +173,7 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) {
}
loc = w.Header().Get("Location")
if loc != authboss.Cfg.AuthLoginFailPath {
if loc != authboss.a.AuthLoginFailPath {
t.Error("Unexpeced location:", loc)
}
@ -224,9 +224,9 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
ctx, w, r, _ := testRequest("POST", "username", "john", "password", "1234")
cb := mocks.NewMockAfterCallback()
authboss.Cfg.Callbacks = authboss.NewCallbacks()
authboss.Cfg.Callbacks.After(authboss.EventAuth, cb.Fn)
authboss.Cfg.AuthLoginOKPath = "/dashboard"
authboss.a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.After(authboss.EventAuth, cb.Fn)
authboss.a.AuthLoginOKPath = "/dashboard"
sessions := mocks.NewMockClientStorer()
ctx.SessionStorer = sessions
@ -244,7 +244,7 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) {
}
loc := w.Header().Get("Location")
if loc != authboss.Cfg.AuthLoginOKPath {
if loc != authboss.a.AuthLoginOKPath {
t.Error("Unexpeced location:", loc)
}
@ -283,7 +283,7 @@ func TestAuth_validateCredentials(t *testing.T) {
storer := mocks.NewMockStorer()
storer.GetErr = "Failed to load user"
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
ctx := authboss.Context{}
@ -305,7 +305,7 @@ func TestAuth_validateCredentials(t *testing.T) {
func TestAuth_logoutHandlerFunc_GET(t *testing.T) {
a, _ := testSetup()
authboss.Cfg.AuthLogoutOKPath = "/dashboard"
authboss.a.AuthLogoutOKPath = "/dashboard"
ctx, w, r, sessionStorer := testRequest("GET")
sessionStorer.Put(authboss.SessionKey, "asdf")

View File

@ -17,11 +17,24 @@ import (
"golang.org/x/crypto/bcrypt"
)
// Authboss contains a configuration and other details for running.
type Authboss struct {
Config
}
// New makes a new instance of authboss with a default
// configuration.
func New() *Authboss {
ab := &Authboss{}
ab.Defaults()
return ab
}
// Init authboss and it's loaded modules.
func Init() error {
func (a *Authboss) Init() error {
for name, mod := range modules {
fmt.Fprintf(Cfg.LogWriter, "%-10s Initializing\n", "["+name+"]")
if err := mod.Initialize(); err != nil {
fmt.Fprintf(a.LogWriter, "%-10s Initializing\n", "["+name+"]")
if err := mod.Initialize(a); err != nil {
return fmt.Errorf("[%s] Error Initializing: %v", name, err)
}
}
@ -30,16 +43,16 @@ func Init() error {
}
// CurrentUser retrieves the current user from the session and the database.
func CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
ctx, err := ContextFromRequest(r)
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
ctx, err := a.ContextFromRequest(r)
if err != nil {
return nil, err
}
ctx.SessionStorer = clientStoreWrapper{Cfg.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{Cfg.CookieStoreMaker(w, r)}
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
_, err = Cfg.Callbacks.FireBefore(EventGetUserSession, ctx)
_, err = a.Callbacks.FireBefore(EventGetUserSession, ctx)
if err != nil {
return nil, err
}
@ -54,22 +67,22 @@ func CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
return nil, err
}
_, err = Cfg.Callbacks.FireBefore(EventGet, ctx)
_, err = a.Callbacks.FireBefore(EventGet, ctx)
if err != nil {
return nil, err
}
if index := strings.IndexByte(key, ';'); index > 0 {
return Cfg.OAuth2Storer.GetOAuth(key[:index], key[index+1:])
return a.OAuth2Storer.GetOAuth(key[:index], key[index+1:])
}
return Cfg.Storer.Get(key)
return a.Storer.Get(key)
}
// CurrentUserP retrieves the current user but panics if it's not available for
// any reason.
func CurrentUserP(w http.ResponseWriter, r *http.Request) interface{} {
i, err := CurrentUser(w, r)
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) interface{} {
i, err := a.CurrentUser(w, r)
if err != nil {
panic(err.Error())
}
@ -96,13 +109,13 @@ will be returned.
The error returned is returned either from the updater if that produced an error
or from the cleanup routines.
*/
func UpdatePassword(w http.ResponseWriter, r *http.Request,
func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
ptPassword string, user interface{}, updater func() error) error {
updatePwd := len(ptPassword) > 0
if updatePwd {
pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), Cfg.BCryptCost)
pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), a.BCryptCost)
if err != nil {
return err
}
@ -131,11 +144,11 @@ func UpdatePassword(w http.ResponseWriter, r *http.Request,
return nil
}
ctx, err := ContextFromRequest(r)
ctx, err := a.ContextFromRequest(r)
if err != nil {
return err
}
ctx.SessionStorer = clientStoreWrapper{Cfg.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{Cfg.CookieStoreMaker(w, r)}
return Cfg.Callbacks.FireAfter(EventPasswordReset, ctx)
ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)}
return a.Callbacks.FireAfter(EventPasswordReset, ctx)
}

View File

@ -6,46 +6,41 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
)
func TestMain(main *testing.M) {
RegisterModule("testmodule", testMod)
Cfg.LogWriter = ioutil.Discard
Init()
code := main.Run()
os.Exit(code)
}
func TestAuthBossInit(t *testing.T) {
Cfg = NewConfig()
Cfg.LogWriter = ioutil.Discard
err := Init()
t.Parallel()
ab := New()
ab.LogWriter = ioutil.Discard
err := ab.Init()
if err != nil {
t.Error("Unexpected error:", err)
}
}
func TestAuthBossCurrentUser(t *testing.T) {
Cfg = NewConfig()
Cfg.LogWriter = ioutil.Discard
Cfg.Storer = mockStorer{"joe": Attributes{"email": "john@john.com", "password": "lies"}}
Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
t.Parallel()
ab := New()
ab.LogWriter = ioutil.Discard
ab.Storer = mockStorer{"joe": Attributes{"email": "john@john.com", "password": "lies"}}
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return mockClientStore{SessionKey: "joe"}
}
Cfg.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return mockClientStore{}
}
if err := Init(); err != nil {
if err := ab.Init(); err != nil {
t.Error("Unexpected error:", err)
}
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "localhost", nil)
userStruct := CurrentUserP(rec, req)
userStruct := ab.CurrentUserP(rec, req)
us := userStruct.(*mockUser)
if us.Email != "john@john.com" || us.Password != "lies" {
@ -54,18 +49,20 @@ func TestAuthBossCurrentUser(t *testing.T) {
}
func TestAuthbossUpdatePassword(t *testing.T) {
Cfg = NewConfig()
t.Parallel()
ab := New()
session := mockClientStore{}
cookies := mockClientStore{}
Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return session
}
Cfg.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return cookies
}
called := false
Cfg.Callbacks.After(EventPasswordReset, func(ctx *Context) error {
ab.Callbacks.After(EventPasswordReset, func(ctx *Context) error {
called = true
return nil
})
@ -80,7 +77,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
r, _ := http.NewRequest("GET", "http://localhost", nil)
called = false
err := UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil })
err := ab.UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil })
if err != nil {
t.Error(err)
}
@ -93,7 +90,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
}
called = false
err = UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil })
err = ab.UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil })
if err != nil {
t.Error(err)
}
@ -107,7 +104,7 @@ func TestAuthbossUpdatePassword(t *testing.T) {
called = false
oldPassword := user1.Password
err = UpdatePassword(nil, r, "", &user1, func() error { return nil })
err = ab.UpdatePassword(nil, r, "", &user1, func() error { return nil })
if err != nil {
t.Error(err)
}
@ -121,12 +118,16 @@ func TestAuthbossUpdatePassword(t *testing.T) {
}
func TestAuthbossUpdatePasswordFail(t *testing.T) {
t.Parallel()
ab := New()
user1 := struct {
Password string
}{}
anErr := errors.New("AnError")
err := UpdatePassword(nil, nil, "update", &user1, func() error { return anErr })
err := ab.UpdatePassword(nil, nil, "update", &user1, func() error { return anErr })
if err != anErr {
t.Error("Expected an specific error:", err)
}

View File

@ -115,7 +115,7 @@ func (c *Callbacks) FireBefore(e Event, ctx *Context) (interrupt Interrupt, err
for _, fn := range callbacks {
interrupt, err = fn(ctx)
if err != nil {
fmt.Fprintf(Cfg.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
fmt.Fprintf(ctx.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
return InterruptNone, err
}
if interrupt != InterruptNone {
@ -132,7 +132,7 @@ func (c *Callbacks) FireAfter(e Event, ctx *Context) (err error) {
callbacks := c.after[e]
for _, fn := range callbacks {
if err = fn(ctx); err != nil {
fmt.Fprintf(Cfg.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
fmt.Fprintf(ctx.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
return err
}
}

View File

@ -8,15 +8,17 @@ import (
)
func TestCallbacks(t *testing.T) {
t.Parallel()
ab := New()
afterCalled := false
beforeCalled := false
c := NewCallbacks()
c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
beforeCalled = true
return InterruptNone, nil
})
c.After(EventRegister, func(ctx *Context) error {
ab.Callbacks.After(EventRegister, func(ctx *Context) error {
afterCalled = true
return nil
})
@ -25,7 +27,7 @@ func TestCallbacks(t *testing.T) {
t.Error("Neither should be called.")
}
interrupt, err := c.FireBefore(EventRegister, NewContext())
interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext())
if err != nil {
t.Error("Unexpected error:", err)
}
@ -40,27 +42,29 @@ func TestCallbacks(t *testing.T) {
t.Error("Expected after not to be called.")
}
c.FireAfter(EventRegister, NewContext())
ab.Callbacks.FireAfter(EventRegister, ab.NewContext())
if !afterCalled {
t.Error("Expected after to be called.")
}
}
func TestCallbacksInterrupt(t *testing.T) {
t.Parallel()
ab := New()
before1 := false
before2 := false
c := NewCallbacks()
c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before1 = true
return InterruptAccountLocked, nil
})
c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before2 = true
return InterruptNone, nil
})
interrupt, err := c.FireBefore(EventRegister, NewContext())
interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext())
if err != nil {
t.Error(err)
}
@ -77,26 +81,26 @@ func TestCallbacksInterrupt(t *testing.T) {
}
func TestCallbacksBeforeErrors(t *testing.T) {
t.Parallel()
ab := New()
log := &bytes.Buffer{}
Cfg = &Config{
LogWriter: log,
}
ab.LogWriter = log
before1 := false
before2 := false
c := NewCallbacks()
errValue := errors.New("Problem occured")
c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before1 = true
return InterruptNone, errValue
})
c.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) {
before2 = true
return InterruptNone, nil
})
interrupt, err := c.FireBefore(EventRegister, NewContext())
interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext())
if err != errValue {
t.Error("Expected an error to come back.")
}
@ -117,26 +121,26 @@ func TestCallbacksBeforeErrors(t *testing.T) {
}
func TestCallbacksAfterErrors(t *testing.T) {
t.Parallel()
log := &bytes.Buffer{}
Cfg = &Config{
LogWriter: log,
}
ab := New()
ab.LogWriter = log
after1 := false
after2 := false
c := NewCallbacks()
errValue := errors.New("Problem occured")
c.After(EventRegister, func(ctx *Context) error {
ab.Callbacks.After(EventRegister, func(ctx *Context) error {
after1 = true
return errValue
})
c.After(EventRegister, func(ctx *Context) error {
ab.Callbacks.After(EventRegister, func(ctx *Context) error {
after2 = true
return nil
})
err := c.FireAfter(EventRegister, NewContext())
err := ab.Callbacks.FireAfter(EventRegister, ab.NewContext())
if err != errValue {
t.Error("Expected an error to come back.")
}
@ -154,6 +158,8 @@ func TestCallbacksAfterErrors(t *testing.T) {
}
func TestEventString(t *testing.T) {
t.Parallel()
tests := []struct {
ev Event
str string
@ -178,6 +184,8 @@ func TestEventString(t *testing.T) {
}
func TestInterruptString(t *testing.T) {
t.Parallel()
tests := []struct {
in Interrupt
str string

View File

@ -64,8 +64,8 @@ type CookieStoreMaker func(http.ResponseWriter, *http.Request) ClientStorer
type SessionStoreMaker func(http.ResponseWriter, *http.Request) ClientStorer
// FlashSuccess returns FlashSuccessKey from the session and removes it.
func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
storer := Cfg.SessionStoreMaker(w, r)
func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
storer := a.SessionStoreMaker(w, r)
msg, ok := storer.Get(FlashSuccessKey)
if ok {
storer.Del(FlashSuccessKey)
@ -75,8 +75,8 @@ func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
}
// FlashError returns FlashError from the session and removes it.
func FlashError(w http.ResponseWriter, r *http.Request) string {
storer := Cfg.SessionStoreMaker(w, r)
func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string {
storer := a.SessionStoreMaker(w, r)
msg, ok := storer.Get(FlashErrorKey)
if ok {
storer.Del(FlashErrorKey)

View File

@ -14,6 +14,8 @@ func (t testClientStorerErr) Get(key string) (string, bool) {
func (t testClientStorerErr) Del(key string) {}
func TestClientStorerErr(t *testing.T) {
t.Parallel()
var cs testClientStorerErr
csw := clientStoreWrapper{&cs}
@ -30,19 +32,22 @@ func TestClientStorerErr(t *testing.T) {
}
func TestFlashClearer(t *testing.T) {
t.Parallel()
session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"}
Cfg.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer {
ab := New()
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer {
return session
}
if msg := FlashSuccess(nil, nil); msg != "success" {
if msg := ab.FlashSuccess(nil, nil); msg != "success" {
t.Error("Unexpected flash success:", msg)
}
if msg, ok := session.Get(FlashSuccessKey); ok {
t.Error("Unexpected success flash:", msg)
}
if msg := FlashError(nil, nil); msg != "error" {
if msg := ab.FlashError(nil, nil); msg != "error" {
t.Error("Unexpected flash error:", msg)
}
if msg, ok := session.Get(FlashErrorKey); ok {

View File

@ -10,9 +10,6 @@ import (
"golang.org/x/crypto/bcrypt"
)
// Cfg is the singleton instance of Config
var Cfg = NewConfig()
// Config holds all the configuration for both authboss and it's modules.
type Config struct {
// MountPath is the path to mount authboss's routes at (eg /auth).
@ -117,61 +114,55 @@ type Config struct {
Mailer Mailer
}
// NewConfig creates a config full of healthy default values.
// Notable exceptions to default values are the Storers.
// This method is called automatically on startup and is set to authboss.Cfg
// so implementers need not call it. Primarily exported for testing.
func NewConfig() *Config {
return &Config{
MountPath: "/",
ViewsPath: "./",
RootURL: "http://localhost:8080",
BCryptCost: bcrypt.DefaultCost,
// Defaults sets the configuration's default values.
func (c *Config) Defaults() {
c.MountPath = "/"
c.ViewsPath = "./"
c.RootURL = "http://localhost:8080"
c.BCryptCost = bcrypt.DefaultCost
PrimaryID: StoreEmail,
c.PrimaryID = StoreEmail
Layout: template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`)),
LayoutHTMLEmail: template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`)),
LayoutTextEmail: template.Must(template.New("").Parse(`{{template "authboss" .}}`)),
c.Layout = template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`))
c.LayoutHTMLEmail = template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`))
c.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
AuthLoginOKPath: "/",
AuthLoginFailPath: "/",
AuthLogoutOKPath: "/",
c.AuthLoginOKPath = "/"
c.AuthLoginFailPath = "/"
c.AuthLogoutOKPath = "/"
RecoverOKPath: "/",
RecoverTokenDuration: time.Duration(24) * time.Hour,
c.RecoverOKPath = "/"
c.RecoverTokenDuration = time.Duration(24) * time.Hour
RegisterOKPath: "/",
c.RegisterOKPath = "/"
Policies: []Validator{
Rules{
FieldName: "username",
Required: true,
MinLength: 2,
MaxLength: 4,
AllowWhitespace: false,
},
Rules{
FieldName: "password",
Required: true,
MinLength: 4,
MaxLength: 8,
AllowWhitespace: false,
},
c.Policies = []Validator{
Rules{
FieldName: "username",
Required: true,
MinLength: 2,
MaxLength: 4,
AllowWhitespace: false,
},
ConfirmFields: []string{
StorePassword, ConfirmPrefix + StorePassword,
Rules{
FieldName: "password",
Required: true,
MinLength: 4,
MaxLength: 8,
AllowWhitespace: false,
},
ExpireAfter: 60 * time.Minute,
LockAfter: 3,
LockWindow: 5 * time.Minute,
LockDuration: 5 * time.Hour,
LogWriter: NewDefaultLogger(),
Callbacks: NewCallbacks(),
Mailer: LogMailer(ioutil.Discard),
}
c.ConfirmFields = []string{
StorePassword, ConfirmPrefix + StorePassword,
}
c.ExpireAfter = 60 * time.Minute
c.LockAfter = 3
c.LockWindow = 5 * time.Minute
c.LockDuration = 5 * time.Hour
c.LogWriter = NewDefaultLogger()
c.Callbacks = NewCallbacks()
c.Mailer = LogMailer(ioutil.Discard)
}

View File

@ -53,23 +53,23 @@ type Confirm struct {
// Initialize the module
func (c *Confirm) Initialize() (err error) {
var ok bool
storer, ok := authboss.Cfg.Storer.(ConfirmStorer)
storer, ok := authboss.a.Storer.(ConfirmStorer)
if storer == nil || !ok {
return errors.New("confirm: Need a ConfirmStorer")
}
c.emailHTMLTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutHTMLEmail, authboss.Cfg.ViewsPath, tplConfirmHTML)
c.emailHTMLTemplates, err = response.LoadTemplates(authboss.a.LayoutHTMLEmail, authboss.a.ViewsPath, tplConfirmHTML)
if err != nil {
return err
}
c.emailTextTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutTextEmail, authboss.Cfg.ViewsPath, tplConfirmText)
c.emailTextTemplates, err = response.LoadTemplates(authboss.a.LayoutTextEmail, authboss.a.ViewsPath, tplConfirmText)
if err != nil {
return err
}
authboss.Cfg.Callbacks.Before(authboss.EventGet, c.beforeGet)
authboss.Cfg.Callbacks.Before(authboss.EventAuth, c.beforeGet)
authboss.Cfg.Callbacks.After(authboss.EventRegister, c.afterRegister)
authboss.a.Callbacks.Before(authboss.EventGet, c.beforeGet)
authboss.a.Callbacks.Before(authboss.EventAuth, c.beforeGet)
authboss.a.Callbacks.After(authboss.EventRegister, c.afterRegister)
return nil
}
@ -84,10 +84,10 @@ func (c *Confirm) Routes() authboss.RouteTable {
// Storage requirements
func (c *Confirm) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.Cfg.PrimaryID: authboss.String,
authboss.StoreEmail: authboss.String,
StoreConfirmToken: authboss.String,
StoreConfirmed: authboss.Bool,
authboss.a.PrimaryID: authboss.String,
authboss.StoreEmail: authboss.String,
StoreConfirmToken: authboss.String,
StoreConfirmed: authboss.Bool,
}
}
@ -135,18 +135,18 @@ var goConfirmEmail = func(c *Confirm, to, token string) {
// confirmEmail sends a confirmation e-mail.
func (c *Confirm) confirmEmail(to, token string) {
p := path.Join(authboss.Cfg.MountPath, "confirm")
url := fmt.Sprintf("%s%s?%s=%s", authboss.Cfg.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token))
p := path.Join(authboss.a.MountPath, "confirm")
url := fmt.Sprintf("%s%s?%s=%s", authboss.a.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token))
email := authboss.Email{
To: []string{to},
From: authboss.Cfg.EmailFrom,
Subject: authboss.Cfg.EmailSubjectPrefix + "Confirm New Account",
From: authboss.a.EmailFrom,
Subject: authboss.a.EmailSubjectPrefix + "Confirm New Account",
}
err := response.Email(email, c.emailHTMLTemplates, tplConfirmHTML, c.emailTextTemplates, tplConfirmText, url)
if err != nil {
fmt.Fprintf(authboss.Cfg.LogWriter, "confirm: Failed to send e-mail: %v", err)
fmt.Fprintf(authboss.a.LogWriter, "confirm: Failed to send e-mail: %v", err)
}
}
@ -166,7 +166,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r
sum := md5.Sum(toHash)
dbTok := base64.StdEncoding.EncodeToString(sum[:])
user, err := authboss.Cfg.Storer.(ConfirmStorer).ConfirmUser(dbTok)
user, err := authboss.a.Storer.(ConfirmStorer).ConfirmUser(dbTok)
if err == authboss.ErrUserNotFound {
return authboss.ErrAndRedirect{Location: "/", Err: errors.New("confirm: token not found")}
} else if err != nil {
@ -178,7 +178,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r
ctx.User[StoreConfirmToken] = ""
ctx.User[StoreConfirmed] = true
key, err := ctx.User.StringErr(authboss.Cfg.PrimaryID)
key, err := ctx.User.StringErr(authboss.a.PrimaryID)
if err != nil {
return err
}
@ -188,7 +188,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r
}
ctx.SessionStorer.Put(authboss.SessionKey, key)
response.Redirect(ctx, w, r, authboss.Cfg.RegisterOKPath, "You have successfully confirmed your account.", "", true)
response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "You have successfully confirmed your account.", "", true)
return nil
}

View File

@ -18,9 +18,9 @@ import (
func setup() *Confirm {
authboss.NewConfig()
authboss.Cfg.Storer = mocks.NewMockStorer()
authboss.Cfg.LayoutHTMLEmail = template.Must(template.New("").Parse(`email ^_^`))
authboss.Cfg.LayoutTextEmail = template.Must(template.New("").Parse(`email`))
authboss.a.Storer = mocks.NewMockStorer()
authboss.a.LayoutHTMLEmail = template.Must(template.New("").Parse(`email ^_^`))
authboss.a.LayoutTextEmail = template.Must(template.New("").Parse(`email`))
c := &Confirm{}
if err := c.Initialize(); err != nil {
@ -100,9 +100,9 @@ func TestConfirm_AfterRegister(t *testing.T) {
c := setup()
ctx := authboss.NewContext()
log := &bytes.Buffer{}
authboss.Cfg.LogWriter = log
authboss.Cfg.Mailer = authboss.LogMailer(log)
authboss.Cfg.PrimaryID = authboss.StoreUsername
authboss.a.LogWriter = log
authboss.a.Mailer = authboss.LogMailer(log)
authboss.a.PrimaryID = authboss.StoreUsername
sentEmail := false
@ -115,7 +115,7 @@ func TestConfirm_AfterRegister(t *testing.T) {
t.Error("Expected it to die with user error:", err)
}
ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: "username"}
ctx.User = authboss.Attributes{authboss.a.PrimaryID: "username"}
if err := c.afterRegister(ctx); err == nil || err.(authboss.AttributeErr).Name != "email" {
t.Error("Expected it to die with e-mail address error:", err)
}
@ -135,8 +135,8 @@ func TestConfirm_AfterRegister(t *testing.T) {
func TestConfirm_ConfirmHandlerErrors(t *testing.T) {
c := setup()
log := &bytes.Buffer{}
authboss.Cfg.LogWriter = log
authboss.Cfg.Mailer = authboss.LogMailer(log)
authboss.a.LogWriter = log
authboss.a.Mailer = authboss.LogMailer(log)
tests := []struct {
URL string
@ -177,8 +177,8 @@ func TestConfirm_Confirm(t *testing.T) {
c := setup()
ctx := authboss.NewContext()
log := &bytes.Buffer{}
authboss.Cfg.LogWriter = log
authboss.Cfg.Mailer = authboss.LogMailer(log)
authboss.a.LogWriter = log
authboss.a.Mailer = authboss.LogMailer(log)
// Create a token
token := []byte("hi")
@ -186,7 +186,7 @@ func TestConfirm_Confirm(t *testing.T) {
// Create the "database"
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
user := authboss.Attributes{
authboss.StoreUsername: "usern",
StoreConfirmToken: base64.StdEncoding.EncodeToString(sum[:]),

View File

@ -19,6 +19,8 @@ var (
// need for context is a request's session store. It is not safe for use by
// multiple goroutines.
type Context struct {
*Authboss
SessionStorer ClientStorerErr
CookieStorer ClientStorerErr
User Attributes
@ -28,17 +30,19 @@ type Context struct {
}
// NewContext is exported for testing modules.
func NewContext() *Context {
return &Context{}
func (a *Authboss) NewContext() *Context {
return &Context{
Authboss: a,
}
}
// ContextFromRequest creates a context from an http request.
func ContextFromRequest(r *http.Request) (*Context, error) {
func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) {
if err := r.ParseForm(); err != nil {
return nil, err
}
c := NewContext()
c := a.NewContext()
c.formValues = map[string][]string(r.Form)
c.postFormValues = map[string][]string(r.PostForm)
return c, nil
@ -111,9 +115,9 @@ func (c *Context) LoadUser(key string) error {
var err error
if index := strings.IndexByte(key, ';'); index > 0 {
user, err = Cfg.OAuth2Storer.GetOAuth(key[:index], key[index+1:])
user, err = c.OAuth2Storer.GetOAuth(key[:index], key[index+1:])
} else {
user, err = Cfg.Storer.Get(key)
user, err = c.Storer.Get(key)
}
if err != nil {
return err
@ -144,12 +148,12 @@ func (c *Context) SaveUser() error {
return errors.New("User not initialized.")
}
key, ok := c.User.String(Cfg.PrimaryID)
key, ok := c.User.String(c.PrimaryID)
if !ok {
return errors.New("User improperly initialized, primary ID missing")
}
return Cfg.Storer.Put(key, c.User)
return c.Storer.Put(key, c.User)
}
// Attributes converts the post form values into an attributes map.

View File

@ -8,13 +8,17 @@ import (
)
func TestContext_Request(t *testing.T) {
t.Parallel()
ab := New()
req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form"))
if err != nil {
t.Error("Unexpected Error:", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ContextFromRequest(req)
ctx, err := ab.ContextFromRequest(req)
if err != nil {
t.Error("Unexpected Error:", err)
}
@ -69,10 +73,12 @@ func TestContext_Request(t *testing.T) {
}
func TestContext_SaveUser(t *testing.T) {
Cfg = NewConfig()
ctx := NewContext()
t.Parallel()
ab := New()
ctx := ab.NewContext()
storer := mockStorer{}
Cfg.Storer = storer
ab.Storer = storer
ctx.User = Attributes{StoreUsername: "joe", StoreEmail: "hello@joe.com", StorePassword: "mysticalhash"}
err := ctx.SaveUser()
@ -93,8 +99,10 @@ func TestContext_SaveUser(t *testing.T) {
}
func TestContext_LoadUser(t *testing.T) {
Cfg = NewConfig()
ctx := NewContext()
t.Parallel()
ab := New()
ctx := ab.NewContext()
attr := Attributes{
"email": "hello@joe.com",
@ -107,8 +115,8 @@ func TestContext_LoadUser(t *testing.T) {
"joe": attr,
"whatgoogle": attr,
}
Cfg.Storer = storer
Cfg.OAuth2Storer = storer
ab.Storer = storer
ab.OAuth2Storer = storer
ctx.User = nil
if err := ctx.LoadUser("joe"); err != nil {
@ -144,12 +152,14 @@ func TestContext_LoadUser(t *testing.T) {
}
func TestContext_LoadSessionUser(t *testing.T) {
Cfg = NewConfig()
ctx := NewContext()
t.Parallel()
ab := New()
ctx := ab.NewContext()
storer := mockStorer{
"joe": Attributes{"email": "hello@joe.com", "password": "mysticalhash"},
}
Cfg.Storer = storer
ab.Storer = storer
ctx.SessionStorer = mockClientStore{
SessionKey: "joe",
}
@ -169,9 +179,12 @@ func TestContext_LoadSessionUser(t *testing.T) {
}
func TestContext_Attributes(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
ctx := NewContext()
ab := New()
ctx := ab.NewContext()
ctx.postFormValues = map[string][]string{
"a": []string{"a", "1"},
"b_int": []string{"5", "hello"},

View File

@ -6,6 +6,8 @@ import (
)
func TestAttributeErr(t *testing.T) {
t.Parallel()
estr := "Failed to retrieve database attribute, type was wrong: lol (want: String, got: int)"
if str := NewAttributeErr("lol", String, 5).Error(); str != estr {
t.Error("Error was wrong:", str)
@ -19,6 +21,8 @@ func TestAttributeErr(t *testing.T) {
}
func TestClientDataErr(t *testing.T) {
t.Parallel()
estr := "Failed to retrieve client attribute: lol"
err := ClientDataErr{"lol"}
if str := err.Error(); str != estr {
@ -27,6 +31,8 @@ func TestClientDataErr(t *testing.T) {
}
func TestErrAndRedirect(t *testing.T) {
t.Parallel()
estr := "Error: cause, Redirecting to: /"
err := ErrAndRedirect{errors.New("cause"), "/", "success", "failure"}
if str := err.Error(); str != estr {
@ -35,6 +41,8 @@ func TestErrAndRedirect(t *testing.T) {
}
func TestRenderErr(t *testing.T) {
t.Parallel()
estr := `Error rendering template "lol": cause, data: authboss.HTMLData{"a":5}`
err := RenderErr{"lol", NewHTMLData("a", 5), errors.New("cause")}
if str := err.Error(); str != estr {

View File

@ -8,14 +8,14 @@ import (
var nowTime = time.Now
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
func TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
return timeToExpiry(Cfg.SessionStoreMaker(w, r))
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
return a.timeToExpiry(a.SessionStoreMaker(w, r))
}
func timeToExpiry(session ClientStorer) time.Duration {
func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
dateStr, ok := session.Get(SessionLastAction)
if !ok {
return Cfg.ExpireAfter
return a.ExpireAfter
}
date, err := time.Parse(time.RFC3339, dateStr)
@ -23,7 +23,7 @@ func timeToExpiry(session ClientStorer) time.Duration {
panic("last_action is not a valid RFC3339 date")
}
remaining := date.Add(Cfg.ExpireAfter).Sub(nowTime().UTC())
remaining := date.Add(a.ExpireAfter).Sub(nowTime().UTC())
if remaining > 0 {
return remaining
}
@ -32,35 +32,36 @@ func timeToExpiry(session ClientStorer) time.Duration {
}
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
func RefreshExpiry(w http.ResponseWriter, r *http.Request) {
session := Cfg.SessionStoreMaker(w, r)
refreshExpiry(session)
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
session := a.SessionStoreMaker(w, r)
a.refreshExpiry(session)
}
func refreshExpiry(session ClientStorer) {
func (a *Authboss) refreshExpiry(session ClientStorer) {
session.Put(SessionLastAction, nowTime().UTC().Format(time.RFC3339))
}
type expireMiddleware struct {
ab *Authboss
next http.Handler
}
// ExpireMiddleware ensures that the user's expiry information is kept up-to-date
// on each request. Deletes the SessionKey from the session if the user is
// expired (Cfg.ExpireAfter duration since SessionLastAction).
func ExpireMiddleware(next http.Handler) http.Handler {
return expireMiddleware{next}
// expired (a.ExpireAfter duration since SessionLastAction).
func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
return expireMiddleware{a, next}
}
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session := Cfg.SessionStoreMaker(w, r)
session := m.ab.SessionStoreMaker(w, r)
if _, ok := session.Get(SessionKey); ok {
ttl := timeToExpiry(session)
ttl := m.ab.timeToExpiry(session)
if ttl == 0 {
session.Del(SessionKey)
session.Del(SessionLastAction)
} else {
refreshExpiry(session)
m.ab.refreshExpiry(session)
}
}

View File

@ -7,15 +7,17 @@ import (
"time"
)
// These tests use the global variable nowTime so cannot be parallelized
func TestDudeIsExpired(t *testing.T) {
Cfg = NewConfig()
ab := New()
session := mockClientStore{SessionKey: "username"}
refreshExpiry(session)
ab.refreshExpiry(session)
nowTime = func() time.Time {
return time.Now().UTC().Add(Cfg.ExpireAfter * 2)
return time.Now().UTC().Add(ab.ExpireAfter * 2)
}
Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return session
}
@ -23,7 +25,7 @@ func TestDudeIsExpired(t *testing.T) {
w := httptest.NewRecorder()
called := false
m := ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}))
@ -43,14 +45,14 @@ func TestDudeIsExpired(t *testing.T) {
}
func TestDudeIsNotExpired(t *testing.T) {
Cfg = NewConfig()
ab := New()
session := mockClientStore{SessionKey: "username"}
refreshExpiry(session)
ab.refreshExpiry(session)
nowTime = func() time.Time {
return time.Now().UTC().Add(Cfg.ExpireAfter / 2)
return time.Now().UTC().Add(ab.ExpireAfter / 2)
}
Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return session
}
@ -58,7 +60,7 @@ func TestDudeIsNotExpired(t *testing.T) {
w := httptest.NewRecorder()
called := false
m := ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}))

View File

@ -25,10 +25,10 @@ var (
funcMap = template.FuncMap{
"title": strings.Title,
"mountpathed": func(location string) string {
if authboss.Cfg.MountPath == "/" {
if authboss.a.MountPath == "/" {
return location
}
return path.Join(authboss.Cfg.MountPath, location)
return path.Join(authboss.a.MountPath, location)
},
}
)
@ -77,10 +77,10 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http.
return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound}
}
data.MergeKV("xsrfName", template.HTML(authboss.Cfg.XSRFName), "xsrfToken", template.HTML(authboss.Cfg.XSRFMaker(w, r)))
data.MergeKV("xsrfName", template.HTML(authboss.a.XSRFName), "xsrfToken", template.HTML(authboss.a.XSRFMaker(w, r)))
if authboss.Cfg.LayoutDataMaker != nil {
data.Merge(authboss.Cfg.LayoutDataMaker(w, r))
if authboss.a.LayoutDataMaker != nil {
data.Merge(authboss.a.LayoutDataMaker(w, r))
}
if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok {
@ -130,7 +130,7 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T
}
email.TextBody = plainBuffer.String()
if err := authboss.Cfg.Mailer.Send(email); err != nil {
if err := authboss.a.Mailer.Send(email); err != nil {
return err
}

View File

@ -86,7 +86,7 @@ func TestTemplates_Render(t *testing.T) {
func Test_Email(t *testing.T) {
mockMailer := &mocks.MockMailer{}
authboss.Cfg.Mailer = mockMailer
authboss.a.Mailer = mockMailer
htmlTpls := Templates{"html": testEmailHTMLTempalte}
textTpls := Templates{"plain": testEmailPlainTempalte}

View File

@ -29,15 +29,15 @@ type Lock struct {
// Initialize the module
func (l *Lock) Initialize() error {
if authboss.Cfg.Storer == nil {
if authboss.a.Storer == nil {
return errors.New("lock: Need a Storer")
}
// Events
authboss.Cfg.Callbacks.Before(authboss.EventGet, l.beforeAuth)
authboss.Cfg.Callbacks.Before(authboss.EventAuth, l.beforeAuth)
authboss.Cfg.Callbacks.After(authboss.EventAuth, l.afterAuth)
authboss.Cfg.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail)
authboss.a.Callbacks.Before(authboss.EventGet, l.beforeAuth)
authboss.a.Callbacks.Before(authboss.EventAuth, l.beforeAuth)
authboss.a.Callbacks.After(authboss.EventAuth, l.afterAuth)
authboss.a.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail)
return nil
}
@ -50,10 +50,10 @@ func (l *Lock) Routes() authboss.RouteTable {
// Storage requirements
func (l *Lock) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.Cfg.PrimaryID: authboss.String,
StoreAttemptNumber: authboss.Integer,
StoreAttemptTime: authboss.DateTime,
StoreLocked: authboss.DateTime,
authboss.a.PrimaryID: authboss.String,
StoreAttemptNumber: authboss.Integer,
StoreAttemptTime: authboss.DateTime,
StoreLocked: authboss.DateTime,
}
}
@ -104,9 +104,9 @@ func (l *Lock) afterAuthFail(ctx *authboss.Context) error {
nAttempts++
if time.Now().UTC().Sub(lastAttempt) <= authboss.Cfg.LockWindow {
if nAttempts >= int64(authboss.Cfg.LockAfter) {
ctx.User[StoreLocked] = time.Now().UTC().Add(authboss.Cfg.LockDuration)
if time.Now().UTC().Sub(lastAttempt) <= authboss.a.LockWindow {
if nAttempts >= int64(authboss.a.LockAfter) {
ctx.User[StoreLocked] = time.Now().UTC().Add(authboss.a.LockDuration)
}
ctx.User[StoreAttemptNumber] = nAttempts
@ -124,7 +124,7 @@ func (l *Lock) afterAuthFail(ctx *authboss.Context) error {
// Lock a user manually.
func (l *Lock) Lock(key string) error {
user, err := authboss.Cfg.Storer.Get(key)
user, err := authboss.a.Storer.Get(key)
if err != nil {
return err
}
@ -134,14 +134,14 @@ func (l *Lock) Lock(key string) error {
return err
}
attr[StoreLocked] = time.Now().UTC().Add(authboss.Cfg.LockDuration)
attr[StoreLocked] = time.Now().UTC().Add(authboss.a.LockDuration)
return authboss.Cfg.Storer.Put(key, attr)
return authboss.a.Storer.Put(key, attr)
}
// Unlock a user that was locked by this module.
func (l *Lock) Unlock(key string) error {
user, err := authboss.Cfg.Storer.Get(key)
user, err := authboss.a.Storer.Get(key)
if err != nil {
return err
}
@ -153,9 +153,9 @@ func (l *Lock) Unlock(key string) error {
// Set the last attempt to be -window*2 to avoid immediately
// giving another login failure.
attr[StoreAttemptTime] = time.Now().UTC().Add(-authboss.Cfg.LockWindow * 2)
attr[StoreAttemptTime] = time.Now().UTC().Add(-authboss.a.LockWindow * 2)
attr[StoreAttemptNumber] = int64(0)
attr[StoreLocked] = time.Now().UTC().Add(-authboss.Cfg.LockDuration)
attr[StoreLocked] = time.Now().UTC().Add(-authboss.a.LockDuration)
return authboss.Cfg.Storer.Put(key, attr)
return authboss.a.Storer.Put(key, attr)
}

View File

@ -53,8 +53,8 @@ func TestAfterAuth(t *testing.T) {
}
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: "john@john.com"}
authboss.a.Storer = storer
ctx.User = authboss.Attributes{authboss.a.PrimaryID: "john@john.com"}
if err := lock.afterAuth(ctx); err != nil {
t.Error(err)
@ -74,15 +74,15 @@ func TestAfterAuthFail_Lock(t *testing.T) {
ctx := authboss.NewContext()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
lock := Lock{}
authboss.Cfg.LockWindow = 30 * time.Minute
authboss.Cfg.LockDuration = 30 * time.Minute
authboss.Cfg.LockAfter = 3
authboss.a.LockWindow = 30 * time.Minute
authboss.a.LockDuration = 30 * time.Minute
authboss.a.LockAfter = 3
email := "john@john.com"
ctx.User = map[string]interface{}{authboss.Cfg.PrimaryID: email}
ctx.User = map[string]interface{}{authboss.a.PrimaryID: email}
old = time.Now().UTC().Add(-1 * time.Hour)
@ -123,17 +123,17 @@ func TestAfterAuthFail_Reset(t *testing.T) {
ctx := authboss.NewContext()
storer := mocks.NewMockStorer()
lock := Lock{}
authboss.Cfg.LockWindow = 30 * time.Minute
authboss.Cfg.Storer = storer
authboss.a.LockWindow = 30 * time.Minute
authboss.a.Storer = storer
old = time.Now().UTC().Add(-time.Hour)
email := "john@john.com"
ctx.User = map[string]interface{}{
authboss.Cfg.PrimaryID: email,
StoreAttemptNumber: int64(2),
StoreAttemptTime: old,
StoreLocked: old,
authboss.a.PrimaryID: email,
StoreAttemptNumber: int64(2),
StoreAttemptTime: old,
StoreLocked: old,
}
lock.afterAuthFail(ctx)
@ -162,13 +162,13 @@ func TestAfterAuthFail_Errors(t *testing.T) {
func TestLock(t *testing.T) {
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
lock := Lock{}
email := "john@john.com"
storer.Users[email] = map[string]interface{}{
authboss.Cfg.PrimaryID: email,
"password": "password",
authboss.a.PrimaryID: email,
"password": "password",
}
err := lock.Lock(email)
@ -184,15 +184,15 @@ func TestLock(t *testing.T) {
func TestUnlock(t *testing.T) {
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
lock := Lock{}
authboss.Cfg.LockWindow = 1 * time.Hour
authboss.a.LockWindow = 1 * time.Hour
email := "john@john.com"
storer.Users[email] = map[string]interface{}{
authboss.Cfg.PrimaryID: email,
"password": "password",
"locked": true,
authboss.a.PrimaryID: email,
"password": "password",
"locked": true,
}
err := lock.Unlock(email)
@ -201,7 +201,7 @@ func TestUnlock(t *testing.T) {
}
attemptTime := storer.Users[email][StoreAttemptTime].(time.Time)
if attemptTime.After(time.Now().UTC().Add(-authboss.Cfg.LockWindow)) {
if attemptTime.After(time.Now().UTC().Add(-authboss.a.LockWindow)) {
t.Error("StoreLocked not set correctly:", attemptTime)
}
if number := storer.Users[email][StoreAttemptNumber].(int64); number != int64(0) {

View File

@ -9,6 +9,8 @@ import (
)
func TestDefaultLogger(t *testing.T) {
t.Parallel()
logger := NewDefaultLogger()
if logger == nil {
t.Error("Logger was not created.")
@ -16,6 +18,8 @@ func TestDefaultLogger(t *testing.T) {
}
func TestDefaultLoggerOutput(t *testing.T) {
t.Parallel()
buffer := &bytes.Buffer{}
logger := (*DefaultLogger)(log.New(buffer, "", log.LstdFlags))
io.WriteString(logger, "hello world")

View File

@ -10,8 +10,8 @@ import (
)
// SendMail uses the currently configured mailer to deliver e-mails.
func SendMail(data Email) error {
return Cfg.Mailer.Send(data)
func (a *Authboss) SendMail(data Email) error {
return a.Mailer.Send(data)
}
// Mailer is a type that is capable of sending an e-mail.

View File

@ -8,15 +8,16 @@ import (
)
func TestMailer(t *testing.T) {
Cfg = NewConfig()
t.Parallel()
ab := New()
mailServer := &bytes.Buffer{}
Cfg.Mailer = LogMailer(mailServer)
Cfg.Storer = mockStorer{}
Cfg.LogWriter = ioutil.Discard
Init()
ab.Mailer = LogMailer(mailServer)
ab.Storer = mockStorer{}
ab.LogWriter = ioutil.Discard
err := SendMail(Email{
err := ab.SendMail(Email{
To: []string{"some@email.com", "a@a.com"},
ToNames: []string{"Jake", "Noname"},
From: "some@guy.com",
@ -53,6 +54,8 @@ func TestMailer(t *testing.T) {
}
func TestSMTPMailer(t *testing.T) {
t.Parallel()
var _ Mailer = SMTPMailer("server", nil)
recovered := false

View File

@ -56,7 +56,7 @@ func (m mockClientStore) GetErr(key string) (string, error) {
func (m mockClientStore) Put(key, val string) { m[key] = val }
func (m mockClientStore) Del(key string) { delete(m, key) }
func mockRequestContext(postKeyValues ...string) *Context {
func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context {
keyValues := &bytes.Buffer{}
for i := 0; i < len(postKeyValues); i += 2 {
if i != 0 {
@ -71,7 +71,7 @@ func mockRequestContext(postKeyValues ...string) *Context {
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
ctx, err := ContextFromRequest(req)
ctx, err := ab.ContextFromRequest(req)
if err != nil {
panic(err)
}

View File

@ -9,7 +9,7 @@ var ModuleAttributes = make(AttributeMeta)
// Modularizer should be implemented by all the authboss modules.
type Modularizer interface {
Initialize() error
Initialize(*Authboss) error
Routes() RouteTable
Storage() StorageOptions
}

View File

@ -23,12 +23,13 @@ func testHandler(ctx *Context, w http.ResponseWriter, r *http.Request) error {
return nil
}
func (t *testModule) Initialize() error { return nil }
func (t *testModule) Routes() RouteTable { return t.r }
func (t *testModule) Storage() StorageOptions { return t.s }
func (t *testModule) Initialize(a *Authboss) error { return nil }
func (t *testModule) Routes() RouteTable { return t.r }
func (t *testModule) Storage() StorageOptions { return t.s }
func TestRegister(t *testing.T) {
// RegisterModule called by TestMain.
modules = make(map[string]Modularizer)
RegisterModule("testmodule", testMod)
if _, ok := modules["testmodule"]; !ok {
t.Error("Expected module to be saved.")
@ -40,7 +41,8 @@ func TestRegister(t *testing.T) {
}
func TestLoadedModules(t *testing.T) {
// RegisterModule called by TestMain.
modules = make(map[string]Modularizer)
RegisterModule("testmodule", testMod)
loadedMods := LoadedModules()
if len(loadedMods) != 1 {

View File

@ -30,7 +30,7 @@ func init() {
// Initialize module
func (o *OAuth2) Initialize() error {
if authboss.Cfg.OAuth2Storer == nil {
if authboss.a.OAuth2Storer == nil {
return errors.New("oauth2: need an OAuth2Storer")
}
return nil
@ -40,7 +40,7 @@ func (o *OAuth2) Initialize() error {
func (o *OAuth2) Routes() authboss.RouteTable {
routes := make(authboss.RouteTable)
for prov, cfg := range authboss.Cfg.OAuth2Providers {
for prov, cfg := range authboss.a.OAuth2Providers {
prov = strings.ToLower(prov)
init := fmt.Sprintf("/oauth2/%s", prov)
@ -49,11 +49,11 @@ func (o *OAuth2) Routes() authboss.RouteTable {
routes[init] = oauthInit
routes[callback] = oauthCallback
if len(authboss.Cfg.MountPath) > 0 {
callback = path.Join(authboss.Cfg.MountPath, callback)
if len(authboss.a.MountPath) > 0 {
callback = path.Join(authboss.a.MountPath, callback)
}
cfg.OAuth2Config.RedirectURL = authboss.Cfg.RootURL + callback
a.OAuth2Config.RedirectURL = authboss.a.RootURL + callback
}
routes["/oauth2/logout"] = logout
@ -75,7 +75,7 @@ func (o *OAuth2) Storage() authboss.StorageOptions {
func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
provider := strings.ToLower(filepath.Base(r.URL.Path))
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
cfg, ok := authboss.a.OAuth2Providers[provider]
if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider)
}
@ -106,9 +106,9 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
}
url := cfg.OAuth2Config.AuthCodeURL(state)
url := a.OAuth2Config.AuthCodeURL(state)
extraParams := cfg.AdditionalParams.Encode()
extraParams := a.AdditionalParams.Encode()
if len(extraParams) > 0 {
url = fmt.Sprintf("%s&%s", url, extraParams)
}
@ -140,18 +140,18 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
hasErr := r.FormValue("error")
if len(hasErr) > 0 {
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
if err := authboss.a.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil {
return err
}
return authboss.ErrAndRedirect{
Err: errors.New(r.FormValue("error_reason")),
Location: authboss.Cfg.AuthLoginFailPath,
Location: authboss.a.AuthLoginFailPath,
FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)),
}
}
cfg, ok := authboss.Cfg.OAuth2Providers[provider]
cfg, ok := authboss.a.OAuth2Providers[provider]
if !ok {
return fmt.Errorf("OAuth2 provider %q not found", provider)
}
@ -165,12 +165,12 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
// Get the code
code := r.FormValue("code")
token, err := exchanger(cfg.OAuth2Config, oauth2.NoContext, code)
token, err := exchanger(a.OAuth2Config, oauth2.NoContext, code)
if err != nil {
return fmt.Errorf("Could not validate oauth2 code: %v", err)
}
user, err := cfg.Callback(*cfg.OAuth2Config, token)
user, err := a.Callback(*cfg.OAuth2Config, token)
if err != nil {
return err
}
@ -189,7 +189,7 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
user[authboss.StoreOAuth2Refresh] = token.RefreshToken
}
if err = authboss.Cfg.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
if err = authboss.a.OAuth2Storer.PutOAuth(uid, provider, user); err != nil {
return err
}
@ -197,13 +197,13 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request
ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider))
ctx.SessionStorer.Del(authboss.SessionHalfAuthKey)
if err = authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
if err = authboss.a.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil {
return nil
}
ctx.SessionStorer.Del(authboss.SessionOAuth2Params)
redirect := authboss.Cfg.AuthLoginOKPath
redirect := authboss.a.AuthLoginOKPath
query := make(url.Values)
for k, v := range values {
switch k {
@ -231,7 +231,7 @@ func logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error
ctx.CookieStorer.Del(authboss.CookieRemember)
ctx.SessionStorer.Del(authboss.SessionLastAction)
response.Redirect(ctx, w, r, authboss.Cfg.AuthLogoutOKPath, "You have logged out", "", true)
response.Redirect(ctx, w, r, authboss.a.AuthLogoutOKPath, "You have logged out", "", true)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}

View File

@ -31,7 +31,7 @@ var testProviders = map[string]authboss.OAuth2Provider{
func TestInitialize(t *testing.T) {
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.OAuth2Storer = mocks.NewMockStorer()
authboss.a.OAuth2Storer = mocks.NewMockStorer()
o := OAuth2{}
if err := o.Initialize(); err != nil {
t.Error(err)
@ -43,11 +43,11 @@ func TestRoutes(t *testing.T) {
mount := "/auth"
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.RootURL = root
authboss.Cfg.MountPath = mount
authboss.Cfg.OAuth2Providers = testProviders
authboss.a.RootURL = root
authboss.a.MountPath = mount
authboss.a.OAuth2Providers = testProviders
googleCfg := authboss.Cfg.OAuth2Providers["google"].OAuth2Config
googleCfg := authboss.a.OAuth2Providers["google"].OAuth2Config
if 0 != len(googleCfg.RedirectURL) {
t.Error("RedirectURL should not be set")
}
@ -74,7 +74,7 @@ func TestOAuth2Init(t *testing.T) {
cfg := authboss.NewConfig()
session := mocks.NewMockClientStorer()
cfg.OAuth2Providers = testProviders
a.OAuth2Providers = testProviders
authboss.Cfg = cfg
r, _ := http.NewRequest("GET", "/oauth2/google?redir=/my/redirect%23lol&rm=true", nil)
@ -137,7 +137,7 @@ func TestOAuthSuccess(t *testing.T) {
return fakeToken, nil
}
cfg.OAuth2Providers = map[string]authboss.OAuth2Provider{
a.OAuth2Providers = map[string]authboss.OAuth2Provider{
"fake": authboss.OAuth2Provider{
OAuth2Config: &oauth2.Config{
ClientID: `jazz`,
@ -168,8 +168,8 @@ func TestOAuthSuccess(t *testing.T) {
storer := mocks.NewMockStorer()
ctx.SessionStorer = session
cfg.OAuth2Storer = storer
cfg.AuthLoginOKPath = "/fakeloginok"
a.OAuth2Storer = storer
a.AuthLoginOKPath = "/fakeloginok"
if err := oauthCallback(ctx, w, r); err != nil {
t.Error(err)
@ -214,7 +214,7 @@ func TestOAuthXSRFFailure(t *testing.T) {
session := mocks.NewMockClientStorer()
session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State)
cfg.OAuth2Providers = testProviders
a.OAuth2Providers = testProviders
authboss.Cfg = cfg
values := url.Values{}
@ -234,7 +234,7 @@ func TestOAuthXSRFFailure(t *testing.T) {
func TestOAuthFailure(t *testing.T) {
cfg := authboss.NewConfig()
cfg.OAuth2Providers = testProviders
a.OAuth2Providers = testProviders
authboss.Cfg = cfg
values := url.Values{}
@ -260,7 +260,7 @@ func TestOAuthFailure(t *testing.T) {
func TestLogout(t *testing.T) {
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.AuthLogoutOKPath = "/dashboard"
authboss.a.AuthLogoutOKPath = "/dashboard"
r, _ := http.NewRequest("GET", "/oauth2/google?", nil)
w := httptest.NewRecorder()
@ -292,7 +292,7 @@ func TestLogout(t *testing.T) {
}
location := w.Header().Get("Location")
if location != authboss.Cfg.AuthLogoutOKPath {
if location != authboss.a.AuthLogoutOKPath {
t.Error("Redirect wrong:", location)
}
}

View File

@ -26,8 +26,8 @@ type googleMeResponse struct {
var clientGet = (*http.Client).Get
// Google is a callback appropriate for use with Google's OAuth2 configuration.
func Google(cfg oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) {
client := cfg.Client(oauth2.NoContext, token)
func Google(a.oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) {
client := a.Client(oauth2.NoContext, token)
resp, err := clientGet(client, googleInfoEndpoint)
if err != nil {
return nil, err

View File

@ -63,32 +63,32 @@ type Recover struct {
// Initialize module
func (r *Recover) Initialize() (err error) {
if authboss.Cfg.Storer == nil {
if authboss.a.Storer == nil {
return errors.New("recover: Need a RecoverStorer")
}
if _, ok := authboss.Cfg.Storer.(RecoverStorer); !ok {
if _, ok := authboss.a.Storer.(RecoverStorer); !ok {
return errors.New("recover: RecoverStorer required for recover functionality")
}
if len(authboss.Cfg.XSRFName) == 0 {
if len(authboss.a.XSRFName) == 0 {
return errors.New("auth: XSRFName must be set")
}
if authboss.Cfg.XSRFMaker == nil {
if authboss.a.XSRFMaker == nil {
return errors.New("auth: XSRFMaker must be defined")
}
r.templates, err = response.LoadTemplates(authboss.Cfg.Layout, authboss.Cfg.ViewsPath, tplRecover, tplRecoverComplete)
r.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplRecover, tplRecoverComplete)
if err != nil {
return err
}
r.emailHTMLTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutHTMLEmail, authboss.Cfg.ViewsPath, tplInitHTMLEmail)
r.emailHTMLTemplates, err = response.LoadTemplates(authboss.a.LayoutHTMLEmail, authboss.a.ViewsPath, tplInitHTMLEmail)
if err != nil {
return err
}
r.emailTextTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutTextEmail, authboss.Cfg.ViewsPath, tplInitTextEmail)
r.emailTextTemplates, err = response.LoadTemplates(authboss.a.LayoutTextEmail, authboss.a.ViewsPath, tplInitTextEmail)
if err != nil {
return err
}
@ -107,7 +107,7 @@ func (r *Recover) Routes() authboss.RouteTable {
// Storage requirements
func (r *Recover) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.Cfg.PrimaryID: authboss.String,
authboss.a.PrimaryID: authboss.String,
authboss.StoreEmail: authboss.String,
authboss.StorePassword: authboss.String,
StoreRecoverToken: authboss.String,
@ -119,31 +119,31 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
switch r.Method {
case methodGET:
data := authboss.NewHTMLData(
"primaryID", authboss.Cfg.PrimaryID,
"primaryID", authboss.a.PrimaryID,
"primaryIDValue", "",
"confirmPrimaryIDValue", "",
)
return rec.templates.Render(ctx, w, r, tplRecover, data)
case methodPOST:
primaryID, _ := ctx.FirstPostFormValue(authboss.Cfg.PrimaryID)
confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", authboss.Cfg.PrimaryID))
primaryID, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID)
confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", authboss.a.PrimaryID))
errData := authboss.NewHTMLData(
"primaryID", authboss.Cfg.PrimaryID,
"primaryID", authboss.a.PrimaryID,
"primaryIDValue", primaryID,
"confirmPrimaryIDValue", confirmPrimaryID,
)
policies := authboss.FilterValidators(authboss.Cfg.Policies, authboss.Cfg.PrimaryID)
if validationErrs := ctx.Validate(policies, authboss.Cfg.PrimaryID, authboss.ConfirmPrefix+authboss.Cfg.PrimaryID).Map(); len(validationErrs) > 0 {
policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID)
if validationErrs := ctx.Validate(policies, authboss.a.PrimaryID, authboss.ConfirmPrefix+authboss.a.PrimaryID).Map(); len(validationErrs) > 0 {
errData.MergeKV("errs", validationErrs)
return rec.templates.Render(ctx, w, r, tplRecover, errData)
}
// redirect to login when user not found to prevent username sniffing
if err := ctx.LoadUser(primaryID); err == authboss.ErrUserNotFound {
return authboss.ErrAndRedirect{err, authboss.Cfg.RecoverOKPath, recoverInitiateSuccessFlash, ""}
return authboss.ErrAndRedirect{err, authboss.a.RecoverOKPath, recoverInitiateSuccessFlash, ""}
} else if err != nil {
return err
}
@ -159,7 +159,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
}
ctx.User[StoreRecoverToken] = encodedChecksum
ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(authboss.Cfg.RecoverTokenDuration)
ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(authboss.a.RecoverTokenDuration)
if err := ctx.SaveUser(); err != nil {
return err
@ -168,7 +168,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite
goRecoverEmail(rec, email, encodedToken)
ctx.SessionStorer.Put(authboss.FlashSuccessKey, recoverInitiateSuccessFlash)
response.Redirect(ctx, w, r, authboss.Cfg.RecoverOKPath, "", "", true)
response.Redirect(ctx, w, r, authboss.a.RecoverOKPath, "", "", true)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
@ -191,17 +191,17 @@ var goRecoverEmail = func(r *Recover, to, encodedToken string) {
}
func (r *Recover) sendRecoverEmail(to, encodedToken string) {
p := path.Join(authboss.Cfg.MountPath, "recover/complete")
url := fmt.Sprintf("%s%s?token=%s", authboss.Cfg.RootURL, p, encodedToken)
p := path.Join(authboss.a.MountPath, "recover/complete")
url := fmt.Sprintf("%s%s?token=%s", authboss.a.RootURL, p, encodedToken)
email := authboss.Email{
To: []string{to},
From: authboss.Cfg.EmailFrom,
Subject: authboss.Cfg.EmailSubjectPrefix + "Password Reset",
From: authboss.a.EmailFrom,
Subject: authboss.a.EmailSubjectPrefix + "Password Reset",
}
if err := response.Email(email, r.emailHTMLTemplates, tplInitHTMLEmail, r.emailTextTemplates, tplInitTextEmail, url); err != nil {
fmt.Fprintln(authboss.Cfg.LogWriter, "recover: failed to send recover email:", err)
fmt.Fprintln(authboss.a.LogWriter, "recover: failed to send recover email:", err)
}
}
@ -227,7 +227,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
password, _ := ctx.FirstPostFormValue("password")
confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword")
policies := authboss.FilterValidators(authboss.Cfg.Policies, "password")
policies := authboss.FilterValidators(authboss.a.Policies, "password")
if validationErrs := ctx.Validate(policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 {
data := authboss.NewHTMLData(
"token", token,
@ -242,7 +242,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
return err
}
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), authboss.Cfg.BCryptCost)
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), authboss.a.BCryptCost)
if err != nil {
return err
}
@ -252,7 +252,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
var nullTime time.Time
ctx.User[StoreRecoverTokenExpiry] = nullTime
primaryID, err := ctx.User.StringErr(authboss.Cfg.PrimaryID)
primaryID, err := ctx.User.StringErr(authboss.a.PrimaryID)
if err != nil {
return err
}
@ -261,12 +261,12 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit
return err
}
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil {
if err := authboss.a.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil {
return err
}
ctx.SessionStorer.Put(authboss.SessionKey, primaryID)
response.Redirect(ctx, w, req, authboss.Cfg.AuthLoginOKPath, "", "", true)
response.Redirect(ctx, w, req, authboss.a.AuthLoginOKPath, "", "", true)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
@ -287,7 +287,7 @@ func verifyToken(ctx *authboss.Context) (attrs authboss.Attributes, err error) {
}
sum := md5.Sum(decoded)
storer := authboss.Cfg.Storer.(RecoverStorer)
storer := authboss.a.Storer.(RecoverStorer)
userInter, err := storer.RecoverUser(base64.StdEncoding.EncodeToString(sum[:]))
if err != nil {

View File

@ -26,16 +26,16 @@ func testSetup() (r *Recover, s *mocks.MockStorer, l *bytes.Buffer) {
l = &bytes.Buffer{}
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.Cfg.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`))
authboss.Cfg.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.Cfg.Storer = s
authboss.Cfg.XSRFName = "xsrf"
authboss.Cfg.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`))
authboss.a.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.Storer = s
authboss.a.XSRFName = "xsrf"
authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "xsrfvalue"
}
authboss.Cfg.PrimaryID = authboss.StoreUsername
authboss.Cfg.LogWriter = l
authboss.a.PrimaryID = authboss.StoreUsername
authboss.a.LogWriter = l
r = &Recover{}
if err := r.Initialize(); err != nil {
@ -62,8 +62,8 @@ func TestRecover(t *testing.T) {
r, _, _ := testSetup()
storage := r.Storage()
if storage[authboss.Cfg.PrimaryID] != authboss.String {
t.Error("Expected storage KV:", authboss.Cfg.PrimaryID, authboss.String)
if storage[authboss.a.PrimaryID] != authboss.String {
t.Error("Expected storage KV:", authboss.a.PrimaryID, authboss.String)
}
if storage[authboss.StoreEmail] != authboss.String {
t.Error("Expected storage KV:", authboss.StoreEmail, authboss.String)
@ -103,10 +103,10 @@ func TestRecover_startHandlerFunc_GET(t *testing.T) {
if !strings.Contains(body, `<form action="recover"`) {
t.Error("Should have rendered a form")
}
if !strings.Contains(body, `name="`+authboss.Cfg.PrimaryID) {
if !strings.Contains(body, `name="`+authboss.a.PrimaryID) {
t.Error("Form should contain the primary ID field")
}
if !strings.Contains(body, `name="confirm_`+authboss.Cfg.PrimaryID) {
if !strings.Contains(body, `name="confirm_`+authboss.a.PrimaryID) {
t.Error("Form should contain the confirm primary ID field")
}
}
@ -141,7 +141,7 @@ func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) {
t.Error("Expected ErrAndRedirect error")
}
if rerr.Location != authboss.Cfg.RecoverOKPath {
if rerr.Location != authboss.a.RecoverOKPath {
t.Error("Unexpected location:", rerr.Location)
}
@ -187,7 +187,7 @@ func TestRecover_startHandlerFunc_POST(t *testing.T) {
}
loc := w.Header().Get("Location")
if loc != authboss.Cfg.RecoverOKPath {
if loc != authboss.a.RecoverOKPath {
t.Error("Unexpected location:", loc)
}
@ -237,7 +237,7 @@ func TestRecover_sendRecoverMail_FailToSend(t *testing.T) {
mailer := mocks.NewMockMailer()
mailer.SendErr = "failed to send"
authboss.Cfg.Mailer = mailer
authboss.a.Mailer = mailer
a.sendRecoverEmail("", "")
@ -250,9 +250,9 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
a, _, _ := testSetup()
mailer := mocks.NewMockMailer()
authboss.Cfg.EmailSubjectPrefix = "foo "
authboss.Cfg.RootURL = "bar"
authboss.Cfg.Mailer = mailer
authboss.a.EmailSubjectPrefix = "foo "
authboss.a.RootURL = "bar"
authboss.a.Mailer = mailer
a.sendRecoverEmail("a@b.c", "abc=")
if len(mailer.Last.To) != 1 {
@ -265,7 +265,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) {
t.Error("Unexpected subject:", mailer.Last.Subject)
}
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.Cfg.RootURL)
url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.a.RootURL)
if !strings.Contains(mailer.Last.HTMLBody, url) {
t.Error("Expected HTMLBody to contain url:", url)
}
@ -377,12 +377,12 @@ func TestRecover_completeHandlerFunc_POST_VerificationFails(t *testing.T) {
func TestRecover_completeHandlerFunc_POST(t *testing.T) {
rec, storer, _ := testSetup()
storer.Users["john"] = authboss.Attributes{authboss.Cfg.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"}
storer.Users["john"] = authboss.Attributes{authboss.a.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"}
cbCalled := false
authboss.Cfg.Callbacks = authboss.NewCallbacks()
authboss.Cfg.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error {
authboss.a.Callbacks = authboss.NewCallbacks()
authboss.a.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error {
cbCalled = true
return nil
})
@ -421,7 +421,7 @@ func TestRecover_completeHandlerFunc_POST(t *testing.T) {
}
loc := w.Header().Get("Location")
if loc != authboss.Cfg.AuthLogoutOKPath {
if loc != authboss.a.AuthLogoutOKPath {
t.Error("Unexpected location:", loc)
}
}

View File

@ -33,15 +33,15 @@ type Register struct {
// Initialize the module.
func (r *Register) Initialize() (err error) {
if authboss.Cfg.Storer == nil {
if authboss.a.Storer == nil {
return errors.New("register: Need a RegisterStorer")
}
if _, ok := authboss.Cfg.Storer.(RegisterStorer); !ok {
if _, ok := authboss.a.Storer.(RegisterStorer); !ok {
return errors.New("register: RegisterStorer required for register functionality")
}
if r.templates, err = response.LoadTemplates(authboss.Cfg.Layout, authboss.Cfg.ViewsPath, tplRegister); err != nil {
if r.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplRegister); err != nil {
return err
}
@ -58,7 +58,7 @@ func (r *Register) Routes() authboss.RouteTable {
// Storage returns storage requirements.
func (r *Register) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.Cfg.PrimaryID: authboss.String,
authboss.a.PrimaryID: authboss.String,
authboss.StorePassword: authboss.String,
}
}
@ -67,7 +67,7 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite
switch r.Method {
case "GET":
data := authboss.HTMLData{
"primaryID": authboss.Cfg.PrimaryID,
"primaryID": authboss.a.PrimaryID,
"primaryIDValue": "",
}
return reg.templates.Render(ctx, w, r, tplRegister, data)
@ -78,15 +78,15 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite
}
func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
key, _ := ctx.FirstPostFormValue(authboss.Cfg.PrimaryID)
key, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID)
password, _ := ctx.FirstPostFormValue(authboss.StorePassword)
policies := authboss.FilterValidators(authboss.Cfg.Policies, authboss.Cfg.PrimaryID, authboss.StorePassword)
validationErrs := ctx.Validate(policies, authboss.Cfg.ConfirmFields...)
policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID, authboss.StorePassword)
validationErrs := ctx.Validate(policies, authboss.a.ConfirmFields...)
if len(validationErrs) != 0 {
data := authboss.HTMLData{
"primaryID": authboss.Cfg.PrimaryID,
"primaryID": authboss.a.PrimaryID,
"primaryIDValue": key,
"errs": validationErrs.Map(),
}
@ -99,30 +99,30 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW
return err
}
pass, err := bcrypt.GenerateFromPassword([]byte(password), authboss.Cfg.BCryptCost)
pass, err := bcrypt.GenerateFromPassword([]byte(password), authboss.a.BCryptCost)
if err != nil {
return err
}
attr[authboss.Cfg.PrimaryID] = key
attr[authboss.a.PrimaryID] = key
attr[authboss.StorePassword] = string(pass)
ctx.User = attr
if err := authboss.Cfg.Storer.(RegisterStorer).Create(key, attr); err != nil {
if err := authboss.a.Storer.(RegisterStorer).Create(key, attr); err != nil {
return err
}
if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil {
if err := authboss.a.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil {
return err
}
if authboss.IsLoaded("confirm") {
response.Redirect(ctx, w, r, authboss.Cfg.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true)
response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true)
return nil
}
ctx.SessionStorer.Put(authboss.SessionKey, key)
response.Redirect(ctx, w, r, authboss.Cfg.RegisterOKPath, "Account successfully created, you are now logged in.", "", true)
response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "Account successfully created, you are now logged in.", "", true)
return nil
}

View File

@ -15,14 +15,14 @@ import (
func setup() *Register {
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.RegisterOKPath = "/regsuccess"
authboss.Cfg.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.Cfg.XSRFName = "xsrf"
authboss.Cfg.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
authboss.a.RegisterOKPath = "/regsuccess"
authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`))
authboss.a.XSRFName = "xsrf"
authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string {
return "xsrfvalue"
}
authboss.Cfg.ConfirmFields = []string{"password", "confirm_password"}
authboss.Cfg.Storer = mocks.NewMockStorer()
authboss.a.ConfirmFields = []string{"password", "confirm_password"}
authboss.a.Storer = mocks.NewMockStorer()
reg := Register{}
if err := reg.Initialize(); err != nil {
@ -34,7 +34,7 @@ func setup() *Register {
func TestRegister(t *testing.T) {
authboss.Cfg = authboss.NewConfig()
authboss.Cfg.Storer = mocks.NewMockStorer()
authboss.a.Storer = mocks.NewMockStorer()
r := Register{}
if err := r.Initialize(); err != nil {
@ -46,7 +46,7 @@ func TestRegister(t *testing.T) {
}
sto := r.Storage()
if sto[authboss.Cfg.PrimaryID] != authboss.String {
if sto[authboss.a.PrimaryID] != authboss.String {
t.Error("Wanted primary ID to be a string.")
}
if sto[authboss.StorePassword] != authboss.String {
@ -76,7 +76,7 @@ func TestRegisterGet(t *testing.T) {
if str := w.Body.String(); !strings.Contains(str, "<form") {
t.Error("It should have rendered a nice form:", str)
} else if !strings.Contains(str, `name="`+authboss.Cfg.PrimaryID) {
} else if !strings.Contains(str, `name="`+authboss.a.PrimaryID) {
t.Error("Form should contain the primary ID:", str)
}
}
@ -88,7 +88,7 @@ func TestRegisterPostValidationErrs(t *testing.T) {
vals := url.Values{}
email := "email@address.com"
vals.Set(authboss.Cfg.PrimaryID, email)
vals.Set(authboss.a.PrimaryID, email)
vals.Set(authboss.StorePassword, "pass")
vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass2")
@ -113,7 +113,7 @@ func TestRegisterPostValidationErrs(t *testing.T) {
t.Error("Confirm password should have an error:", str)
}
if _, err := authboss.Cfg.Storer.Get(email); err != authboss.ErrUserNotFound {
if _, err := authboss.a.Storer.Get(email); err != authboss.ErrUserNotFound {
t.Error("The user should not have been saved.")
}
}
@ -125,7 +125,7 @@ func TestRegisterPostSuccess(t *testing.T) {
vals := url.Values{}
email := "email@address.com"
vals.Set(authboss.Cfg.PrimaryID, email)
vals.Set(authboss.a.PrimaryID, email)
vals.Set(authboss.StorePassword, "pass")
vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass")
@ -142,17 +142,17 @@ func TestRegisterPostSuccess(t *testing.T) {
t.Error("It should have written a redirect:", w.Code)
}
if loc := w.Header().Get("Location"); loc != authboss.Cfg.RegisterOKPath {
if loc := w.Header().Get("Location"); loc != authboss.a.RegisterOKPath {
t.Error("Redirected to the wrong location", loc)
}
user, err := authboss.Cfg.Storer.Get(email)
user, err := authboss.a.Storer.Get(email)
if err == authboss.ErrUserNotFound {
t.Error("The user have been saved.")
}
attrs := authboss.Unbind(user)
if e, err := attrs.StringErr(authboss.Cfg.PrimaryID); err != nil {
if e, err := attrs.StringErr(authboss.a.PrimaryID); err != nil {
t.Error(err)
} else if e != email {
t.Errorf("Email was not set properly, want: %s, got: %s", email, e)

View File

@ -47,20 +47,20 @@ type Remember struct{}
// Initialize module
func (r *Remember) Initialize() error {
if authboss.Cfg.Storer == nil && authboss.Cfg.OAuth2Storer == nil {
if authboss.a.Storer == nil && authboss.a.OAuth2Storer == nil {
return errors.New("remember: Need a RememberStorer")
}
if _, ok := authboss.Cfg.Storer.(RememberStorer); !ok {
if _, ok := authboss.Cfg.OAuth2Storer.(RememberStorer); !ok {
if _, ok := authboss.a.Storer.(RememberStorer); !ok {
if _, ok := authboss.a.OAuth2Storer.(RememberStorer); !ok {
return errors.New("remember: RememberStorer required for remember functionality")
}
}
authboss.Cfg.Callbacks.Before(authboss.EventGetUserSession, r.auth)
authboss.Cfg.Callbacks.After(authboss.EventAuth, r.afterAuth)
authboss.Cfg.Callbacks.After(authboss.EventOAuth, r.afterOAuth)
authboss.Cfg.Callbacks.After(authboss.EventPasswordReset, r.afterPassword)
authboss.a.Callbacks.Before(authboss.EventGetUserSession, r.auth)
authboss.a.Callbacks.After(authboss.EventAuth, r.afterAuth)
authboss.a.Callbacks.After(authboss.EventOAuth, r.afterOAuth)
authboss.a.Callbacks.After(authboss.EventPasswordReset, r.afterPassword)
return nil
}
@ -73,7 +73,7 @@ func (r *Remember) Routes() authboss.RouteTable {
// Storage requirements
func (r *Remember) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
authboss.Cfg.PrimaryID: authboss.String,
authboss.a.PrimaryID: authboss.String,
}
}
@ -87,7 +87,7 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error {
return errUserMissing
}
key, err := ctx.User.StringErr(authboss.Cfg.PrimaryID)
key, err := ctx.User.StringErr(authboss.a.PrimaryID)
if err != nil {
return err
}
@ -146,7 +146,7 @@ func (r *Remember) afterPassword(ctx *authboss.Context) error {
return nil
}
id, ok := ctx.User.String(authboss.Cfg.PrimaryID)
id, ok := ctx.User.String(authboss.a.PrimaryID)
if !ok {
return nil
}
@ -154,8 +154,8 @@ func (r *Remember) afterPassword(ctx *authboss.Context) error {
ctx.CookieStorer.Del(authboss.CookieRemember)
var storer RememberStorer
if storer, ok = authboss.Cfg.Storer.(RememberStorer); !ok {
if storer, ok = authboss.Cfg.OAuth2Storer.(RememberStorer); !ok {
if storer, ok = authboss.a.Storer.(RememberStorer); !ok {
if storer, ok = authboss.a.OAuth2Storer.(RememberStorer); !ok {
return nil
}
}
@ -181,8 +181,8 @@ func (r *Remember) new(cstorer authboss.ClientStorer, storageKey string) (string
var storer RememberStorer
var ok bool
if storer, ok = authboss.Cfg.Storer.(RememberStorer); !ok {
storer, ok = authboss.Cfg.OAuth2Storer.(RememberStorer)
if storer, ok = authboss.a.Storer.(RememberStorer); !ok {
storer, ok = authboss.a.OAuth2Storer.(RememberStorer)
}
// Save the token in the DB
@ -226,8 +226,8 @@ func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) {
sum := md5.Sum(token)
var storer RememberStorer
if storer, ok = authboss.Cfg.Storer.(RememberStorer); !ok {
storer, ok = authboss.Cfg.OAuth2Storer.(RememberStorer)
if storer, ok = authboss.a.Storer.(RememberStorer); !ok {
storer, ok = authboss.a.OAuth2Storer.(RememberStorer)
}
err = storer.UseToken(givenKey, base64.StdEncoding.EncodeToString(sum[:]))

View File

@ -19,13 +19,13 @@ func TestInitialize(t *testing.T) {
t.Error("Expected error about token storers.")
}
authboss.Cfg.Storer = mocks.MockFailStorer{}
authboss.a.Storer = mocks.MockFailStorer{}
err = r.Initialize()
if err == nil {
t.Error("Expected error about token storers.")
}
authboss.Cfg.Storer = mocks.NewMockStorer()
authboss.a.Storer = mocks.NewMockStorer()
err = r.Initialize()
if err != nil {
t.Error("Unexpected error:", err)
@ -36,7 +36,7 @@ func TestAfterAuth(t *testing.T) {
r := Remember{}
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer()
@ -54,7 +54,7 @@ func TestAfterAuth(t *testing.T) {
ctx.SessionStorer = session
ctx.CookieStorer = cookies
ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: "test@email.com"}
ctx.User = authboss.Attributes{authboss.a.PrimaryID: "test@email.com"}
if err := r.afterAuth(ctx); err != nil {
t.Error(err)
@ -69,7 +69,7 @@ func TestAfterOAuth(t *testing.T) {
r := Remember{}
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
@ -108,14 +108,14 @@ func TestAfterPasswordReset(t *testing.T) {
id := "test@email.com"
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
session := mocks.NewMockClientStorer()
cookies := mocks.NewMockClientStorer()
storer.Tokens[id] = []string{"one", "two"}
cookies.Values[authboss.CookieRemember] = "token"
ctx := authboss.NewContext()
ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: id}
ctx.User = authboss.Attributes{authboss.a.PrimaryID: id}
ctx.SessionStorer = session
ctx.CookieStorer = cookies
@ -136,7 +136,7 @@ func TestNew(t *testing.T) {
r := &Remember{}
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
cookies := mocks.NewMockClientStorer()
key := "tester"
@ -165,7 +165,7 @@ func TestAuth(t *testing.T) {
r := &Remember{}
authboss.NewConfig()
storer := mocks.NewMockStorer()
authboss.Cfg.Storer = storer
authboss.a.Storer = storer
cookies := mocks.NewMockClientStorer()
session := mocks.NewMockClientStorer()

View File

@ -14,19 +14,19 @@ type HandlerFunc func(*Context, http.ResponseWriter, *http.Request) error
type RouteTable map[string]HandlerFunc
// NewRouter returns a router to be mounted at some mountpoint.
func NewRouter() http.Handler {
func (a *Authboss) NewRouter() http.Handler {
mux := http.NewServeMux()
for name, mod := range modules {
for route, handler := range mod.Routes() {
fmt.Fprintf(Cfg.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(Cfg.MountPath, route))
mux.Handle(path.Join(Cfg.MountPath, route), contextRoute{handler})
fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route))
mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
}
}
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if Cfg.NotFoundHandler != nil {
Cfg.NotFoundHandler.ServeHTTP(w, r)
if a.NotFoundHandler != nil {
a.NotFoundHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, "404 Page not found")
@ -37,25 +37,26 @@ func NewRouter() http.Handler {
}
type contextRoute struct {
*Authboss
fn HandlerFunc
}
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, err := ContextFromRequest(r)
ctx, err := c.Authboss.ContextFromRequest(r)
if err != nil {
fmt.Fprintf(Cfg.LogWriter, "route: Malformed request, could not create context: %v", err)
fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err)
return
}
ctx.CookieStorer = clientStoreWrapper{Cfg.CookieStoreMaker(w, r)}
ctx.SessionStorer = clientStoreWrapper{Cfg.SessionStoreMaker(w, r)}
ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)}
ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)}
err = c.fn(ctx, w, r)
if err == nil {
return
}
fmt.Fprintf(Cfg.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err)
fmt.Fprintf(c.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err)
switch e := err.(type) {
case ErrAndRedirect:
@ -67,15 +68,15 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
http.Redirect(w, r, e.Location, http.StatusFound)
case ClientDataErr:
if Cfg.BadRequestHandler != nil {
Cfg.BadRequestHandler.ServeHTTP(w, r)
if c.BadRequestHandler != nil {
c.BadRequestHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "400 Bad request")
}
default:
if Cfg.ErrorHandler != nil {
Cfg.ErrorHandler.ServeHTTP(w, r)
if c.ErrorHandler != nil {
c.ErrorHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "500 An error has occurred")

View File

@ -9,34 +9,31 @@ import (
"testing"
)
type testRouterMod struct {
handler HandlerFunc
routes RouteTable
type testRouterModule struct {
routes RouteTable
}
func (t testRouterMod) Initialize() error { return nil }
func (t testRouterMod) Routes() RouteTable { return t.routes }
func (t testRouterMod) Storage() StorageOptions { return nil }
func (t testRouterModule) Initialize(ab *Authboss) error { return nil }
func (t testRouterModule) Routes() RouteTable { return t.routes }
func (t testRouterModule) Storage() StorageOptions { return nil }
func testRouterSetup() (http.Handler, *bytes.Buffer) {
Cfg = NewConfig()
Cfg.MountPath = "/prefix"
Cfg.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
Cfg.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
ab := New()
ab.MountPath = "/prefix"
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
logger := &bytes.Buffer{}
Cfg.LogWriter = logger
ab.LogWriter = logger
return NewRouter(), logger
return ab, ab.NewRouter(), logger
}
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
modules = map[string]Modularizer{
"test": testRouterMod{
routes: map[string]HandlerFunc{
path: h,
},
},
}
modules = map[string]Modularizer{}
RegisterModule("testrouter", testRouterModule{
routes: map[string]HandlerFunc{path: h},
})
w = httptest.NewRecorder()
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)
@ -52,7 +49,7 @@ func TestRouter(t *testing.T) {
return nil
})
router, _ := testRouterSetup()
_, router, _ := testRouterSetup()
router.ServeHTTP(w, r)
@ -62,7 +59,7 @@ func TestRouter(t *testing.T) {
}
func TestRouter_NotFound(t *testing.T) {
router, _ := testRouterSetup()
ab, router, _ := testRouterSetup()
w := httptest.NewRecorder()
r, _ := http.NewRequest("GET", "http://localhost/wat", nil)
@ -75,7 +72,7 @@ func TestRouter_NotFound(t *testing.T) {
}
called := false
Cfg.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ab.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})
@ -93,7 +90,7 @@ func TestRouter_BadRequest(t *testing.T) {
},
)
router, logger := testRouterSetup()
ab, router, logger := testRouterSetup()
logger.Reset()
router.ServeHTTP(w, r)
@ -109,7 +106,7 @@ func TestRouter_BadRequest(t *testing.T) {
}
called := false
Cfg.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ab.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})
@ -132,7 +129,7 @@ func TestRouter_Error(t *testing.T) {
},
)
router, logger := testRouterSetup()
ab, router, logger := testRouterSetup()
logger.Reset()
router.ServeHTTP(w, r)
@ -148,7 +145,7 @@ func TestRouter_Error(t *testing.T) {
}
called := false
Cfg.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ab.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
})
@ -177,10 +174,10 @@ func TestRouter_Redirect(t *testing.T) {
},
)
router, logger := testRouterSetup()
ab, router, logger := testRouterSetup()
session := mockClientStore{}
Cfg.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session }
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session }
logger.Reset()
router.ServeHTTP(w, r)

View File

@ -72,6 +72,8 @@ func TestAttributeMeta_Names(t *testing.T) {
}
func TestAttributeMeta_Helpers(t *testing.T) {
t.Parallel()
now := time.Now()
attr := Attributes{
"integer": int64(5),

View File

@ -64,7 +64,8 @@ func TestErrorList_Map(t *testing.T) {
func TestValidate(t *testing.T) {
t.Parallel()
ctx := mockRequestContext(StoreUsername, "john", StoreEmail, "john@john.com")
ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com")
errList := ctx.Validate([]Validator{
mockValidator{
@ -95,19 +96,20 @@ func TestValidate(t *testing.T) {
func TestValidate_Confirm(t *testing.T) {
t.Parallel()
ctx := mockRequestContext(StoreUsername, "john", "confirmUsername", "johnny")
ab := New()
ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny")
errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
if errs["confirmUsername"][0] != "Does not match username" {
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
}
ctx = mockRequestContext(StoreUsername, "john", "confirmUsername", "john")
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map()
if len(errs) != 0 {
t.Error("Expected no errors:", errs)
}
ctx = mockRequestContext(StoreUsername, "john", "confirmUsername", "john")
ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john")
errs = ctx.Validate(nil, StoreUsername).Map()
if len(errs) != 0 {
t.Error("Expected no errors:", errs)

View File

@ -3,6 +3,8 @@ package authboss
import "testing"
func TestHTMLData(t *testing.T) {
t.Parallel()
data := NewHTMLData("a", "b").MergeKV("c", "d").Merge(NewHTMLData("e", "f"))
if data["a"].(string) != "b" {
t.Error("A was wrong:", data["a"])
@ -16,6 +18,8 @@ func TestHTMLData(t *testing.T) {
}
func TestHTMLData_Panics(t *testing.T) {
t.Parallel()
nPanics := 0
panicCount := func() {
if r := recover(); r != nil {