mirror of
				https://github.com/volatiletech/authboss.git
				synced 2025-10-30 23:47:59 +02:00 
			
		
		
		
	Stop reliance on global scope.
- This change was necessary because multi-tenancy sites could not use authboss properly.
This commit is contained in:
		
							
								
								
									
										34
									
								
								auth/auth.go
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								auth/auth.go
									
									
									
									
									
								
							| @@ -29,19 +29,19 @@ type Auth struct { | ||||
|  | ||||
| // Initialize module | ||||
| func (a *Auth) Initialize() (err error) { | ||||
| 	if authboss.Cfg.Storer == nil { | ||||
| 	if authboss.a.Storer == nil { | ||||
| 		return errors.New("auth: Need a Storer") | ||||
| 	} | ||||
|  | ||||
| 	if len(authboss.Cfg.XSRFName) == 0 { | ||||
| 	if len(authboss.a.XSRFName) == 0 { | ||||
| 		return errors.New("auth: XSRFName must be set") | ||||
| 	} | ||||
|  | ||||
| 	if authboss.Cfg.XSRFMaker == nil { | ||||
| 	if authboss.a.XSRFMaker == nil { | ||||
| 		return errors.New("auth: XSRFMaker must be defined") | ||||
| 	} | ||||
|  | ||||
| 	a.templates, err = response.LoadTemplates(authboss.Cfg.Layout, authboss.Cfg.ViewsPath, tplLogin) | ||||
| 	a.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplLogin) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -60,7 +60,7 @@ func (a *Auth) Routes() authboss.RouteTable { | ||||
| // Storage requirements | ||||
| func (a *Auth) Storage() authboss.StorageOptions { | ||||
| 	return authboss.StorageOptions{ | ||||
| 		authboss.Cfg.PrimaryID: authboss.String, | ||||
| 		authboss.a.PrimaryID:   authboss.String, | ||||
| 		authboss.StorePassword: authboss.String, | ||||
| 	} | ||||
| } | ||||
| @@ -70,8 +70,8 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 	case methodGET: | ||||
| 		if _, ok := ctx.SessionStorer.Get(authboss.SessionKey); ok { | ||||
| 			if halfAuthed, ok := ctx.SessionStorer.Get(authboss.SessionHalfAuthKey); !ok || halfAuthed == "false" { | ||||
| 				//http.Redirect(w, r, authboss.Cfg.AuthLoginOKPath, http.StatusFound, true) | ||||
| 				response.Redirect(ctx, w, r, authboss.Cfg.AuthLoginOKPath, "", "", true) | ||||
| 				//http.Redirect(w, r, authboss.a.AuthLoginOKPath, http.StatusFound, true) | ||||
| 				response.Redirect(ctx, w, r, authboss.a.AuthLoginOKPath, "", "", true) | ||||
| 				return nil | ||||
| 			} | ||||
| 		} | ||||
| @@ -79,23 +79,23 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 		data := authboss.NewHTMLData( | ||||
| 			"showRemember", authboss.IsLoaded("remember"), | ||||
| 			"showRecover", authboss.IsLoaded("recover"), | ||||
| 			"primaryID", authboss.Cfg.PrimaryID, | ||||
| 			"primaryID", authboss.a.PrimaryID, | ||||
| 			"primaryIDValue", "", | ||||
| 		) | ||||
| 		return a.templates.Render(ctx, w, r, tplLogin, data) | ||||
| 	case methodPOST: | ||||
| 		key, _ := ctx.FirstPostFormValue(authboss.Cfg.PrimaryID) | ||||
| 		key, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID) | ||||
| 		password, _ := ctx.FirstPostFormValue("password") | ||||
|  | ||||
| 		errData := authboss.NewHTMLData( | ||||
| 			"error", fmt.Sprintf("invalid %s and/or password", authboss.Cfg.PrimaryID), | ||||
| 			"primaryID", authboss.Cfg.PrimaryID, | ||||
| 			"error", fmt.Sprintf("invalid %s and/or password", authboss.a.PrimaryID), | ||||
| 			"primaryID", authboss.a.PrimaryID, | ||||
| 			"primaryIDValue", key, | ||||
| 			"showRemember", authboss.IsLoaded("remember"), | ||||
| 			"showRecover", authboss.IsLoaded("recover"), | ||||
| 		) | ||||
|  | ||||
| 		policies := authboss.FilterValidators(authboss.Cfg.Policies, authboss.Cfg.PrimaryID, authboss.StorePassword) | ||||
| 		policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID, authboss.StorePassword) | ||||
| 		if validationErrs := ctx.Validate(policies); len(validationErrs) > 0 { | ||||
| 			return a.templates.Render(ctx, w, r, tplLogin, errData) | ||||
| 		} | ||||
| @@ -104,7 +104,7 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 			return a.templates.Render(ctx, w, r, tplLogin, errData) | ||||
| 		} | ||||
|  | ||||
| 		interrupted, err := authboss.Cfg.Callbacks.FireBefore(authboss.EventAuth, ctx) | ||||
| 		interrupted, err := authboss.a.Callbacks.FireBefore(authboss.EventAuth, ctx) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} else if interrupted != authboss.InterruptNone { | ||||
| @@ -115,17 +115,17 @@ func (a *Auth) loginHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 			case authboss.InterruptAccountNotConfirmed: | ||||
| 				reason = "Your account has not been confirmed." | ||||
| 			} | ||||
| 			response.Redirect(ctx, w, r, authboss.Cfg.AuthLoginFailPath, "", reason, false) | ||||
| 			response.Redirect(ctx, w, r, authboss.a.AuthLoginFailPath, "", reason, false) | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		ctx.SessionStorer.Put(authboss.SessionKey, key) | ||||
| 		ctx.SessionStorer.Del(authboss.SessionHalfAuthKey) | ||||
|  | ||||
| 		if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil { | ||||
| 		if err := authboss.a.Callbacks.FireAfter(authboss.EventAuth, ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		response.Redirect(ctx, w, r, authboss.Cfg.AuthLoginOKPath, "", "", true) | ||||
| 		response.Redirect(ctx, w, r, authboss.a.AuthLoginOKPath, "", "", true) | ||||
| 	default: | ||||
| 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 	} | ||||
| @@ -157,7 +157,7 @@ func (a *Auth) logoutHandlerFunc(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 		ctx.CookieStorer.Del(authboss.CookieRemember) | ||||
| 		ctx.SessionStorer.Del(authboss.SessionLastAction) | ||||
|  | ||||
| 		response.Redirect(ctx, w, r, authboss.Cfg.AuthLogoutOKPath, "You have logged out", "", true) | ||||
| 		response.Redirect(ctx, w, r, authboss.a.AuthLogoutOKPath, "You have logged out", "", true) | ||||
| 	default: | ||||
| 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 	} | ||||
|   | ||||
| @@ -17,14 +17,14 @@ func testSetup() (a *Auth, s *mocks.MockStorer) { | ||||
| 	s = mocks.NewMockStorer() | ||||
|  | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.LogWriter = ioutil.Discard | ||||
| 	authboss.Cfg.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.Cfg.Storer = s | ||||
| 	authboss.Cfg.XSRFName = "xsrf" | ||||
| 	authboss.Cfg.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { | ||||
| 	authboss.a.LogWriter = ioutil.Discard | ||||
| 	authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.a.Storer = s | ||||
| 	authboss.a.XSRFName = "xsrf" | ||||
| 	authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { | ||||
| 		return "xsrfvalue" | ||||
| 	} | ||||
| 	authboss.Cfg.PrimaryID = authboss.StoreUsername | ||||
| 	authboss.a.PrimaryID = authboss.StoreUsername | ||||
|  | ||||
| 	a = &Auth{} | ||||
| 	if err := a.Initialize(); err != nil { | ||||
| @@ -51,8 +51,8 @@ func TestAuth(t *testing.T) { | ||||
| 	a, _ := testSetup() | ||||
|  | ||||
| 	storage := a.Storage() | ||||
| 	if storage[authboss.Cfg.PrimaryID] != authboss.String { | ||||
| 		t.Error("Expected storage KV:", authboss.Cfg.PrimaryID, authboss.String) | ||||
| 	if storage[authboss.a.PrimaryID] != authboss.String { | ||||
| 		t.Error("Expected storage KV:", authboss.a.PrimaryID, authboss.String) | ||||
| 	} | ||||
| 	if storage[authboss.StorePassword] != authboss.String { | ||||
| 		t.Error("Expected storage KV:", authboss.StorePassword, authboss.String) | ||||
| @@ -74,7 +74,7 @@ func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) { | ||||
| 	sessionStore.Put(authboss.SessionKey, "a") | ||||
| 	sessionStore.Put(authboss.SessionHalfAuthKey, "false") | ||||
|  | ||||
| 	authboss.Cfg.AuthLoginOKPath = "/dashboard" | ||||
| 	authboss.a.AuthLoginOKPath = "/dashboard" | ||||
|  | ||||
| 	if err := a.loginHandlerFunc(ctx, w, r); err != nil { | ||||
| 		t.Error("Unexpeced error:", err) | ||||
| @@ -85,7 +85,7 @@ func TestAuth_loginHandlerFunc_GET_RedirectsWhenHalfAuthed(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	loc := w.Header().Get("Location") | ||||
| 	if loc != authboss.Cfg.AuthLoginOKPath { | ||||
| 	if loc != authboss.a.AuthLoginOKPath { | ||||
| 		t.Error("Unexpected redirect:", loc) | ||||
| 	} | ||||
| } | ||||
| @@ -106,7 +106,7 @@ func TestAuth_loginHandlerFunc_GET(t *testing.T) { | ||||
| 	if !strings.Contains(body, "<form") { | ||||
| 		t.Error("Should have rendered a form") | ||||
| 	} | ||||
| 	if !strings.Contains(body, `name="`+authboss.Cfg.PrimaryID) { | ||||
| 	if !strings.Contains(body, `name="`+authboss.a.PrimaryID) { | ||||
| 		t.Error("Form should contain the primary ID field:", body) | ||||
| 	} | ||||
| 	if !strings.Contains(body, `name="password"`) { | ||||
| @@ -118,8 +118,8 @@ func TestAuth_loginHandlerFunc_POST_ReturnsErrorOnCallbackFailure(t *testing.T) | ||||
| 	a, storer := testSetup() | ||||
| 	storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"} | ||||
|  | ||||
| 	authboss.Cfg.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { | ||||
| 	authboss.a.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { | ||||
| 		return authboss.InterruptNone, errors.New("explode") | ||||
| 	}) | ||||
|  | ||||
| @@ -134,8 +134,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) { | ||||
| 	a, storer := testSetup() | ||||
| 	storer.Users["john"] = authboss.Attributes{"password": "$2a$10$B7aydtqVF9V8RSNx3lCKB.l09jqLV/aMiVqQHajtL7sWGhCS9jlOu"} | ||||
|  | ||||
| 	authboss.Cfg.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { | ||||
| 	authboss.a.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { | ||||
| 		return authboss.InterruptAccountLocked, nil | ||||
| 	}) | ||||
|  | ||||
| @@ -150,7 +150,7 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	loc := w.Header().Get("Location") | ||||
| 	if loc != authboss.Cfg.AuthLoginFailPath { | ||||
| 	if loc != authboss.a.AuthLoginFailPath { | ||||
| 		t.Error("Unexpeced location:", loc) | ||||
| 	} | ||||
|  | ||||
| @@ -159,8 +159,8 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) { | ||||
| 		t.Error("Expected error flash message:", expectedMsg) | ||||
| 	} | ||||
|  | ||||
| 	authboss.Cfg.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { | ||||
| 	authboss.a.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.a.Callbacks.Before(authboss.EventAuth, func(_ *authboss.Context) (authboss.Interrupt, error) { | ||||
| 		return authboss.InterruptAccountNotConfirmed, nil | ||||
| 	}) | ||||
|  | ||||
| @@ -173,7 +173,7 @@ func TestAuth_loginHandlerFunc_POST_RedirectsWhenInterrupted(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	loc = w.Header().Get("Location") | ||||
| 	if loc != authboss.Cfg.AuthLoginFailPath { | ||||
| 	if loc != authboss.a.AuthLoginFailPath { | ||||
| 		t.Error("Unexpeced location:", loc) | ||||
| 	} | ||||
|  | ||||
| @@ -224,9 +224,9 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) { | ||||
| 	ctx, w, r, _ := testRequest("POST", "username", "john", "password", "1234") | ||||
| 	cb := mocks.NewMockAfterCallback() | ||||
|  | ||||
| 	authboss.Cfg.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventAuth, cb.Fn) | ||||
| 	authboss.Cfg.AuthLoginOKPath = "/dashboard" | ||||
| 	authboss.a.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.a.Callbacks.After(authboss.EventAuth, cb.Fn) | ||||
| 	authboss.a.AuthLoginOKPath = "/dashboard" | ||||
|  | ||||
| 	sessions := mocks.NewMockClientStorer() | ||||
| 	ctx.SessionStorer = sessions | ||||
| @@ -244,7 +244,7 @@ func TestAuth_loginHandlerFunc_POST(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	loc := w.Header().Get("Location") | ||||
| 	if loc != authboss.Cfg.AuthLoginOKPath { | ||||
| 	if loc != authboss.a.AuthLoginOKPath { | ||||
| 		t.Error("Unexpeced location:", loc) | ||||
| 	} | ||||
|  | ||||
| @@ -283,7 +283,7 @@ func TestAuth_validateCredentials(t *testing.T) { | ||||
|  | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	storer.GetErr = "Failed to load user" | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
|  | ||||
| 	ctx := authboss.Context{} | ||||
|  | ||||
| @@ -305,7 +305,7 @@ func TestAuth_validateCredentials(t *testing.T) { | ||||
| func TestAuth_logoutHandlerFunc_GET(t *testing.T) { | ||||
| 	a, _ := testSetup() | ||||
|  | ||||
| 	authboss.Cfg.AuthLogoutOKPath = "/dashboard" | ||||
| 	authboss.a.AuthLogoutOKPath = "/dashboard" | ||||
|  | ||||
| 	ctx, w, r, sessionStorer := testRequest("GET") | ||||
| 	sessionStorer.Put(authboss.SessionKey, "asdf") | ||||
|   | ||||
							
								
								
									
										51
									
								
								authboss.go
									
									
									
									
									
								
							
							
						
						
									
										51
									
								
								authboss.go
									
									
									
									
									
								
							| @@ -17,11 +17,24 @@ import ( | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
|  | ||||
| // Authboss contains a configuration and other details for running. | ||||
| type Authboss struct { | ||||
| 	Config | ||||
| } | ||||
|  | ||||
| // New makes a new instance of authboss with a default | ||||
| // configuration. | ||||
| func New() *Authboss { | ||||
| 	ab := &Authboss{} | ||||
| 	ab.Defaults() | ||||
| 	return ab | ||||
| } | ||||
|  | ||||
| // Init authboss and it's loaded modules. | ||||
| func Init() error { | ||||
| func (a *Authboss) Init() error { | ||||
| 	for name, mod := range modules { | ||||
| 		fmt.Fprintf(Cfg.LogWriter, "%-10s Initializing\n", "["+name+"]") | ||||
| 		if err := mod.Initialize(); err != nil { | ||||
| 		fmt.Fprintf(a.LogWriter, "%-10s Initializing\n", "["+name+"]") | ||||
| 		if err := mod.Initialize(a); err != nil { | ||||
| 			return fmt.Errorf("[%s] Error Initializing: %v", name, err) | ||||
| 		} | ||||
| 	} | ||||
| @@ -30,16 +43,16 @@ func Init() error { | ||||
| } | ||||
|  | ||||
| // CurrentUser retrieves the current user from the session and the database. | ||||
| func CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) { | ||||
| 	ctx, err := ContextFromRequest(r) | ||||
| func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) { | ||||
| 	ctx, err := a.ContextFromRequest(r) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	ctx.SessionStorer = clientStoreWrapper{Cfg.SessionStoreMaker(w, r)} | ||||
| 	ctx.CookieStorer = clientStoreWrapper{Cfg.CookieStoreMaker(w, r)} | ||||
| 	ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)} | ||||
| 	ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)} | ||||
|  | ||||
| 	_, err = Cfg.Callbacks.FireBefore(EventGetUserSession, ctx) | ||||
| 	_, err = a.Callbacks.FireBefore(EventGetUserSession, ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| @@ -54,22 +67,22 @@ func CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, err = Cfg.Callbacks.FireBefore(EventGet, ctx) | ||||
| 	_, err = a.Callbacks.FireBefore(EventGet, ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if index := strings.IndexByte(key, ';'); index > 0 { | ||||
| 		return Cfg.OAuth2Storer.GetOAuth(key[:index], key[index+1:]) | ||||
| 		return a.OAuth2Storer.GetOAuth(key[:index], key[index+1:]) | ||||
| 	} | ||||
|  | ||||
| 	return Cfg.Storer.Get(key) | ||||
| 	return a.Storer.Get(key) | ||||
| } | ||||
|  | ||||
| // CurrentUserP retrieves the current user but panics if it's not available for | ||||
| // any reason. | ||||
| func CurrentUserP(w http.ResponseWriter, r *http.Request) interface{} { | ||||
| 	i, err := CurrentUser(w, r) | ||||
| func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) interface{} { | ||||
| 	i, err := a.CurrentUser(w, r) | ||||
| 	if err != nil { | ||||
| 		panic(err.Error()) | ||||
| 	} | ||||
| @@ -96,13 +109,13 @@ will be returned. | ||||
| The error returned is returned either from the updater if that produced an error | ||||
| or from the cleanup routines. | ||||
| */ | ||||
| func UpdatePassword(w http.ResponseWriter, r *http.Request, | ||||
| func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request, | ||||
| 	ptPassword string, user interface{}, updater func() error) error { | ||||
|  | ||||
| 	updatePwd := len(ptPassword) > 0 | ||||
|  | ||||
| 	if updatePwd { | ||||
| 		pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), Cfg.BCryptCost) | ||||
| 		pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), a.BCryptCost) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -131,11 +144,11 @@ func UpdatePassword(w http.ResponseWriter, r *http.Request, | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ctx, err := ContextFromRequest(r) | ||||
| 	ctx, err := a.ContextFromRequest(r) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	ctx.SessionStorer = clientStoreWrapper{Cfg.SessionStoreMaker(w, r)} | ||||
| 	ctx.CookieStorer = clientStoreWrapper{Cfg.CookieStoreMaker(w, r)} | ||||
| 	return Cfg.Callbacks.FireAfter(EventPasswordReset, ctx) | ||||
| 	ctx.SessionStorer = clientStoreWrapper{a.SessionStoreMaker(w, r)} | ||||
| 	ctx.CookieStorer = clientStoreWrapper{a.CookieStoreMaker(w, r)} | ||||
| 	return a.Callbacks.FireAfter(EventPasswordReset, ctx) | ||||
| } | ||||
|   | ||||
| @@ -6,46 +6,41 @@ import ( | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestMain(main *testing.M) { | ||||
| 	RegisterModule("testmodule", testMod) | ||||
| 	Cfg.LogWriter = ioutil.Discard | ||||
| 	Init() | ||||
| 	code := main.Run() | ||||
| 	os.Exit(code) | ||||
| } | ||||
|  | ||||
| func TestAuthBossInit(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	Cfg.LogWriter = ioutil.Discard | ||||
| 	err := Init() | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	ab.LogWriter = ioutil.Discard | ||||
| 	err := ab.Init() | ||||
| 	if err != nil { | ||||
| 		t.Error("Unexpected error:", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestAuthBossCurrentUser(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	Cfg.LogWriter = ioutil.Discard | ||||
| 	Cfg.Storer = mockStorer{"joe": Attributes{"email": "john@john.com", "password": "lies"}} | ||||
| 	Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	ab.LogWriter = ioutil.Discard | ||||
| 	ab.Storer = mockStorer{"joe": Attributes{"email": "john@john.com", "password": "lies"}} | ||||
| 	ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 		return mockClientStore{SessionKey: "joe"} | ||||
| 	} | ||||
| 	Cfg.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 	ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 		return mockClientStore{} | ||||
| 	} | ||||
|  | ||||
| 	if err := Init(); err != nil { | ||||
| 	if err := ab.Init(); err != nil { | ||||
| 		t.Error("Unexpected error:", err) | ||||
| 	} | ||||
|  | ||||
| 	rec := httptest.NewRecorder() | ||||
| 	req, _ := http.NewRequest("GET", "localhost", nil) | ||||
|  | ||||
| 	userStruct := CurrentUserP(rec, req) | ||||
| 	userStruct := ab.CurrentUserP(rec, req) | ||||
| 	us := userStruct.(*mockUser) | ||||
|  | ||||
| 	if us.Email != "john@john.com" || us.Password != "lies" { | ||||
| @@ -54,18 +49,20 @@ func TestAuthBossCurrentUser(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestAuthbossUpdatePassword(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	session := mockClientStore{} | ||||
| 	cookies := mockClientStore{} | ||||
| 	Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 	ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 		return session | ||||
| 	} | ||||
| 	Cfg.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 	ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 		return cookies | ||||
| 	} | ||||
|  | ||||
| 	called := false | ||||
| 	Cfg.Callbacks.After(EventPasswordReset, func(ctx *Context) error { | ||||
| 	ab.Callbacks.After(EventPasswordReset, func(ctx *Context) error { | ||||
| 		called = true | ||||
| 		return nil | ||||
| 	}) | ||||
| @@ -80,7 +77,7 @@ func TestAuthbossUpdatePassword(t *testing.T) { | ||||
| 	r, _ := http.NewRequest("GET", "http://localhost", nil) | ||||
|  | ||||
| 	called = false | ||||
| 	err := UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil }) | ||||
| 	err := ab.UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil }) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| @@ -93,7 +90,7 @@ func TestAuthbossUpdatePassword(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	called = false | ||||
| 	err = UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil }) | ||||
| 	err = ab.UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil }) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| @@ -107,7 +104,7 @@ func TestAuthbossUpdatePassword(t *testing.T) { | ||||
|  | ||||
| 	called = false | ||||
| 	oldPassword := user1.Password | ||||
| 	err = UpdatePassword(nil, r, "", &user1, func() error { return nil }) | ||||
| 	err = ab.UpdatePassword(nil, r, "", &user1, func() error { return nil }) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| @@ -121,12 +118,16 @@ func TestAuthbossUpdatePassword(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestAuthbossUpdatePasswordFail(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
|  | ||||
| 	user1 := struct { | ||||
| 		Password string | ||||
| 	}{} | ||||
|  | ||||
| 	anErr := errors.New("AnError") | ||||
| 	err := UpdatePassword(nil, nil, "update", &user1, func() error { return anErr }) | ||||
| 	err := ab.UpdatePassword(nil, nil, "update", &user1, func() error { return anErr }) | ||||
| 	if err != anErr { | ||||
| 		t.Error("Expected an specific error:", err) | ||||
| 	} | ||||
|   | ||||
| @@ -115,7 +115,7 @@ func (c *Callbacks) FireBefore(e Event, ctx *Context) (interrupt Interrupt, err | ||||
| 	for _, fn := range callbacks { | ||||
| 		interrupt, err = fn(ctx) | ||||
| 		if err != nil { | ||||
| 			fmt.Fprintf(Cfg.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err) | ||||
| 			fmt.Fprintf(ctx.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err) | ||||
| 			return InterruptNone, err | ||||
| 		} | ||||
| 		if interrupt != InterruptNone { | ||||
| @@ -132,7 +132,7 @@ func (c *Callbacks) FireAfter(e Event, ctx *Context) (err error) { | ||||
| 	callbacks := c.after[e] | ||||
| 	for _, fn := range callbacks { | ||||
| 		if err = fn(ctx); err != nil { | ||||
| 			fmt.Fprintf(Cfg.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err) | ||||
| 			fmt.Fprintf(ctx.LogWriter, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err) | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -8,15 +8,17 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestCallbacks(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	afterCalled := false | ||||
| 	beforeCalled := false | ||||
| 	c := NewCallbacks() | ||||
|  | ||||
| 	c.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 	ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 		beforeCalled = true | ||||
| 		return InterruptNone, nil | ||||
| 	}) | ||||
| 	c.After(EventRegister, func(ctx *Context) error { | ||||
| 	ab.Callbacks.After(EventRegister, func(ctx *Context) error { | ||||
| 		afterCalled = true | ||||
| 		return nil | ||||
| 	}) | ||||
| @@ -25,7 +27,7 @@ func TestCallbacks(t *testing.T) { | ||||
| 		t.Error("Neither should be called.") | ||||
| 	} | ||||
|  | ||||
| 	interrupt, err := c.FireBefore(EventRegister, NewContext()) | ||||
| 	interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext()) | ||||
| 	if err != nil { | ||||
| 		t.Error("Unexpected error:", err) | ||||
| 	} | ||||
| @@ -40,27 +42,29 @@ func TestCallbacks(t *testing.T) { | ||||
| 		t.Error("Expected after not to be called.") | ||||
| 	} | ||||
|  | ||||
| 	c.FireAfter(EventRegister, NewContext()) | ||||
| 	ab.Callbacks.FireAfter(EventRegister, ab.NewContext()) | ||||
| 	if !afterCalled { | ||||
| 		t.Error("Expected after to be called.") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCallbacksInterrupt(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	before1 := false | ||||
| 	before2 := false | ||||
| 	c := NewCallbacks() | ||||
|  | ||||
| 	c.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 	ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 		before1 = true | ||||
| 		return InterruptAccountLocked, nil | ||||
| 	}) | ||||
| 	c.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 	ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 		before2 = true | ||||
| 		return InterruptNone, nil | ||||
| 	}) | ||||
|  | ||||
| 	interrupt, err := c.FireBefore(EventRegister, NewContext()) | ||||
| 	interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext()) | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
| @@ -77,26 +81,26 @@ func TestCallbacksInterrupt(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestCallbacksBeforeErrors(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	log := &bytes.Buffer{} | ||||
| 	Cfg = &Config{ | ||||
| 		LogWriter: log, | ||||
| 	} | ||||
| 	ab.LogWriter = log | ||||
| 	before1 := false | ||||
| 	before2 := false | ||||
| 	c := NewCallbacks() | ||||
|  | ||||
| 	errValue := errors.New("Problem occured") | ||||
|  | ||||
| 	c.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 	ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 		before1 = true | ||||
| 		return InterruptNone, errValue | ||||
| 	}) | ||||
| 	c.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 	ab.Callbacks.Before(EventRegister, func(ctx *Context) (Interrupt, error) { | ||||
| 		before2 = true | ||||
| 		return InterruptNone, nil | ||||
| 	}) | ||||
|  | ||||
| 	interrupt, err := c.FireBefore(EventRegister, NewContext()) | ||||
| 	interrupt, err := ab.Callbacks.FireBefore(EventRegister, ab.NewContext()) | ||||
| 	if err != errValue { | ||||
| 		t.Error("Expected an error to come back.") | ||||
| 	} | ||||
| @@ -117,26 +121,26 @@ func TestCallbacksBeforeErrors(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestCallbacksAfterErrors(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	log := &bytes.Buffer{} | ||||
| 	Cfg = &Config{ | ||||
| 		LogWriter: log, | ||||
| 	} | ||||
| 	ab := New() | ||||
| 	ab.LogWriter = log | ||||
| 	after1 := false | ||||
| 	after2 := false | ||||
| 	c := NewCallbacks() | ||||
|  | ||||
| 	errValue := errors.New("Problem occured") | ||||
|  | ||||
| 	c.After(EventRegister, func(ctx *Context) error { | ||||
| 	ab.Callbacks.After(EventRegister, func(ctx *Context) error { | ||||
| 		after1 = true | ||||
| 		return errValue | ||||
| 	}) | ||||
| 	c.After(EventRegister, func(ctx *Context) error { | ||||
| 	ab.Callbacks.After(EventRegister, func(ctx *Context) error { | ||||
| 		after2 = true | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| 	err := c.FireAfter(EventRegister, NewContext()) | ||||
| 	err := ab.Callbacks.FireAfter(EventRegister, ab.NewContext()) | ||||
| 	if err != errValue { | ||||
| 		t.Error("Expected an error to come back.") | ||||
| 	} | ||||
| @@ -154,6 +158,8 @@ func TestCallbacksAfterErrors(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestEventString(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		ev  Event | ||||
| 		str string | ||||
| @@ -178,6 +184,8 @@ func TestEventString(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestInterruptString(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		in  Interrupt | ||||
| 		str string | ||||
|   | ||||
| @@ -64,8 +64,8 @@ type CookieStoreMaker func(http.ResponseWriter, *http.Request) ClientStorer | ||||
| type SessionStoreMaker func(http.ResponseWriter, *http.Request) ClientStorer | ||||
|  | ||||
| // FlashSuccess returns FlashSuccessKey from the session and removes it. | ||||
| func FlashSuccess(w http.ResponseWriter, r *http.Request) string { | ||||
| 	storer := Cfg.SessionStoreMaker(w, r) | ||||
| func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string { | ||||
| 	storer := a.SessionStoreMaker(w, r) | ||||
| 	msg, ok := storer.Get(FlashSuccessKey) | ||||
| 	if ok { | ||||
| 		storer.Del(FlashSuccessKey) | ||||
| @@ -75,8 +75,8 @@ func FlashSuccess(w http.ResponseWriter, r *http.Request) string { | ||||
| } | ||||
|  | ||||
| // FlashError returns FlashError from the session and removes it. | ||||
| func FlashError(w http.ResponseWriter, r *http.Request) string { | ||||
| 	storer := Cfg.SessionStoreMaker(w, r) | ||||
| func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string { | ||||
| 	storer := a.SessionStoreMaker(w, r) | ||||
| 	msg, ok := storer.Get(FlashErrorKey) | ||||
| 	if ok { | ||||
| 		storer.Del(FlashErrorKey) | ||||
|   | ||||
| @@ -14,6 +14,8 @@ func (t testClientStorerErr) Get(key string) (string, bool) { | ||||
| func (t testClientStorerErr) Del(key string) {} | ||||
|  | ||||
| func TestClientStorerErr(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	var cs testClientStorerErr | ||||
|  | ||||
| 	csw := clientStoreWrapper{&cs} | ||||
| @@ -30,19 +32,22 @@ func TestClientStorerErr(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestFlashClearer(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"} | ||||
| 	Cfg.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { | ||||
| 	ab := New() | ||||
| 	ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { | ||||
| 		return session | ||||
| 	} | ||||
|  | ||||
| 	if msg := FlashSuccess(nil, nil); msg != "success" { | ||||
| 	if msg := ab.FlashSuccess(nil, nil); msg != "success" { | ||||
| 		t.Error("Unexpected flash success:", msg) | ||||
| 	} | ||||
| 	if msg, ok := session.Get(FlashSuccessKey); ok { | ||||
| 		t.Error("Unexpected success flash:", msg) | ||||
| 	} | ||||
|  | ||||
| 	if msg := FlashError(nil, nil); msg != "error" { | ||||
| 	if msg := ab.FlashError(nil, nil); msg != "error" { | ||||
| 		t.Error("Unexpected flash error:", msg) | ||||
| 	} | ||||
| 	if msg, ok := session.Get(FlashErrorKey); ok { | ||||
|   | ||||
							
								
								
									
										93
									
								
								config.go
									
									
									
									
									
								
							
							
						
						
									
										93
									
								
								config.go
									
									
									
									
									
								
							| @@ -10,9 +10,6 @@ import ( | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
|  | ||||
| // Cfg is the singleton instance of Config | ||||
| var Cfg = NewConfig() | ||||
|  | ||||
| // Config holds all the configuration for both authboss and it's modules. | ||||
| type Config struct { | ||||
| 	// MountPath is the path to mount authboss's routes at (eg /auth). | ||||
| @@ -117,61 +114,55 @@ type Config struct { | ||||
| 	Mailer Mailer | ||||
| } | ||||
|  | ||||
| // NewConfig creates a config full of healthy default values. | ||||
| // Notable exceptions to default values are the Storers. | ||||
| // This method is called automatically on startup and is set to authboss.Cfg | ||||
| // so implementers need not call it. Primarily exported for testing. | ||||
| func NewConfig() *Config { | ||||
| 	return &Config{ | ||||
| 		MountPath:  "/", | ||||
| 		ViewsPath:  "./", | ||||
| 		RootURL:    "http://localhost:8080", | ||||
| 		BCryptCost: bcrypt.DefaultCost, | ||||
| // Defaults sets the configuration's default values. | ||||
| func (c *Config) Defaults() { | ||||
| 	c.MountPath = "/" | ||||
| 	c.ViewsPath = "./" | ||||
| 	c.RootURL = "http://localhost:8080" | ||||
| 	c.BCryptCost = bcrypt.DefaultCost | ||||
|  | ||||
| 		PrimaryID: StoreEmail, | ||||
| 	c.PrimaryID = StoreEmail | ||||
|  | ||||
| 		Layout:          template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`)), | ||||
| 		LayoutHTMLEmail: template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`)), | ||||
| 		LayoutTextEmail: template.Must(template.New("").Parse(`{{template "authboss" .}}`)), | ||||
| 	c.Layout = template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`)) | ||||
| 	c.LayoutHTMLEmail = template.Must(template.New("").Parse(`<!DOCTYPE html><html><body>{{template "authboss" .}}</body></html>`)) | ||||
| 	c.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
|  | ||||
| 		AuthLoginOKPath:   "/", | ||||
| 		AuthLoginFailPath: "/", | ||||
| 		AuthLogoutOKPath:  "/", | ||||
| 	c.AuthLoginOKPath = "/" | ||||
| 	c.AuthLoginFailPath = "/" | ||||
| 	c.AuthLogoutOKPath = "/" | ||||
|  | ||||
| 		RecoverOKPath:        "/", | ||||
| 		RecoverTokenDuration: time.Duration(24) * time.Hour, | ||||
| 	c.RecoverOKPath = "/" | ||||
| 	c.RecoverTokenDuration = time.Duration(24) * time.Hour | ||||
|  | ||||
| 		RegisterOKPath: "/", | ||||
| 	c.RegisterOKPath = "/" | ||||
|  | ||||
| 		Policies: []Validator{ | ||||
| 			Rules{ | ||||
| 				FieldName:       "username", | ||||
| 				Required:        true, | ||||
| 				MinLength:       2, | ||||
| 				MaxLength:       4, | ||||
| 				AllowWhitespace: false, | ||||
| 			}, | ||||
| 			Rules{ | ||||
| 				FieldName: "password", | ||||
| 				Required:  true, | ||||
| 				MinLength: 4, | ||||
| 				MaxLength: 8, | ||||
|  | ||||
| 				AllowWhitespace: false, | ||||
| 			}, | ||||
| 	c.Policies = []Validator{ | ||||
| 		Rules{ | ||||
| 			FieldName:       "username", | ||||
| 			Required:        true, | ||||
| 			MinLength:       2, | ||||
| 			MaxLength:       4, | ||||
| 			AllowWhitespace: false, | ||||
| 		}, | ||||
| 		ConfirmFields: []string{ | ||||
| 			StorePassword, ConfirmPrefix + StorePassword, | ||||
| 		Rules{ | ||||
| 			FieldName:       "password", | ||||
| 			Required:        true, | ||||
| 			MinLength:       4, | ||||
| 			MaxLength:       8, | ||||
| 			AllowWhitespace: false, | ||||
| 		}, | ||||
|  | ||||
| 		ExpireAfter: 60 * time.Minute, | ||||
|  | ||||
| 		LockAfter:    3, | ||||
| 		LockWindow:   5 * time.Minute, | ||||
| 		LockDuration: 5 * time.Hour, | ||||
|  | ||||
| 		LogWriter: NewDefaultLogger(), | ||||
| 		Callbacks: NewCallbacks(), | ||||
| 		Mailer:    LogMailer(ioutil.Discard), | ||||
| 	} | ||||
| 	c.ConfirmFields = []string{ | ||||
| 		StorePassword, ConfirmPrefix + StorePassword, | ||||
| 	} | ||||
|  | ||||
| 	c.ExpireAfter = 60 * time.Minute | ||||
|  | ||||
| 	c.LockAfter = 3 | ||||
| 	c.LockWindow = 5 * time.Minute | ||||
| 	c.LockDuration = 5 * time.Hour | ||||
|  | ||||
| 	c.LogWriter = NewDefaultLogger() | ||||
| 	c.Callbacks = NewCallbacks() | ||||
| 	c.Mailer = LogMailer(ioutil.Discard) | ||||
| } | ||||
|   | ||||
| @@ -53,23 +53,23 @@ type Confirm struct { | ||||
| // Initialize the module | ||||
| func (c *Confirm) Initialize() (err error) { | ||||
| 	var ok bool | ||||
| 	storer, ok := authboss.Cfg.Storer.(ConfirmStorer) | ||||
| 	storer, ok := authboss.a.Storer.(ConfirmStorer) | ||||
| 	if storer == nil || !ok { | ||||
| 		return errors.New("confirm: Need a ConfirmStorer") | ||||
| 	} | ||||
|  | ||||
| 	c.emailHTMLTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutHTMLEmail, authboss.Cfg.ViewsPath, tplConfirmHTML) | ||||
| 	c.emailHTMLTemplates, err = response.LoadTemplates(authboss.a.LayoutHTMLEmail, authboss.a.ViewsPath, tplConfirmHTML) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	c.emailTextTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutTextEmail, authboss.Cfg.ViewsPath, tplConfirmText) | ||||
| 	c.emailTextTemplates, err = response.LoadTemplates(authboss.a.LayoutTextEmail, authboss.a.ViewsPath, tplConfirmText) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventGet, c.beforeGet) | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventAuth, c.beforeGet) | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventRegister, c.afterRegister) | ||||
| 	authboss.a.Callbacks.Before(authboss.EventGet, c.beforeGet) | ||||
| 	authboss.a.Callbacks.Before(authboss.EventAuth, c.beforeGet) | ||||
| 	authboss.a.Callbacks.After(authboss.EventRegister, c.afterRegister) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -84,10 +84,10 @@ func (c *Confirm) Routes() authboss.RouteTable { | ||||
| // Storage requirements | ||||
| func (c *Confirm) Storage() authboss.StorageOptions { | ||||
| 	return authboss.StorageOptions{ | ||||
| 		authboss.Cfg.PrimaryID: authboss.String, | ||||
| 		authboss.StoreEmail:    authboss.String, | ||||
| 		StoreConfirmToken:      authboss.String, | ||||
| 		StoreConfirmed:         authboss.Bool, | ||||
| 		authboss.a.PrimaryID: authboss.String, | ||||
| 		authboss.StoreEmail:  authboss.String, | ||||
| 		StoreConfirmToken:    authboss.String, | ||||
| 		StoreConfirmed:       authboss.Bool, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -135,18 +135,18 @@ var goConfirmEmail = func(c *Confirm, to, token string) { | ||||
|  | ||||
| // confirmEmail sends a confirmation e-mail. | ||||
| func (c *Confirm) confirmEmail(to, token string) { | ||||
| 	p := path.Join(authboss.Cfg.MountPath, "confirm") | ||||
| 	url := fmt.Sprintf("%s%s?%s=%s", authboss.Cfg.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token)) | ||||
| 	p := path.Join(authboss.a.MountPath, "confirm") | ||||
| 	url := fmt.Sprintf("%s%s?%s=%s", authboss.a.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token)) | ||||
|  | ||||
| 	email := authboss.Email{ | ||||
| 		To:      []string{to}, | ||||
| 		From:    authboss.Cfg.EmailFrom, | ||||
| 		Subject: authboss.Cfg.EmailSubjectPrefix + "Confirm New Account", | ||||
| 		From:    authboss.a.EmailFrom, | ||||
| 		Subject: authboss.a.EmailSubjectPrefix + "Confirm New Account", | ||||
| 	} | ||||
|  | ||||
| 	err := response.Email(email, c.emailHTMLTemplates, tplConfirmHTML, c.emailTextTemplates, tplConfirmText, url) | ||||
| 	if err != nil { | ||||
| 		fmt.Fprintf(authboss.Cfg.LogWriter, "confirm: Failed to send e-mail: %v", err) | ||||
| 		fmt.Fprintf(authboss.a.LogWriter, "confirm: Failed to send e-mail: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -166,7 +166,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 	sum := md5.Sum(toHash) | ||||
|  | ||||
| 	dbTok := base64.StdEncoding.EncodeToString(sum[:]) | ||||
| 	user, err := authboss.Cfg.Storer.(ConfirmStorer).ConfirmUser(dbTok) | ||||
| 	user, err := authboss.a.Storer.(ConfirmStorer).ConfirmUser(dbTok) | ||||
| 	if err == authboss.ErrUserNotFound { | ||||
| 		return authboss.ErrAndRedirect{Location: "/", Err: errors.New("confirm: token not found")} | ||||
| 	} else if err != nil { | ||||
| @@ -178,7 +178,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 	ctx.User[StoreConfirmToken] = "" | ||||
| 	ctx.User[StoreConfirmed] = true | ||||
|  | ||||
| 	key, err := ctx.User.StringErr(authboss.Cfg.PrimaryID) | ||||
| 	key, err := ctx.User.StringErr(authboss.a.PrimaryID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -188,7 +188,7 @@ func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r | ||||
| 	} | ||||
|  | ||||
| 	ctx.SessionStorer.Put(authboss.SessionKey, key) | ||||
| 	response.Redirect(ctx, w, r, authboss.Cfg.RegisterOKPath, "You have successfully confirmed your account.", "", true) | ||||
| 	response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "You have successfully confirmed your account.", "", true) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -18,9 +18,9 @@ import ( | ||||
|  | ||||
| func setup() *Confirm { | ||||
| 	authboss.NewConfig() | ||||
| 	authboss.Cfg.Storer = mocks.NewMockStorer() | ||||
| 	authboss.Cfg.LayoutHTMLEmail = template.Must(template.New("").Parse(`email ^_^`)) | ||||
| 	authboss.Cfg.LayoutTextEmail = template.Must(template.New("").Parse(`email`)) | ||||
| 	authboss.a.Storer = mocks.NewMockStorer() | ||||
| 	authboss.a.LayoutHTMLEmail = template.Must(template.New("").Parse(`email ^_^`)) | ||||
| 	authboss.a.LayoutTextEmail = template.Must(template.New("").Parse(`email`)) | ||||
|  | ||||
| 	c := &Confirm{} | ||||
| 	if err := c.Initialize(); err != nil { | ||||
| @@ -100,9 +100,9 @@ func TestConfirm_AfterRegister(t *testing.T) { | ||||
| 	c := setup() | ||||
| 	ctx := authboss.NewContext() | ||||
| 	log := &bytes.Buffer{} | ||||
| 	authboss.Cfg.LogWriter = log | ||||
| 	authboss.Cfg.Mailer = authboss.LogMailer(log) | ||||
| 	authboss.Cfg.PrimaryID = authboss.StoreUsername | ||||
| 	authboss.a.LogWriter = log | ||||
| 	authboss.a.Mailer = authboss.LogMailer(log) | ||||
| 	authboss.a.PrimaryID = authboss.StoreUsername | ||||
|  | ||||
| 	sentEmail := false | ||||
|  | ||||
| @@ -115,7 +115,7 @@ func TestConfirm_AfterRegister(t *testing.T) { | ||||
| 		t.Error("Expected it to die with user error:", err) | ||||
| 	} | ||||
|  | ||||
| 	ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: "username"} | ||||
| 	ctx.User = authboss.Attributes{authboss.a.PrimaryID: "username"} | ||||
| 	if err := c.afterRegister(ctx); err == nil || err.(authboss.AttributeErr).Name != "email" { | ||||
| 		t.Error("Expected it to die with e-mail address error:", err) | ||||
| 	} | ||||
| @@ -135,8 +135,8 @@ func TestConfirm_AfterRegister(t *testing.T) { | ||||
| func TestConfirm_ConfirmHandlerErrors(t *testing.T) { | ||||
| 	c := setup() | ||||
| 	log := &bytes.Buffer{} | ||||
| 	authboss.Cfg.LogWriter = log | ||||
| 	authboss.Cfg.Mailer = authboss.LogMailer(log) | ||||
| 	authboss.a.LogWriter = log | ||||
| 	authboss.a.Mailer = authboss.LogMailer(log) | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		URL       string | ||||
| @@ -177,8 +177,8 @@ func TestConfirm_Confirm(t *testing.T) { | ||||
| 	c := setup() | ||||
| 	ctx := authboss.NewContext() | ||||
| 	log := &bytes.Buffer{} | ||||
| 	authboss.Cfg.LogWriter = log | ||||
| 	authboss.Cfg.Mailer = authboss.LogMailer(log) | ||||
| 	authboss.a.LogWriter = log | ||||
| 	authboss.a.Mailer = authboss.LogMailer(log) | ||||
|  | ||||
| 	// Create a token | ||||
| 	token := []byte("hi") | ||||
| @@ -186,7 +186,7 @@ func TestConfirm_Confirm(t *testing.T) { | ||||
|  | ||||
| 	// Create the "database" | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
| 	user := authboss.Attributes{ | ||||
| 		authboss.StoreUsername: "usern", | ||||
| 		StoreConfirmToken:      base64.StdEncoding.EncodeToString(sum[:]), | ||||
|   | ||||
							
								
								
									
										20
									
								
								context.go
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								context.go
									
									
									
									
									
								
							| @@ -19,6 +19,8 @@ var ( | ||||
| // need for context is a request's session store. It is not safe for use by | ||||
| // multiple goroutines. | ||||
| type Context struct { | ||||
| 	*Authboss | ||||
|  | ||||
| 	SessionStorer ClientStorerErr | ||||
| 	CookieStorer  ClientStorerErr | ||||
| 	User          Attributes | ||||
| @@ -28,17 +30,19 @@ type Context struct { | ||||
| } | ||||
|  | ||||
| // NewContext is exported for testing modules. | ||||
| func NewContext() *Context { | ||||
| 	return &Context{} | ||||
| func (a *Authboss) NewContext() *Context { | ||||
| 	return &Context{ | ||||
| 		Authboss: a, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ContextFromRequest creates a context from an http request. | ||||
| func ContextFromRequest(r *http.Request) (*Context, error) { | ||||
| func (a *Authboss) ContextFromRequest(r *http.Request) (*Context, error) { | ||||
| 	if err := r.ParseForm(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	c := NewContext() | ||||
| 	c := a.NewContext() | ||||
| 	c.formValues = map[string][]string(r.Form) | ||||
| 	c.postFormValues = map[string][]string(r.PostForm) | ||||
| 	return c, nil | ||||
| @@ -111,9 +115,9 @@ func (c *Context) LoadUser(key string) error { | ||||
| 	var err error | ||||
|  | ||||
| 	if index := strings.IndexByte(key, ';'); index > 0 { | ||||
| 		user, err = Cfg.OAuth2Storer.GetOAuth(key[:index], key[index+1:]) | ||||
| 		user, err = c.OAuth2Storer.GetOAuth(key[:index], key[index+1:]) | ||||
| 	} else { | ||||
| 		user, err = Cfg.Storer.Get(key) | ||||
| 		user, err = c.Storer.Get(key) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| @@ -144,12 +148,12 @@ func (c *Context) SaveUser() error { | ||||
| 		return errors.New("User not initialized.") | ||||
| 	} | ||||
|  | ||||
| 	key, ok := c.User.String(Cfg.PrimaryID) | ||||
| 	key, ok := c.User.String(c.PrimaryID) | ||||
| 	if !ok { | ||||
| 		return errors.New("User improperly initialized, primary ID missing") | ||||
| 	} | ||||
|  | ||||
| 	return Cfg.Storer.Put(key, c.User) | ||||
| 	return c.Storer.Put(key, c.User) | ||||
| } | ||||
|  | ||||
| // Attributes converts the post form values into an attributes map. | ||||
|   | ||||
| @@ -8,13 +8,17 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestContext_Request(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
|  | ||||
| 	req, err := http.NewRequest("POST", "http://localhost?query=string", bytes.NewBufferString("post=form")) | ||||
| 	if err != nil { | ||||
| 		t.Error("Unexpected Error:", err) | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
|  | ||||
| 	ctx, err := ContextFromRequest(req) | ||||
| 	ctx, err := ab.ContextFromRequest(req) | ||||
| 	if err != nil { | ||||
| 		t.Error("Unexpected Error:", err) | ||||
| 	} | ||||
| @@ -69,10 +73,12 @@ func TestContext_Request(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestContext_SaveUser(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	ctx := NewContext() | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	ctx := ab.NewContext() | ||||
| 	storer := mockStorer{} | ||||
| 	Cfg.Storer = storer | ||||
| 	ab.Storer = storer | ||||
| 	ctx.User = Attributes{StoreUsername: "joe", StoreEmail: "hello@joe.com", StorePassword: "mysticalhash"} | ||||
|  | ||||
| 	err := ctx.SaveUser() | ||||
| @@ -93,8 +99,10 @@ func TestContext_SaveUser(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestContext_LoadUser(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	ctx := NewContext() | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	ctx := ab.NewContext() | ||||
|  | ||||
| 	attr := Attributes{ | ||||
| 		"email":    "hello@joe.com", | ||||
| @@ -107,8 +115,8 @@ func TestContext_LoadUser(t *testing.T) { | ||||
| 		"joe":        attr, | ||||
| 		"whatgoogle": attr, | ||||
| 	} | ||||
| 	Cfg.Storer = storer | ||||
| 	Cfg.OAuth2Storer = storer | ||||
| 	ab.Storer = storer | ||||
| 	ab.OAuth2Storer = storer | ||||
|  | ||||
| 	ctx.User = nil | ||||
| 	if err := ctx.LoadUser("joe"); err != nil { | ||||
| @@ -144,12 +152,14 @@ func TestContext_LoadUser(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestContext_LoadSessionUser(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	ctx := NewContext() | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	ctx := ab.NewContext() | ||||
| 	storer := mockStorer{ | ||||
| 		"joe": Attributes{"email": "hello@joe.com", "password": "mysticalhash"}, | ||||
| 	} | ||||
| 	Cfg.Storer = storer | ||||
| 	ab.Storer = storer | ||||
| 	ctx.SessionStorer = mockClientStore{ | ||||
| 		SessionKey: "joe", | ||||
| 	} | ||||
| @@ -169,9 +179,12 @@ func TestContext_LoadSessionUser(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestContext_Attributes(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	now := time.Now().UTC() | ||||
|  | ||||
| 	ctx := NewContext() | ||||
| 	ab := New() | ||||
| 	ctx := ab.NewContext() | ||||
| 	ctx.postFormValues = map[string][]string{ | ||||
| 		"a":        []string{"a", "1"}, | ||||
| 		"b_int":    []string{"5", "hello"}, | ||||
|   | ||||
| @@ -6,6 +6,8 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestAttributeErr(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	estr := "Failed to retrieve database attribute, type was wrong: lol (want: String, got: int)" | ||||
| 	if str := NewAttributeErr("lol", String, 5).Error(); str != estr { | ||||
| 		t.Error("Error was wrong:", str) | ||||
| @@ -19,6 +21,8 @@ func TestAttributeErr(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestClientDataErr(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	estr := "Failed to retrieve client attribute: lol" | ||||
| 	err := ClientDataErr{"lol"} | ||||
| 	if str := err.Error(); str != estr { | ||||
| @@ -27,6 +31,8 @@ func TestClientDataErr(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestErrAndRedirect(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	estr := "Error: cause, Redirecting to: /" | ||||
| 	err := ErrAndRedirect{errors.New("cause"), "/", "success", "failure"} | ||||
| 	if str := err.Error(); str != estr { | ||||
| @@ -35,6 +41,8 @@ func TestErrAndRedirect(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestRenderErr(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	estr := `Error rendering template "lol": cause, data: authboss.HTMLData{"a":5}` | ||||
| 	err := RenderErr{"lol", NewHTMLData("a", 5), errors.New("cause")} | ||||
| 	if str := err.Error(); str != estr { | ||||
|   | ||||
							
								
								
									
										31
									
								
								expire.go
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								expire.go
									
									
									
									
									
								
							| @@ -8,14 +8,14 @@ import ( | ||||
| var nowTime = time.Now | ||||
|  | ||||
| // TimeToExpiry returns zero if the user session is expired else the time until expiry. | ||||
| func TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration { | ||||
| 	return timeToExpiry(Cfg.SessionStoreMaker(w, r)) | ||||
| func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration { | ||||
| 	return a.timeToExpiry(a.SessionStoreMaker(w, r)) | ||||
| } | ||||
|  | ||||
| func timeToExpiry(session ClientStorer) time.Duration { | ||||
| func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration { | ||||
| 	dateStr, ok := session.Get(SessionLastAction) | ||||
| 	if !ok { | ||||
| 		return Cfg.ExpireAfter | ||||
| 		return a.ExpireAfter | ||||
| 	} | ||||
|  | ||||
| 	date, err := time.Parse(time.RFC3339, dateStr) | ||||
| @@ -23,7 +23,7 @@ func timeToExpiry(session ClientStorer) time.Duration { | ||||
| 		panic("last_action is not a valid RFC3339 date") | ||||
| 	} | ||||
|  | ||||
| 	remaining := date.Add(Cfg.ExpireAfter).Sub(nowTime().UTC()) | ||||
| 	remaining := date.Add(a.ExpireAfter).Sub(nowTime().UTC()) | ||||
| 	if remaining > 0 { | ||||
| 		return remaining | ||||
| 	} | ||||
| @@ -32,35 +32,36 @@ func timeToExpiry(session ClientStorer) time.Duration { | ||||
| } | ||||
|  | ||||
| // RefreshExpiry  updates the last action for the user, so he doesn't become expired. | ||||
| func RefreshExpiry(w http.ResponseWriter, r *http.Request) { | ||||
| 	session := Cfg.SessionStoreMaker(w, r) | ||||
| 	refreshExpiry(session) | ||||
| func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) { | ||||
| 	session := a.SessionStoreMaker(w, r) | ||||
| 	a.refreshExpiry(session) | ||||
| } | ||||
|  | ||||
| func refreshExpiry(session ClientStorer) { | ||||
| func (a *Authboss) refreshExpiry(session ClientStorer) { | ||||
| 	session.Put(SessionLastAction, nowTime().UTC().Format(time.RFC3339)) | ||||
| } | ||||
|  | ||||
| type expireMiddleware struct { | ||||
| 	ab   *Authboss | ||||
| 	next http.Handler | ||||
| } | ||||
|  | ||||
| // ExpireMiddleware ensures that the user's expiry information is kept up-to-date | ||||
| // on each request. Deletes the SessionKey from the session if the user is | ||||
| // expired (Cfg.ExpireAfter duration since SessionLastAction). | ||||
| func ExpireMiddleware(next http.Handler) http.Handler { | ||||
| 	return expireMiddleware{next} | ||||
| // expired (a.ExpireAfter duration since SessionLastAction). | ||||
| func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler { | ||||
| 	return expireMiddleware{a, next} | ||||
| } | ||||
|  | ||||
| func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 	session := Cfg.SessionStoreMaker(w, r) | ||||
| 	session := m.ab.SessionStoreMaker(w, r) | ||||
| 	if _, ok := session.Get(SessionKey); ok { | ||||
| 		ttl := timeToExpiry(session) | ||||
| 		ttl := m.ab.timeToExpiry(session) | ||||
| 		if ttl == 0 { | ||||
| 			session.Del(SessionKey) | ||||
| 			session.Del(SessionLastAction) | ||||
| 		} else { | ||||
| 			refreshExpiry(session) | ||||
| 			m.ab.refreshExpiry(session) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -7,15 +7,17 @@ import ( | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // These tests use the global variable nowTime so cannot be parallelized | ||||
|  | ||||
| func TestDudeIsExpired(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	ab := New() | ||||
|  | ||||
| 	session := mockClientStore{SessionKey: "username"} | ||||
| 	refreshExpiry(session) | ||||
| 	ab.refreshExpiry(session) | ||||
| 	nowTime = func() time.Time { | ||||
| 		return time.Now().UTC().Add(Cfg.ExpireAfter * 2) | ||||
| 		return time.Now().UTC().Add(ab.ExpireAfter * 2) | ||||
| 	} | ||||
| 	Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 	ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 		return session | ||||
| 	} | ||||
|  | ||||
| @@ -23,7 +25,7 @@ func TestDudeIsExpired(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 	called := false | ||||
|  | ||||
| 	m := ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		called = true | ||||
| 	})) | ||||
|  | ||||
| @@ -43,14 +45,14 @@ func TestDudeIsExpired(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestDudeIsNotExpired(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	ab := New() | ||||
|  | ||||
| 	session := mockClientStore{SessionKey: "username"} | ||||
| 	refreshExpiry(session) | ||||
| 	ab.refreshExpiry(session) | ||||
| 	nowTime = func() time.Time { | ||||
| 		return time.Now().UTC().Add(Cfg.ExpireAfter / 2) | ||||
| 		return time.Now().UTC().Add(ab.ExpireAfter / 2) | ||||
| 	} | ||||
| 	Cfg.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 	ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer { | ||||
| 		return session | ||||
| 	} | ||||
|  | ||||
| @@ -58,7 +60,7 @@ func TestDudeIsNotExpired(t *testing.T) { | ||||
| 	w := httptest.NewRecorder() | ||||
| 	called := false | ||||
|  | ||||
| 	m := ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		called = true | ||||
| 	})) | ||||
|  | ||||
|   | ||||
| @@ -25,10 +25,10 @@ var ( | ||||
| 	funcMap = template.FuncMap{ | ||||
| 		"title": strings.Title, | ||||
| 		"mountpathed": func(location string) string { | ||||
| 			if authboss.Cfg.MountPath == "/" { | ||||
| 			if authboss.a.MountPath == "/" { | ||||
| 				return location | ||||
| 			} | ||||
| 			return path.Join(authboss.Cfg.MountPath, location) | ||||
| 			return path.Join(authboss.a.MountPath, location) | ||||
| 		}, | ||||
| 	} | ||||
| ) | ||||
| @@ -77,10 +77,10 @@ func (t Templates) Render(ctx *authboss.Context, w http.ResponseWriter, r *http. | ||||
| 		return authboss.RenderErr{tpl.Name(), data, ErrTemplateNotFound} | ||||
| 	} | ||||
|  | ||||
| 	data.MergeKV("xsrfName", template.HTML(authboss.Cfg.XSRFName), "xsrfToken", template.HTML(authboss.Cfg.XSRFMaker(w, r))) | ||||
| 	data.MergeKV("xsrfName", template.HTML(authboss.a.XSRFName), "xsrfToken", template.HTML(authboss.a.XSRFMaker(w, r))) | ||||
|  | ||||
| 	if authboss.Cfg.LayoutDataMaker != nil { | ||||
| 		data.Merge(authboss.Cfg.LayoutDataMaker(w, r)) | ||||
| 	if authboss.a.LayoutDataMaker != nil { | ||||
| 		data.Merge(authboss.a.LayoutDataMaker(w, r)) | ||||
| 	} | ||||
|  | ||||
| 	if flash, ok := ctx.SessionStorer.Get(authboss.FlashSuccessKey); ok { | ||||
| @@ -130,7 +130,7 @@ func Email(email authboss.Email, htmlTpls Templates, nameHTML string, textTpls T | ||||
| 	} | ||||
| 	email.TextBody = plainBuffer.String() | ||||
|  | ||||
| 	if err := authboss.Cfg.Mailer.Send(email); err != nil { | ||||
| 	if err := authboss.a.Mailer.Send(email); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -86,7 +86,7 @@ func TestTemplates_Render(t *testing.T) { | ||||
|  | ||||
| func Test_Email(t *testing.T) { | ||||
| 	mockMailer := &mocks.MockMailer{} | ||||
| 	authboss.Cfg.Mailer = mockMailer | ||||
| 	authboss.a.Mailer = mockMailer | ||||
|  | ||||
| 	htmlTpls := Templates{"html": testEmailHTMLTempalte} | ||||
| 	textTpls := Templates{"plain": testEmailPlainTempalte} | ||||
|   | ||||
							
								
								
									
										38
									
								
								lock/lock.go
									
									
									
									
									
								
							
							
						
						
									
										38
									
								
								lock/lock.go
									
									
									
									
									
								
							| @@ -29,15 +29,15 @@ type Lock struct { | ||||
|  | ||||
| // Initialize the module | ||||
| func (l *Lock) Initialize() error { | ||||
| 	if authboss.Cfg.Storer == nil { | ||||
| 	if authboss.a.Storer == nil { | ||||
| 		return errors.New("lock: Need a Storer") | ||||
| 	} | ||||
|  | ||||
| 	// Events | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventGet, l.beforeAuth) | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventAuth, l.beforeAuth) | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventAuth, l.afterAuth) | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail) | ||||
| 	authboss.a.Callbacks.Before(authboss.EventGet, l.beforeAuth) | ||||
| 	authboss.a.Callbacks.Before(authboss.EventAuth, l.beforeAuth) | ||||
| 	authboss.a.Callbacks.After(authboss.EventAuth, l.afterAuth) | ||||
| 	authboss.a.Callbacks.After(authboss.EventAuthFail, l.afterAuthFail) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -50,10 +50,10 @@ func (l *Lock) Routes() authboss.RouteTable { | ||||
| // Storage requirements | ||||
| func (l *Lock) Storage() authboss.StorageOptions { | ||||
| 	return authboss.StorageOptions{ | ||||
| 		authboss.Cfg.PrimaryID: authboss.String, | ||||
| 		StoreAttemptNumber:     authboss.Integer, | ||||
| 		StoreAttemptTime:       authboss.DateTime, | ||||
| 		StoreLocked:            authboss.DateTime, | ||||
| 		authboss.a.PrimaryID: authboss.String, | ||||
| 		StoreAttemptNumber:   authboss.Integer, | ||||
| 		StoreAttemptTime:     authboss.DateTime, | ||||
| 		StoreLocked:          authboss.DateTime, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -104,9 +104,9 @@ func (l *Lock) afterAuthFail(ctx *authboss.Context) error { | ||||
|  | ||||
| 	nAttempts++ | ||||
|  | ||||
| 	if time.Now().UTC().Sub(lastAttempt) <= authboss.Cfg.LockWindow { | ||||
| 		if nAttempts >= int64(authboss.Cfg.LockAfter) { | ||||
| 			ctx.User[StoreLocked] = time.Now().UTC().Add(authboss.Cfg.LockDuration) | ||||
| 	if time.Now().UTC().Sub(lastAttempt) <= authboss.a.LockWindow { | ||||
| 		if nAttempts >= int64(authboss.a.LockAfter) { | ||||
| 			ctx.User[StoreLocked] = time.Now().UTC().Add(authboss.a.LockDuration) | ||||
| 		} | ||||
|  | ||||
| 		ctx.User[StoreAttemptNumber] = nAttempts | ||||
| @@ -124,7 +124,7 @@ func (l *Lock) afterAuthFail(ctx *authboss.Context) error { | ||||
|  | ||||
| // Lock a user manually. | ||||
| func (l *Lock) Lock(key string) error { | ||||
| 	user, err := authboss.Cfg.Storer.Get(key) | ||||
| 	user, err := authboss.a.Storer.Get(key) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -134,14 +134,14 @@ func (l *Lock) Lock(key string) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	attr[StoreLocked] = time.Now().UTC().Add(authboss.Cfg.LockDuration) | ||||
| 	attr[StoreLocked] = time.Now().UTC().Add(authboss.a.LockDuration) | ||||
|  | ||||
| 	return authboss.Cfg.Storer.Put(key, attr) | ||||
| 	return authboss.a.Storer.Put(key, attr) | ||||
| } | ||||
|  | ||||
| // Unlock a user that was locked by this module. | ||||
| func (l *Lock) Unlock(key string) error { | ||||
| 	user, err := authboss.Cfg.Storer.Get(key) | ||||
| 	user, err := authboss.a.Storer.Get(key) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -153,9 +153,9 @@ func (l *Lock) Unlock(key string) error { | ||||
|  | ||||
| 	// Set the last attempt to be -window*2 to avoid immediately | ||||
| 	// giving another login failure. | ||||
| 	attr[StoreAttemptTime] = time.Now().UTC().Add(-authboss.Cfg.LockWindow * 2) | ||||
| 	attr[StoreAttemptTime] = time.Now().UTC().Add(-authboss.a.LockWindow * 2) | ||||
| 	attr[StoreAttemptNumber] = int64(0) | ||||
| 	attr[StoreLocked] = time.Now().UTC().Add(-authboss.Cfg.LockDuration) | ||||
| 	attr[StoreLocked] = time.Now().UTC().Add(-authboss.a.LockDuration) | ||||
|  | ||||
| 	return authboss.Cfg.Storer.Put(key, attr) | ||||
| 	return authboss.a.Storer.Put(key, attr) | ||||
| } | ||||
|   | ||||
| @@ -53,8 +53,8 @@ func TestAfterAuth(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: "john@john.com"} | ||||
| 	authboss.a.Storer = storer | ||||
| 	ctx.User = authboss.Attributes{authboss.a.PrimaryID: "john@john.com"} | ||||
|  | ||||
| 	if err := lock.afterAuth(ctx); err != nil { | ||||
| 		t.Error(err) | ||||
| @@ -74,15 +74,15 @@ func TestAfterAuthFail_Lock(t *testing.T) { | ||||
|  | ||||
| 	ctx := authboss.NewContext() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
| 	lock := Lock{} | ||||
| 	authboss.Cfg.LockWindow = 30 * time.Minute | ||||
| 	authboss.Cfg.LockDuration = 30 * time.Minute | ||||
| 	authboss.Cfg.LockAfter = 3 | ||||
| 	authboss.a.LockWindow = 30 * time.Minute | ||||
| 	authboss.a.LockDuration = 30 * time.Minute | ||||
| 	authboss.a.LockAfter = 3 | ||||
|  | ||||
| 	email := "john@john.com" | ||||
|  | ||||
| 	ctx.User = map[string]interface{}{authboss.Cfg.PrimaryID: email} | ||||
| 	ctx.User = map[string]interface{}{authboss.a.PrimaryID: email} | ||||
|  | ||||
| 	old = time.Now().UTC().Add(-1 * time.Hour) | ||||
|  | ||||
| @@ -123,17 +123,17 @@ func TestAfterAuthFail_Reset(t *testing.T) { | ||||
| 	ctx := authboss.NewContext() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	lock := Lock{} | ||||
| 	authboss.Cfg.LockWindow = 30 * time.Minute | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.LockWindow = 30 * time.Minute | ||||
| 	authboss.a.Storer = storer | ||||
|  | ||||
| 	old = time.Now().UTC().Add(-time.Hour) | ||||
|  | ||||
| 	email := "john@john.com" | ||||
| 	ctx.User = map[string]interface{}{ | ||||
| 		authboss.Cfg.PrimaryID: email, | ||||
| 		StoreAttemptNumber:     int64(2), | ||||
| 		StoreAttemptTime:       old, | ||||
| 		StoreLocked:            old, | ||||
| 		authboss.a.PrimaryID: email, | ||||
| 		StoreAttemptNumber:   int64(2), | ||||
| 		StoreAttemptTime:     old, | ||||
| 		StoreLocked:          old, | ||||
| 	} | ||||
|  | ||||
| 	lock.afterAuthFail(ctx) | ||||
| @@ -162,13 +162,13 @@ func TestAfterAuthFail_Errors(t *testing.T) { | ||||
| func TestLock(t *testing.T) { | ||||
| 	authboss.NewConfig() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
| 	lock := Lock{} | ||||
|  | ||||
| 	email := "john@john.com" | ||||
| 	storer.Users[email] = map[string]interface{}{ | ||||
| 		authboss.Cfg.PrimaryID: email, | ||||
| 		"password":             "password", | ||||
| 		authboss.a.PrimaryID: email, | ||||
| 		"password":           "password", | ||||
| 	} | ||||
|  | ||||
| 	err := lock.Lock(email) | ||||
| @@ -184,15 +184,15 @@ func TestLock(t *testing.T) { | ||||
| func TestUnlock(t *testing.T) { | ||||
| 	authboss.NewConfig() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
| 	lock := Lock{} | ||||
| 	authboss.Cfg.LockWindow = 1 * time.Hour | ||||
| 	authboss.a.LockWindow = 1 * time.Hour | ||||
|  | ||||
| 	email := "john@john.com" | ||||
| 	storer.Users[email] = map[string]interface{}{ | ||||
| 		authboss.Cfg.PrimaryID: email, | ||||
| 		"password":             "password", | ||||
| 		"locked":               true, | ||||
| 		authboss.a.PrimaryID: email, | ||||
| 		"password":           "password", | ||||
| 		"locked":             true, | ||||
| 	} | ||||
|  | ||||
| 	err := lock.Unlock(email) | ||||
| @@ -201,7 +201,7 @@ func TestUnlock(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	attemptTime := storer.Users[email][StoreAttemptTime].(time.Time) | ||||
| 	if attemptTime.After(time.Now().UTC().Add(-authboss.Cfg.LockWindow)) { | ||||
| 	if attemptTime.After(time.Now().UTC().Add(-authboss.a.LockWindow)) { | ||||
| 		t.Error("StoreLocked not set correctly:", attemptTime) | ||||
| 	} | ||||
| 	if number := storer.Users[email][StoreAttemptNumber].(int64); number != int64(0) { | ||||
|   | ||||
| @@ -9,6 +9,8 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestDefaultLogger(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	logger := NewDefaultLogger() | ||||
| 	if logger == nil { | ||||
| 		t.Error("Logger was not created.") | ||||
| @@ -16,6 +18,8 @@ func TestDefaultLogger(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestDefaultLoggerOutput(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	buffer := &bytes.Buffer{} | ||||
| 	logger := (*DefaultLogger)(log.New(buffer, "", log.LstdFlags)) | ||||
| 	io.WriteString(logger, "hello world") | ||||
|   | ||||
| @@ -10,8 +10,8 @@ import ( | ||||
| ) | ||||
|  | ||||
| // SendMail uses the currently configured mailer to deliver e-mails. | ||||
| func SendMail(data Email) error { | ||||
| 	return Cfg.Mailer.Send(data) | ||||
| func (a *Authboss) SendMail(data Email) error { | ||||
| 	return a.Mailer.Send(data) | ||||
| } | ||||
|  | ||||
| // Mailer is a type that is capable of sending an e-mail. | ||||
|   | ||||
| @@ -8,15 +8,16 @@ import ( | ||||
| ) | ||||
|  | ||||
| func TestMailer(t *testing.T) { | ||||
| 	Cfg = NewConfig() | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ab := New() | ||||
| 	mailServer := &bytes.Buffer{} | ||||
|  | ||||
| 	Cfg.Mailer = LogMailer(mailServer) | ||||
| 	Cfg.Storer = mockStorer{} | ||||
| 	Cfg.LogWriter = ioutil.Discard | ||||
| 	Init() | ||||
| 	ab.Mailer = LogMailer(mailServer) | ||||
| 	ab.Storer = mockStorer{} | ||||
| 	ab.LogWriter = ioutil.Discard | ||||
|  | ||||
| 	err := SendMail(Email{ | ||||
| 	err := ab.SendMail(Email{ | ||||
| 		To:       []string{"some@email.com", "a@a.com"}, | ||||
| 		ToNames:  []string{"Jake", "Noname"}, | ||||
| 		From:     "some@guy.com", | ||||
| @@ -53,6 +54,8 @@ func TestMailer(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestSMTPMailer(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	var _ Mailer = SMTPMailer("server", nil) | ||||
|  | ||||
| 	recovered := false | ||||
|   | ||||
| @@ -56,7 +56,7 @@ func (m mockClientStore) GetErr(key string) (string, error) { | ||||
| func (m mockClientStore) Put(key, val string) { m[key] = val } | ||||
| func (m mockClientStore) Del(key string)      { delete(m, key) } | ||||
|  | ||||
| func mockRequestContext(postKeyValues ...string) *Context { | ||||
| func mockRequestContext(ab *Authboss, postKeyValues ...string) *Context { | ||||
| 	keyValues := &bytes.Buffer{} | ||||
| 	for i := 0; i < len(postKeyValues); i += 2 { | ||||
| 		if i != 0 { | ||||
| @@ -71,7 +71,7 @@ func mockRequestContext(postKeyValues ...string) *Context { | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", "application/x-www-form-urlencoded") | ||||
|  | ||||
| 	ctx, err := ContextFromRequest(req) | ||||
| 	ctx, err := ab.ContextFromRequest(req) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|   | ||||
| @@ -9,7 +9,7 @@ var ModuleAttributes = make(AttributeMeta) | ||||
|  | ||||
| // Modularizer should be implemented by all the authboss modules. | ||||
| type Modularizer interface { | ||||
| 	Initialize() error | ||||
| 	Initialize(*Authboss) error | ||||
| 	Routes() RouteTable | ||||
| 	Storage() StorageOptions | ||||
| } | ||||
|   | ||||
| @@ -23,12 +23,13 @@ func testHandler(ctx *Context, w http.ResponseWriter, r *http.Request) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (t *testModule) Initialize() error       { return nil } | ||||
| func (t *testModule) Routes() RouteTable      { return t.r } | ||||
| func (t *testModule) Storage() StorageOptions { return t.s } | ||||
| func (t *testModule) Initialize(a *Authboss) error { return nil } | ||||
| func (t *testModule) Routes() RouteTable           { return t.r } | ||||
| func (t *testModule) Storage() StorageOptions      { return t.s } | ||||
|  | ||||
| func TestRegister(t *testing.T) { | ||||
| 	// RegisterModule called by TestMain. | ||||
| 	modules = make(map[string]Modularizer) | ||||
| 	RegisterModule("testmodule", testMod) | ||||
|  | ||||
| 	if _, ok := modules["testmodule"]; !ok { | ||||
| 		t.Error("Expected module to be saved.") | ||||
| @@ -40,7 +41,8 @@ func TestRegister(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestLoadedModules(t *testing.T) { | ||||
| 	// RegisterModule called by TestMain. | ||||
| 	modules = make(map[string]Modularizer) | ||||
| 	RegisterModule("testmodule", testMod) | ||||
|  | ||||
| 	loadedMods := LoadedModules() | ||||
| 	if len(loadedMods) != 1 { | ||||
|   | ||||
| @@ -30,7 +30,7 @@ func init() { | ||||
|  | ||||
| // Initialize module | ||||
| func (o *OAuth2) Initialize() error { | ||||
| 	if authboss.Cfg.OAuth2Storer == nil { | ||||
| 	if authboss.a.OAuth2Storer == nil { | ||||
| 		return errors.New("oauth2: need an OAuth2Storer") | ||||
| 	} | ||||
| 	return nil | ||||
| @@ -40,7 +40,7 @@ func (o *OAuth2) Initialize() error { | ||||
| func (o *OAuth2) Routes() authboss.RouteTable { | ||||
| 	routes := make(authboss.RouteTable) | ||||
|  | ||||
| 	for prov, cfg := range authboss.Cfg.OAuth2Providers { | ||||
| 	for prov, cfg := range authboss.a.OAuth2Providers { | ||||
| 		prov = strings.ToLower(prov) | ||||
|  | ||||
| 		init := fmt.Sprintf("/oauth2/%s", prov) | ||||
| @@ -49,11 +49,11 @@ func (o *OAuth2) Routes() authboss.RouteTable { | ||||
| 		routes[init] = oauthInit | ||||
| 		routes[callback] = oauthCallback | ||||
|  | ||||
| 		if len(authboss.Cfg.MountPath) > 0 { | ||||
| 			callback = path.Join(authboss.Cfg.MountPath, callback) | ||||
| 		if len(authboss.a.MountPath) > 0 { | ||||
| 			callback = path.Join(authboss.a.MountPath, callback) | ||||
| 		} | ||||
|  | ||||
| 		cfg.OAuth2Config.RedirectURL = authboss.Cfg.RootURL + callback | ||||
| 		a.OAuth2Config.RedirectURL = authboss.a.RootURL + callback | ||||
| 	} | ||||
|  | ||||
| 	routes["/oauth2/logout"] = logout | ||||
| @@ -75,7 +75,7 @@ func (o *OAuth2) Storage() authboss.StorageOptions { | ||||
|  | ||||
| func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { | ||||
| 	provider := strings.ToLower(filepath.Base(r.URL.Path)) | ||||
| 	cfg, ok := authboss.Cfg.OAuth2Providers[provider] | ||||
| 	cfg, ok := authboss.a.OAuth2Providers[provider] | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("OAuth2 provider %q not found", provider) | ||||
| 	} | ||||
| @@ -106,9 +106,9 @@ func oauthInit(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) er | ||||
| 		ctx.SessionStorer.Del(authboss.SessionOAuth2Params) | ||||
| 	} | ||||
|  | ||||
| 	url := cfg.OAuth2Config.AuthCodeURL(state) | ||||
| 	url := a.OAuth2Config.AuthCodeURL(state) | ||||
|  | ||||
| 	extraParams := cfg.AdditionalParams.Encode() | ||||
| 	extraParams := a.AdditionalParams.Encode() | ||||
| 	if len(extraParams) > 0 { | ||||
| 		url = fmt.Sprintf("%s&%s", url, extraParams) | ||||
| 	} | ||||
| @@ -140,18 +140,18 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request | ||||
|  | ||||
| 	hasErr := r.FormValue("error") | ||||
| 	if len(hasErr) > 0 { | ||||
| 		if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil { | ||||
| 		if err := authboss.a.Callbacks.FireAfter(authboss.EventOAuthFail, ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		return authboss.ErrAndRedirect{ | ||||
| 			Err:        errors.New(r.FormValue("error_reason")), | ||||
| 			Location:   authboss.Cfg.AuthLoginFailPath, | ||||
| 			Location:   authboss.a.AuthLoginFailPath, | ||||
| 			FlashError: fmt.Sprintf("%s login cancelled or failed.", strings.Title(provider)), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cfg, ok := authboss.Cfg.OAuth2Providers[provider] | ||||
| 	cfg, ok := authboss.a.OAuth2Providers[provider] | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("OAuth2 provider %q not found", provider) | ||||
| 	} | ||||
| @@ -165,12 +165,12 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request | ||||
|  | ||||
| 	// Get the code | ||||
| 	code := r.FormValue("code") | ||||
| 	token, err := exchanger(cfg.OAuth2Config, oauth2.NoContext, code) | ||||
| 	token, err := exchanger(a.OAuth2Config, oauth2.NoContext, code) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("Could not validate oauth2 code: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	user, err := cfg.Callback(*cfg.OAuth2Config, token) | ||||
| 	user, err := a.Callback(*cfg.OAuth2Config, token) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -189,7 +189,7 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request | ||||
| 		user[authboss.StoreOAuth2Refresh] = token.RefreshToken | ||||
| 	} | ||||
|  | ||||
| 	if err = authboss.Cfg.OAuth2Storer.PutOAuth(uid, provider, user); err != nil { | ||||
| 	if err = authboss.a.OAuth2Storer.PutOAuth(uid, provider, user); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @@ -197,13 +197,13 @@ func oauthCallback(ctx *authboss.Context, w http.ResponseWriter, r *http.Request | ||||
| 	ctx.SessionStorer.Put(authboss.SessionKey, fmt.Sprintf("%s;%s", uid, provider)) | ||||
| 	ctx.SessionStorer.Del(authboss.SessionHalfAuthKey) | ||||
|  | ||||
| 	if err = authboss.Cfg.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil { | ||||
| 	if err = authboss.a.Callbacks.FireAfter(authboss.EventOAuth, ctx); err != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ctx.SessionStorer.Del(authboss.SessionOAuth2Params) | ||||
|  | ||||
| 	redirect := authboss.Cfg.AuthLoginOKPath | ||||
| 	redirect := authboss.a.AuthLoginOKPath | ||||
| 	query := make(url.Values) | ||||
| 	for k, v := range values { | ||||
| 		switch k { | ||||
| @@ -231,7 +231,7 @@ func logout(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error | ||||
| 		ctx.CookieStorer.Del(authboss.CookieRemember) | ||||
| 		ctx.SessionStorer.Del(authboss.SessionLastAction) | ||||
|  | ||||
| 		response.Redirect(ctx, w, r, authboss.Cfg.AuthLogoutOKPath, "You have logged out", "", true) | ||||
| 		response.Redirect(ctx, w, r, authboss.a.AuthLogoutOKPath, "You have logged out", "", true) | ||||
| 	default: | ||||
| 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 	} | ||||
|   | ||||
| @@ -31,7 +31,7 @@ var testProviders = map[string]authboss.OAuth2Provider{ | ||||
|  | ||||
| func TestInitialize(t *testing.T) { | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.OAuth2Storer = mocks.NewMockStorer() | ||||
| 	authboss.a.OAuth2Storer = mocks.NewMockStorer() | ||||
| 	o := OAuth2{} | ||||
| 	if err := o.Initialize(); err != nil { | ||||
| 		t.Error(err) | ||||
| @@ -43,11 +43,11 @@ func TestRoutes(t *testing.T) { | ||||
| 	mount := "/auth" | ||||
|  | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.RootURL = root | ||||
| 	authboss.Cfg.MountPath = mount | ||||
| 	authboss.Cfg.OAuth2Providers = testProviders | ||||
| 	authboss.a.RootURL = root | ||||
| 	authboss.a.MountPath = mount | ||||
| 	authboss.a.OAuth2Providers = testProviders | ||||
|  | ||||
| 	googleCfg := authboss.Cfg.OAuth2Providers["google"].OAuth2Config | ||||
| 	googleCfg := authboss.a.OAuth2Providers["google"].OAuth2Config | ||||
| 	if 0 != len(googleCfg.RedirectURL) { | ||||
| 		t.Error("RedirectURL should not be set") | ||||
| 	} | ||||
| @@ -74,7 +74,7 @@ func TestOAuth2Init(t *testing.T) { | ||||
| 	cfg := authboss.NewConfig() | ||||
| 	session := mocks.NewMockClientStorer() | ||||
|  | ||||
| 	cfg.OAuth2Providers = testProviders | ||||
| 	a.OAuth2Providers = testProviders | ||||
| 	authboss.Cfg = cfg | ||||
|  | ||||
| 	r, _ := http.NewRequest("GET", "/oauth2/google?redir=/my/redirect%23lol&rm=true", nil) | ||||
| @@ -137,7 +137,7 @@ func TestOAuthSuccess(t *testing.T) { | ||||
| 		return fakeToken, nil | ||||
| 	} | ||||
|  | ||||
| 	cfg.OAuth2Providers = map[string]authboss.OAuth2Provider{ | ||||
| 	a.OAuth2Providers = map[string]authboss.OAuth2Provider{ | ||||
| 		"fake": authboss.OAuth2Provider{ | ||||
| 			OAuth2Config: &oauth2.Config{ | ||||
| 				ClientID:     `jazz`, | ||||
| @@ -168,8 +168,8 @@ func TestOAuthSuccess(t *testing.T) { | ||||
|  | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	ctx.SessionStorer = session | ||||
| 	cfg.OAuth2Storer = storer | ||||
| 	cfg.AuthLoginOKPath = "/fakeloginok" | ||||
| 	a.OAuth2Storer = storer | ||||
| 	a.AuthLoginOKPath = "/fakeloginok" | ||||
|  | ||||
| 	if err := oauthCallback(ctx, w, r); err != nil { | ||||
| 		t.Error(err) | ||||
| @@ -214,7 +214,7 @@ func TestOAuthXSRFFailure(t *testing.T) { | ||||
| 	session := mocks.NewMockClientStorer() | ||||
| 	session.Put(authboss.SessionOAuth2State, authboss.FormValueOAuth2State) | ||||
|  | ||||
| 	cfg.OAuth2Providers = testProviders | ||||
| 	a.OAuth2Providers = testProviders | ||||
| 	authboss.Cfg = cfg | ||||
|  | ||||
| 	values := url.Values{} | ||||
| @@ -234,7 +234,7 @@ func TestOAuthXSRFFailure(t *testing.T) { | ||||
| func TestOAuthFailure(t *testing.T) { | ||||
| 	cfg := authboss.NewConfig() | ||||
|  | ||||
| 	cfg.OAuth2Providers = testProviders | ||||
| 	a.OAuth2Providers = testProviders | ||||
| 	authboss.Cfg = cfg | ||||
|  | ||||
| 	values := url.Values{} | ||||
| @@ -260,7 +260,7 @@ func TestOAuthFailure(t *testing.T) { | ||||
|  | ||||
| func TestLogout(t *testing.T) { | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.AuthLogoutOKPath = "/dashboard" | ||||
| 	authboss.a.AuthLogoutOKPath = "/dashboard" | ||||
|  | ||||
| 	r, _ := http.NewRequest("GET", "/oauth2/google?", nil) | ||||
| 	w := httptest.NewRecorder() | ||||
| @@ -292,7 +292,7 @@ func TestLogout(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	location := w.Header().Get("Location") | ||||
| 	if location != authboss.Cfg.AuthLogoutOKPath { | ||||
| 	if location != authboss.a.AuthLogoutOKPath { | ||||
| 		t.Error("Redirect wrong:", location) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -26,8 +26,8 @@ type googleMeResponse struct { | ||||
| var clientGet = (*http.Client).Get | ||||
|  | ||||
| // Google is a callback appropriate for use with Google's OAuth2 configuration. | ||||
| func Google(cfg oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) { | ||||
| 	client := cfg.Client(oauth2.NoContext, token) | ||||
| func Google(a.oauth2.Config, token *oauth2.Token) (authboss.Attributes, error) { | ||||
| 	client := a.Client(oauth2.NoContext, token) | ||||
| 	resp, err := clientGet(client, googleInfoEndpoint) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|   | ||||
| @@ -63,32 +63,32 @@ type Recover struct { | ||||
|  | ||||
| // Initialize module | ||||
| func (r *Recover) Initialize() (err error) { | ||||
| 	if authboss.Cfg.Storer == nil { | ||||
| 	if authboss.a.Storer == nil { | ||||
| 		return errors.New("recover: Need a RecoverStorer") | ||||
| 	} | ||||
|  | ||||
| 	if _, ok := authboss.Cfg.Storer.(RecoverStorer); !ok { | ||||
| 	if _, ok := authboss.a.Storer.(RecoverStorer); !ok { | ||||
| 		return errors.New("recover: RecoverStorer required for recover functionality") | ||||
| 	} | ||||
|  | ||||
| 	if len(authboss.Cfg.XSRFName) == 0 { | ||||
| 	if len(authboss.a.XSRFName) == 0 { | ||||
| 		return errors.New("auth: XSRFName must be set") | ||||
| 	} | ||||
|  | ||||
| 	if authboss.Cfg.XSRFMaker == nil { | ||||
| 	if authboss.a.XSRFMaker == nil { | ||||
| 		return errors.New("auth: XSRFMaker must be defined") | ||||
| 	} | ||||
|  | ||||
| 	r.templates, err = response.LoadTemplates(authboss.Cfg.Layout, authboss.Cfg.ViewsPath, tplRecover, tplRecoverComplete) | ||||
| 	r.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplRecover, tplRecoverComplete) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	r.emailHTMLTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutHTMLEmail, authboss.Cfg.ViewsPath, tplInitHTMLEmail) | ||||
| 	r.emailHTMLTemplates, err = response.LoadTemplates(authboss.a.LayoutHTMLEmail, authboss.a.ViewsPath, tplInitHTMLEmail) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	r.emailTextTemplates, err = response.LoadTemplates(authboss.Cfg.LayoutTextEmail, authboss.Cfg.ViewsPath, tplInitTextEmail) | ||||
| 	r.emailTextTemplates, err = response.LoadTemplates(authboss.a.LayoutTextEmail, authboss.a.ViewsPath, tplInitTextEmail) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -107,7 +107,7 @@ func (r *Recover) Routes() authboss.RouteTable { | ||||
| // Storage requirements | ||||
| func (r *Recover) Storage() authboss.StorageOptions { | ||||
| 	return authboss.StorageOptions{ | ||||
| 		authboss.Cfg.PrimaryID:  authboss.String, | ||||
| 		authboss.a.PrimaryID:    authboss.String, | ||||
| 		authboss.StoreEmail:     authboss.String, | ||||
| 		authboss.StorePassword:  authboss.String, | ||||
| 		StoreRecoverToken:       authboss.String, | ||||
| @@ -119,31 +119,31 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite | ||||
| 	switch r.Method { | ||||
| 	case methodGET: | ||||
| 		data := authboss.NewHTMLData( | ||||
| 			"primaryID", authboss.Cfg.PrimaryID, | ||||
| 			"primaryID", authboss.a.PrimaryID, | ||||
| 			"primaryIDValue", "", | ||||
| 			"confirmPrimaryIDValue", "", | ||||
| 		) | ||||
|  | ||||
| 		return rec.templates.Render(ctx, w, r, tplRecover, data) | ||||
| 	case methodPOST: | ||||
| 		primaryID, _ := ctx.FirstPostFormValue(authboss.Cfg.PrimaryID) | ||||
| 		confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", authboss.Cfg.PrimaryID)) | ||||
| 		primaryID, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID) | ||||
| 		confirmPrimaryID, _ := ctx.FirstPostFormValue(fmt.Sprintf("confirm_%s", authboss.a.PrimaryID)) | ||||
|  | ||||
| 		errData := authboss.NewHTMLData( | ||||
| 			"primaryID", authboss.Cfg.PrimaryID, | ||||
| 			"primaryID", authboss.a.PrimaryID, | ||||
| 			"primaryIDValue", primaryID, | ||||
| 			"confirmPrimaryIDValue", confirmPrimaryID, | ||||
| 		) | ||||
|  | ||||
| 		policies := authboss.FilterValidators(authboss.Cfg.Policies, authboss.Cfg.PrimaryID) | ||||
| 		if validationErrs := ctx.Validate(policies, authboss.Cfg.PrimaryID, authboss.ConfirmPrefix+authboss.Cfg.PrimaryID).Map(); len(validationErrs) > 0 { | ||||
| 		policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID) | ||||
| 		if validationErrs := ctx.Validate(policies, authboss.a.PrimaryID, authboss.ConfirmPrefix+authboss.a.PrimaryID).Map(); len(validationErrs) > 0 { | ||||
| 			errData.MergeKV("errs", validationErrs) | ||||
| 			return rec.templates.Render(ctx, w, r, tplRecover, errData) | ||||
| 		} | ||||
|  | ||||
| 		// redirect to login when user not found to prevent username sniffing | ||||
| 		if err := ctx.LoadUser(primaryID); err == authboss.ErrUserNotFound { | ||||
| 			return authboss.ErrAndRedirect{err, authboss.Cfg.RecoverOKPath, recoverInitiateSuccessFlash, ""} | ||||
| 			return authboss.ErrAndRedirect{err, authboss.a.RecoverOKPath, recoverInitiateSuccessFlash, ""} | ||||
| 		} else if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -159,7 +159,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite | ||||
| 		} | ||||
|  | ||||
| 		ctx.User[StoreRecoverToken] = encodedChecksum | ||||
| 		ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(authboss.Cfg.RecoverTokenDuration) | ||||
| 		ctx.User[StoreRecoverTokenExpiry] = time.Now().Add(authboss.a.RecoverTokenDuration) | ||||
|  | ||||
| 		if err := ctx.SaveUser(); err != nil { | ||||
| 			return err | ||||
| @@ -168,7 +168,7 @@ func (rec *Recover) startHandlerFunc(ctx *authboss.Context, w http.ResponseWrite | ||||
| 		goRecoverEmail(rec, email, encodedToken) | ||||
|  | ||||
| 		ctx.SessionStorer.Put(authboss.FlashSuccessKey, recoverInitiateSuccessFlash) | ||||
| 		response.Redirect(ctx, w, r, authboss.Cfg.RecoverOKPath, "", "", true) | ||||
| 		response.Redirect(ctx, w, r, authboss.a.RecoverOKPath, "", "", true) | ||||
| 	default: | ||||
| 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 	} | ||||
| @@ -191,17 +191,17 @@ var goRecoverEmail = func(r *Recover, to, encodedToken string) { | ||||
| } | ||||
|  | ||||
| func (r *Recover) sendRecoverEmail(to, encodedToken string) { | ||||
| 	p := path.Join(authboss.Cfg.MountPath, "recover/complete") | ||||
| 	url := fmt.Sprintf("%s%s?token=%s", authboss.Cfg.RootURL, p, encodedToken) | ||||
| 	p := path.Join(authboss.a.MountPath, "recover/complete") | ||||
| 	url := fmt.Sprintf("%s%s?token=%s", authboss.a.RootURL, p, encodedToken) | ||||
|  | ||||
| 	email := authboss.Email{ | ||||
| 		To:      []string{to}, | ||||
| 		From:    authboss.Cfg.EmailFrom, | ||||
| 		Subject: authboss.Cfg.EmailSubjectPrefix + "Password Reset", | ||||
| 		From:    authboss.a.EmailFrom, | ||||
| 		Subject: authboss.a.EmailSubjectPrefix + "Password Reset", | ||||
| 	} | ||||
|  | ||||
| 	if err := response.Email(email, r.emailHTMLTemplates, tplInitHTMLEmail, r.emailTextTemplates, tplInitTextEmail, url); err != nil { | ||||
| 		fmt.Fprintln(authboss.Cfg.LogWriter, "recover: failed to send recover email:", err) | ||||
| 		fmt.Fprintln(authboss.a.LogWriter, "recover: failed to send recover email:", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -227,7 +227,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit | ||||
| 		password, _ := ctx.FirstPostFormValue("password") | ||||
| 		confirmPassword, _ := ctx.FirstPostFormValue("confirmPassword") | ||||
|  | ||||
| 		policies := authboss.FilterValidators(authboss.Cfg.Policies, "password") | ||||
| 		policies := authboss.FilterValidators(authboss.a.Policies, "password") | ||||
| 		if validationErrs := ctx.Validate(policies, authboss.StorePassword, authboss.ConfirmPrefix+authboss.StorePassword).Map(); len(validationErrs) > 0 { | ||||
| 			data := authboss.NewHTMLData( | ||||
| 				"token", token, | ||||
| @@ -242,7 +242,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), authboss.Cfg.BCryptCost) | ||||
| 		encryptedPassword, err := bcrypt.GenerateFromPassword([]byte(password), authboss.a.BCryptCost) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -252,7 +252,7 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit | ||||
| 		var nullTime time.Time | ||||
| 		ctx.User[StoreRecoverTokenExpiry] = nullTime | ||||
|  | ||||
| 		primaryID, err := ctx.User.StringErr(authboss.Cfg.PrimaryID) | ||||
| 		primaryID, err := ctx.User.StringErr(authboss.a.PrimaryID) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -261,12 +261,12 @@ func (r *Recover) completeHandlerFunc(ctx *authboss.Context, w http.ResponseWrit | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil { | ||||
| 		if err := authboss.a.Callbacks.FireAfter(authboss.EventPasswordReset, ctx); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		ctx.SessionStorer.Put(authboss.SessionKey, primaryID) | ||||
| 		response.Redirect(ctx, w, req, authboss.Cfg.AuthLoginOKPath, "", "", true) | ||||
| 		response.Redirect(ctx, w, req, authboss.a.AuthLoginOKPath, "", "", true) | ||||
| 	default: | ||||
| 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 	} | ||||
| @@ -287,7 +287,7 @@ func verifyToken(ctx *authboss.Context) (attrs authboss.Attributes, err error) { | ||||
| 	} | ||||
|  | ||||
| 	sum := md5.Sum(decoded) | ||||
| 	storer := authboss.Cfg.Storer.(RecoverStorer) | ||||
| 	storer := authboss.a.Storer.(RecoverStorer) | ||||
|  | ||||
| 	userInter, err := storer.RecoverUser(base64.StdEncoding.EncodeToString(sum[:])) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -26,16 +26,16 @@ func testSetup() (r *Recover, s *mocks.MockStorer, l *bytes.Buffer) { | ||||
| 	l = &bytes.Buffer{} | ||||
|  | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.Cfg.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`)) | ||||
| 	authboss.Cfg.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.Cfg.Storer = s | ||||
| 	authboss.Cfg.XSRFName = "xsrf" | ||||
| 	authboss.Cfg.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { | ||||
| 	authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.a.LayoutHTMLEmail = template.Must(template.New("").Parse(`<strong>{{template "authboss" .}}</strong>`)) | ||||
| 	authboss.a.LayoutTextEmail = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.a.Storer = s | ||||
| 	authboss.a.XSRFName = "xsrf" | ||||
| 	authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { | ||||
| 		return "xsrfvalue" | ||||
| 	} | ||||
| 	authboss.Cfg.PrimaryID = authboss.StoreUsername | ||||
| 	authboss.Cfg.LogWriter = l | ||||
| 	authboss.a.PrimaryID = authboss.StoreUsername | ||||
| 	authboss.a.LogWriter = l | ||||
|  | ||||
| 	r = &Recover{} | ||||
| 	if err := r.Initialize(); err != nil { | ||||
| @@ -62,8 +62,8 @@ func TestRecover(t *testing.T) { | ||||
| 	r, _, _ := testSetup() | ||||
|  | ||||
| 	storage := r.Storage() | ||||
| 	if storage[authboss.Cfg.PrimaryID] != authboss.String { | ||||
| 		t.Error("Expected storage KV:", authboss.Cfg.PrimaryID, authboss.String) | ||||
| 	if storage[authboss.a.PrimaryID] != authboss.String { | ||||
| 		t.Error("Expected storage KV:", authboss.a.PrimaryID, authboss.String) | ||||
| 	} | ||||
| 	if storage[authboss.StoreEmail] != authboss.String { | ||||
| 		t.Error("Expected storage KV:", authboss.StoreEmail, authboss.String) | ||||
| @@ -103,10 +103,10 @@ func TestRecover_startHandlerFunc_GET(t *testing.T) { | ||||
| 	if !strings.Contains(body, `<form action="recover"`) { | ||||
| 		t.Error("Should have rendered a form") | ||||
| 	} | ||||
| 	if !strings.Contains(body, `name="`+authboss.Cfg.PrimaryID) { | ||||
| 	if !strings.Contains(body, `name="`+authboss.a.PrimaryID) { | ||||
| 		t.Error("Form should contain the primary ID field") | ||||
| 	} | ||||
| 	if !strings.Contains(body, `name="confirm_`+authboss.Cfg.PrimaryID) { | ||||
| 	if !strings.Contains(body, `name="confirm_`+authboss.a.PrimaryID) { | ||||
| 		t.Error("Form should contain the confirm primary ID field") | ||||
| 	} | ||||
| } | ||||
| @@ -141,7 +141,7 @@ func TestRecover_startHandlerFunc_POST_UserNotFound(t *testing.T) { | ||||
| 		t.Error("Expected ErrAndRedirect error") | ||||
| 	} | ||||
|  | ||||
| 	if rerr.Location != authboss.Cfg.RecoverOKPath { | ||||
| 	if rerr.Location != authboss.a.RecoverOKPath { | ||||
| 		t.Error("Unexpected location:", rerr.Location) | ||||
| 	} | ||||
|  | ||||
| @@ -187,7 +187,7 @@ func TestRecover_startHandlerFunc_POST(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	loc := w.Header().Get("Location") | ||||
| 	if loc != authboss.Cfg.RecoverOKPath { | ||||
| 	if loc != authboss.a.RecoverOKPath { | ||||
| 		t.Error("Unexpected location:", loc) | ||||
| 	} | ||||
|  | ||||
| @@ -237,7 +237,7 @@ func TestRecover_sendRecoverMail_FailToSend(t *testing.T) { | ||||
|  | ||||
| 	mailer := mocks.NewMockMailer() | ||||
| 	mailer.SendErr = "failed to send" | ||||
| 	authboss.Cfg.Mailer = mailer | ||||
| 	authboss.a.Mailer = mailer | ||||
|  | ||||
| 	a.sendRecoverEmail("", "") | ||||
|  | ||||
| @@ -250,9 +250,9 @@ func TestRecover_sendRecoverEmail(t *testing.T) { | ||||
| 	a, _, _ := testSetup() | ||||
|  | ||||
| 	mailer := mocks.NewMockMailer() | ||||
| 	authboss.Cfg.EmailSubjectPrefix = "foo " | ||||
| 	authboss.Cfg.RootURL = "bar" | ||||
| 	authboss.Cfg.Mailer = mailer | ||||
| 	authboss.a.EmailSubjectPrefix = "foo " | ||||
| 	authboss.a.RootURL = "bar" | ||||
| 	authboss.a.Mailer = mailer | ||||
|  | ||||
| 	a.sendRecoverEmail("a@b.c", "abc=") | ||||
| 	if len(mailer.Last.To) != 1 { | ||||
| @@ -265,7 +265,7 @@ func TestRecover_sendRecoverEmail(t *testing.T) { | ||||
| 		t.Error("Unexpected subject:", mailer.Last.Subject) | ||||
| 	} | ||||
|  | ||||
| 	url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.Cfg.RootURL) | ||||
| 	url := fmt.Sprintf("%s/recover/complete?token=abc=", authboss.a.RootURL) | ||||
| 	if !strings.Contains(mailer.Last.HTMLBody, url) { | ||||
| 		t.Error("Expected HTMLBody to contain url:", url) | ||||
| 	} | ||||
| @@ -377,12 +377,12 @@ func TestRecover_completeHandlerFunc_POST_VerificationFails(t *testing.T) { | ||||
| func TestRecover_completeHandlerFunc_POST(t *testing.T) { | ||||
| 	rec, storer, _ := testSetup() | ||||
|  | ||||
| 	storer.Users["john"] = authboss.Attributes{authboss.Cfg.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"} | ||||
| 	storer.Users["john"] = authboss.Attributes{authboss.a.PrimaryID: "john", StoreRecoverToken: testStdBase64Token, StoreRecoverTokenExpiry: time.Now().Add(1 * time.Hour), authboss.StorePassword: "asdf"} | ||||
|  | ||||
| 	cbCalled := false | ||||
|  | ||||
| 	authboss.Cfg.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error { | ||||
| 	authboss.a.Callbacks = authboss.NewCallbacks() | ||||
| 	authboss.a.Callbacks.After(authboss.EventPasswordReset, func(_ *authboss.Context) error { | ||||
| 		cbCalled = true | ||||
| 		return nil | ||||
| 	}) | ||||
| @@ -421,7 +421,7 @@ func TestRecover_completeHandlerFunc_POST(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	loc := w.Header().Get("Location") | ||||
| 	if loc != authboss.Cfg.AuthLogoutOKPath { | ||||
| 	if loc != authboss.a.AuthLogoutOKPath { | ||||
| 		t.Error("Unexpected location:", loc) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -33,15 +33,15 @@ type Register struct { | ||||
|  | ||||
| // Initialize the module. | ||||
| func (r *Register) Initialize() (err error) { | ||||
| 	if authboss.Cfg.Storer == nil { | ||||
| 	if authboss.a.Storer == nil { | ||||
| 		return errors.New("register: Need a RegisterStorer") | ||||
| 	} | ||||
|  | ||||
| 	if _, ok := authboss.Cfg.Storer.(RegisterStorer); !ok { | ||||
| 	if _, ok := authboss.a.Storer.(RegisterStorer); !ok { | ||||
| 		return errors.New("register: RegisterStorer required for register functionality") | ||||
| 	} | ||||
|  | ||||
| 	if r.templates, err = response.LoadTemplates(authboss.Cfg.Layout, authboss.Cfg.ViewsPath, tplRegister); err != nil { | ||||
| 	if r.templates, err = response.LoadTemplates(authboss.a.Layout, authboss.a.ViewsPath, tplRegister); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| @@ -58,7 +58,7 @@ func (r *Register) Routes() authboss.RouteTable { | ||||
| // Storage returns storage requirements. | ||||
| func (r *Register) Storage() authboss.StorageOptions { | ||||
| 	return authboss.StorageOptions{ | ||||
| 		authboss.Cfg.PrimaryID: authboss.String, | ||||
| 		authboss.a.PrimaryID:   authboss.String, | ||||
| 		authboss.StorePassword: authboss.String, | ||||
| 	} | ||||
| } | ||||
| @@ -67,7 +67,7 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite | ||||
| 	switch r.Method { | ||||
| 	case "GET": | ||||
| 		data := authboss.HTMLData{ | ||||
| 			"primaryID":      authboss.Cfg.PrimaryID, | ||||
| 			"primaryID":      authboss.a.PrimaryID, | ||||
| 			"primaryIDValue": "", | ||||
| 		} | ||||
| 		return reg.templates.Render(ctx, w, r, tplRegister, data) | ||||
| @@ -78,15 +78,15 @@ func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWrite | ||||
| } | ||||
|  | ||||
| func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error { | ||||
| 	key, _ := ctx.FirstPostFormValue(authboss.Cfg.PrimaryID) | ||||
| 	key, _ := ctx.FirstPostFormValue(authboss.a.PrimaryID) | ||||
| 	password, _ := ctx.FirstPostFormValue(authboss.StorePassword) | ||||
|  | ||||
| 	policies := authboss.FilterValidators(authboss.Cfg.Policies, authboss.Cfg.PrimaryID, authboss.StorePassword) | ||||
| 	validationErrs := ctx.Validate(policies, authboss.Cfg.ConfirmFields...) | ||||
| 	policies := authboss.FilterValidators(authboss.a.Policies, authboss.a.PrimaryID, authboss.StorePassword) | ||||
| 	validationErrs := ctx.Validate(policies, authboss.a.ConfirmFields...) | ||||
|  | ||||
| 	if len(validationErrs) != 0 { | ||||
| 		data := authboss.HTMLData{ | ||||
| 			"primaryID":      authboss.Cfg.PrimaryID, | ||||
| 			"primaryID":      authboss.a.PrimaryID, | ||||
| 			"primaryIDValue": key, | ||||
| 			"errs":           validationErrs.Map(), | ||||
| 		} | ||||
| @@ -99,30 +99,30 @@ func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseW | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	pass, err := bcrypt.GenerateFromPassword([]byte(password), authboss.Cfg.BCryptCost) | ||||
| 	pass, err := bcrypt.GenerateFromPassword([]byte(password), authboss.a.BCryptCost) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	attr[authboss.Cfg.PrimaryID] = key | ||||
| 	attr[authboss.a.PrimaryID] = key | ||||
| 	attr[authboss.StorePassword] = string(pass) | ||||
| 	ctx.User = attr | ||||
|  | ||||
| 	if err := authboss.Cfg.Storer.(RegisterStorer).Create(key, attr); err != nil { | ||||
| 	if err := authboss.a.Storer.(RegisterStorer).Create(key, attr); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if err := authboss.Cfg.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil { | ||||
| 	if err := authboss.a.Callbacks.FireAfter(authboss.EventRegister, ctx); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if authboss.IsLoaded("confirm") { | ||||
| 		response.Redirect(ctx, w, r, authboss.Cfg.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true) | ||||
| 		response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	ctx.SessionStorer.Put(authboss.SessionKey, key) | ||||
| 	response.Redirect(ctx, w, r, authboss.Cfg.RegisterOKPath, "Account successfully created, you are now logged in.", "", true) | ||||
| 	response.Redirect(ctx, w, r, authboss.a.RegisterOKPath, "Account successfully created, you are now logged in.", "", true) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -15,14 +15,14 @@ import ( | ||||
|  | ||||
| func setup() *Register { | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.RegisterOKPath = "/regsuccess" | ||||
| 	authboss.Cfg.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.Cfg.XSRFName = "xsrf" | ||||
| 	authboss.Cfg.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { | ||||
| 	authboss.a.RegisterOKPath = "/regsuccess" | ||||
| 	authboss.a.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) | ||||
| 	authboss.a.XSRFName = "xsrf" | ||||
| 	authboss.a.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { | ||||
| 		return "xsrfvalue" | ||||
| 	} | ||||
| 	authboss.Cfg.ConfirmFields = []string{"password", "confirm_password"} | ||||
| 	authboss.Cfg.Storer = mocks.NewMockStorer() | ||||
| 	authboss.a.ConfirmFields = []string{"password", "confirm_password"} | ||||
| 	authboss.a.Storer = mocks.NewMockStorer() | ||||
|  | ||||
| 	reg := Register{} | ||||
| 	if err := reg.Initialize(); err != nil { | ||||
| @@ -34,7 +34,7 @@ func setup() *Register { | ||||
|  | ||||
| func TestRegister(t *testing.T) { | ||||
| 	authboss.Cfg = authboss.NewConfig() | ||||
| 	authboss.Cfg.Storer = mocks.NewMockStorer() | ||||
| 	authboss.a.Storer = mocks.NewMockStorer() | ||||
| 	r := Register{} | ||||
|  | ||||
| 	if err := r.Initialize(); err != nil { | ||||
| @@ -46,7 +46,7 @@ func TestRegister(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	sto := r.Storage() | ||||
| 	if sto[authboss.Cfg.PrimaryID] != authboss.String { | ||||
| 	if sto[authboss.a.PrimaryID] != authboss.String { | ||||
| 		t.Error("Wanted primary ID to be a string.") | ||||
| 	} | ||||
| 	if sto[authboss.StorePassword] != authboss.String { | ||||
| @@ -76,7 +76,7 @@ func TestRegisterGet(t *testing.T) { | ||||
|  | ||||
| 	if str := w.Body.String(); !strings.Contains(str, "<form") { | ||||
| 		t.Error("It should have rendered a nice form:", str) | ||||
| 	} else if !strings.Contains(str, `name="`+authboss.Cfg.PrimaryID) { | ||||
| 	} else if !strings.Contains(str, `name="`+authboss.a.PrimaryID) { | ||||
| 		t.Error("Form should contain the primary ID:", str) | ||||
| 	} | ||||
| } | ||||
| @@ -88,7 +88,7 @@ func TestRegisterPostValidationErrs(t *testing.T) { | ||||
| 	vals := url.Values{} | ||||
|  | ||||
| 	email := "email@address.com" | ||||
| 	vals.Set(authboss.Cfg.PrimaryID, email) | ||||
| 	vals.Set(authboss.a.PrimaryID, email) | ||||
| 	vals.Set(authboss.StorePassword, "pass") | ||||
| 	vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass2") | ||||
|  | ||||
| @@ -113,7 +113,7 @@ func TestRegisterPostValidationErrs(t *testing.T) { | ||||
| 		t.Error("Confirm password should have an error:", str) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := authboss.Cfg.Storer.Get(email); err != authboss.ErrUserNotFound { | ||||
| 	if _, err := authboss.a.Storer.Get(email); err != authboss.ErrUserNotFound { | ||||
| 		t.Error("The user should not have been saved.") | ||||
| 	} | ||||
| } | ||||
| @@ -125,7 +125,7 @@ func TestRegisterPostSuccess(t *testing.T) { | ||||
| 	vals := url.Values{} | ||||
|  | ||||
| 	email := "email@address.com" | ||||
| 	vals.Set(authboss.Cfg.PrimaryID, email) | ||||
| 	vals.Set(authboss.a.PrimaryID, email) | ||||
| 	vals.Set(authboss.StorePassword, "pass") | ||||
| 	vals.Set(authboss.ConfirmPrefix+authboss.StorePassword, "pass") | ||||
|  | ||||
| @@ -142,17 +142,17 @@ func TestRegisterPostSuccess(t *testing.T) { | ||||
| 		t.Error("It should have written a redirect:", w.Code) | ||||
| 	} | ||||
|  | ||||
| 	if loc := w.Header().Get("Location"); loc != authboss.Cfg.RegisterOKPath { | ||||
| 	if loc := w.Header().Get("Location"); loc != authboss.a.RegisterOKPath { | ||||
| 		t.Error("Redirected to the wrong location", loc) | ||||
| 	} | ||||
|  | ||||
| 	user, err := authboss.Cfg.Storer.Get(email) | ||||
| 	user, err := authboss.a.Storer.Get(email) | ||||
| 	if err == authboss.ErrUserNotFound { | ||||
| 		t.Error("The user have been saved.") | ||||
| 	} | ||||
|  | ||||
| 	attrs := authboss.Unbind(user) | ||||
| 	if e, err := attrs.StringErr(authboss.Cfg.PrimaryID); err != nil { | ||||
| 	if e, err := attrs.StringErr(authboss.a.PrimaryID); err != nil { | ||||
| 		t.Error(err) | ||||
| 	} else if e != email { | ||||
| 		t.Errorf("Email was not set properly, want: %s, got: %s", email, e) | ||||
|   | ||||
| @@ -47,20 +47,20 @@ type Remember struct{} | ||||
|  | ||||
| // Initialize module | ||||
| func (r *Remember) Initialize() error { | ||||
| 	if authboss.Cfg.Storer == nil && authboss.Cfg.OAuth2Storer == nil { | ||||
| 	if authboss.a.Storer == nil && authboss.a.OAuth2Storer == nil { | ||||
| 		return errors.New("remember: Need a RememberStorer") | ||||
| 	} | ||||
|  | ||||
| 	if _, ok := authboss.Cfg.Storer.(RememberStorer); !ok { | ||||
| 		if _, ok := authboss.Cfg.OAuth2Storer.(RememberStorer); !ok { | ||||
| 	if _, ok := authboss.a.Storer.(RememberStorer); !ok { | ||||
| 		if _, ok := authboss.a.OAuth2Storer.(RememberStorer); !ok { | ||||
| 			return errors.New("remember: RememberStorer required for remember functionality") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	authboss.Cfg.Callbacks.Before(authboss.EventGetUserSession, r.auth) | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventAuth, r.afterAuth) | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventOAuth, r.afterOAuth) | ||||
| 	authboss.Cfg.Callbacks.After(authboss.EventPasswordReset, r.afterPassword) | ||||
| 	authboss.a.Callbacks.Before(authboss.EventGetUserSession, r.auth) | ||||
| 	authboss.a.Callbacks.After(authboss.EventAuth, r.afterAuth) | ||||
| 	authboss.a.Callbacks.After(authboss.EventOAuth, r.afterOAuth) | ||||
| 	authboss.a.Callbacks.After(authboss.EventPasswordReset, r.afterPassword) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
| @@ -73,7 +73,7 @@ func (r *Remember) Routes() authboss.RouteTable { | ||||
| // Storage requirements | ||||
| func (r *Remember) Storage() authboss.StorageOptions { | ||||
| 	return authboss.StorageOptions{ | ||||
| 		authboss.Cfg.PrimaryID: authboss.String, | ||||
| 		authboss.a.PrimaryID: authboss.String, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -87,7 +87,7 @@ func (r *Remember) afterAuth(ctx *authboss.Context) error { | ||||
| 		return errUserMissing | ||||
| 	} | ||||
|  | ||||
| 	key, err := ctx.User.StringErr(authboss.Cfg.PrimaryID) | ||||
| 	key, err := ctx.User.StringErr(authboss.a.PrimaryID) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -146,7 +146,7 @@ func (r *Remember) afterPassword(ctx *authboss.Context) error { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	id, ok := ctx.User.String(authboss.Cfg.PrimaryID) | ||||
| 	id, ok := ctx.User.String(authboss.a.PrimaryID) | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| 	} | ||||
| @@ -154,8 +154,8 @@ func (r *Remember) afterPassword(ctx *authboss.Context) error { | ||||
| 	ctx.CookieStorer.Del(authboss.CookieRemember) | ||||
|  | ||||
| 	var storer RememberStorer | ||||
| 	if storer, ok = authboss.Cfg.Storer.(RememberStorer); !ok { | ||||
| 		if storer, ok = authboss.Cfg.OAuth2Storer.(RememberStorer); !ok { | ||||
| 	if storer, ok = authboss.a.Storer.(RememberStorer); !ok { | ||||
| 		if storer, ok = authboss.a.OAuth2Storer.(RememberStorer); !ok { | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
| @@ -181,8 +181,8 @@ func (r *Remember) new(cstorer authboss.ClientStorer, storageKey string) (string | ||||
|  | ||||
| 	var storer RememberStorer | ||||
| 	var ok bool | ||||
| 	if storer, ok = authboss.Cfg.Storer.(RememberStorer); !ok { | ||||
| 		storer, ok = authboss.Cfg.OAuth2Storer.(RememberStorer) | ||||
| 	if storer, ok = authboss.a.Storer.(RememberStorer); !ok { | ||||
| 		storer, ok = authboss.a.OAuth2Storer.(RememberStorer) | ||||
| 	} | ||||
|  | ||||
| 	// Save the token in the DB | ||||
| @@ -226,8 +226,8 @@ func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) { | ||||
| 	sum := md5.Sum(token) | ||||
|  | ||||
| 	var storer RememberStorer | ||||
| 	if storer, ok = authboss.Cfg.Storer.(RememberStorer); !ok { | ||||
| 		storer, ok = authboss.Cfg.OAuth2Storer.(RememberStorer) | ||||
| 	if storer, ok = authboss.a.Storer.(RememberStorer); !ok { | ||||
| 		storer, ok = authboss.a.OAuth2Storer.(RememberStorer) | ||||
| 	} | ||||
|  | ||||
| 	err = storer.UseToken(givenKey, base64.StdEncoding.EncodeToString(sum[:])) | ||||
|   | ||||
| @@ -19,13 +19,13 @@ func TestInitialize(t *testing.T) { | ||||
| 		t.Error("Expected error about token storers.") | ||||
| 	} | ||||
|  | ||||
| 	authboss.Cfg.Storer = mocks.MockFailStorer{} | ||||
| 	authboss.a.Storer = mocks.MockFailStorer{} | ||||
| 	err = r.Initialize() | ||||
| 	if err == nil { | ||||
| 		t.Error("Expected error about token storers.") | ||||
| 	} | ||||
|  | ||||
| 	authboss.Cfg.Storer = mocks.NewMockStorer() | ||||
| 	authboss.a.Storer = mocks.NewMockStorer() | ||||
| 	err = r.Initialize() | ||||
| 	if err != nil { | ||||
| 		t.Error("Unexpected error:", err) | ||||
| @@ -36,7 +36,7 @@ func TestAfterAuth(t *testing.T) { | ||||
| 	r := Remember{} | ||||
| 	authboss.NewConfig() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
|  | ||||
| 	cookies := mocks.NewMockClientStorer() | ||||
| 	session := mocks.NewMockClientStorer() | ||||
| @@ -54,7 +54,7 @@ func TestAfterAuth(t *testing.T) { | ||||
|  | ||||
| 	ctx.SessionStorer = session | ||||
| 	ctx.CookieStorer = cookies | ||||
| 	ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: "test@email.com"} | ||||
| 	ctx.User = authboss.Attributes{authboss.a.PrimaryID: "test@email.com"} | ||||
|  | ||||
| 	if err := r.afterAuth(ctx); err != nil { | ||||
| 		t.Error(err) | ||||
| @@ -69,7 +69,7 @@ func TestAfterOAuth(t *testing.T) { | ||||
| 	r := Remember{} | ||||
| 	authboss.NewConfig() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
|  | ||||
| 	cookies := mocks.NewMockClientStorer() | ||||
| 	session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`) | ||||
| @@ -108,14 +108,14 @@ func TestAfterPasswordReset(t *testing.T) { | ||||
| 	id := "test@email.com" | ||||
|  | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
| 	session := mocks.NewMockClientStorer() | ||||
| 	cookies := mocks.NewMockClientStorer() | ||||
| 	storer.Tokens[id] = []string{"one", "two"} | ||||
| 	cookies.Values[authboss.CookieRemember] = "token" | ||||
|  | ||||
| 	ctx := authboss.NewContext() | ||||
| 	ctx.User = authboss.Attributes{authboss.Cfg.PrimaryID: id} | ||||
| 	ctx.User = authboss.Attributes{authboss.a.PrimaryID: id} | ||||
| 	ctx.SessionStorer = session | ||||
| 	ctx.CookieStorer = cookies | ||||
|  | ||||
| @@ -136,7 +136,7 @@ func TestNew(t *testing.T) { | ||||
| 	r := &Remember{} | ||||
| 	authboss.NewConfig() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
| 	cookies := mocks.NewMockClientStorer() | ||||
|  | ||||
| 	key := "tester" | ||||
| @@ -165,7 +165,7 @@ func TestAuth(t *testing.T) { | ||||
| 	r := &Remember{} | ||||
| 	authboss.NewConfig() | ||||
| 	storer := mocks.NewMockStorer() | ||||
| 	authboss.Cfg.Storer = storer | ||||
| 	authboss.a.Storer = storer | ||||
|  | ||||
| 	cookies := mocks.NewMockClientStorer() | ||||
| 	session := mocks.NewMockClientStorer() | ||||
|   | ||||
							
								
								
									
										29
									
								
								router.go
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								router.go
									
									
									
									
									
								
							| @@ -14,19 +14,19 @@ type HandlerFunc func(*Context, http.ResponseWriter, *http.Request) error | ||||
| type RouteTable map[string]HandlerFunc | ||||
|  | ||||
| // NewRouter returns a router to be mounted at some mountpoint. | ||||
| func NewRouter() http.Handler { | ||||
| func (a *Authboss) NewRouter() http.Handler { | ||||
| 	mux := http.NewServeMux() | ||||
|  | ||||
| 	for name, mod := range modules { | ||||
| 		for route, handler := range mod.Routes() { | ||||
| 			fmt.Fprintf(Cfg.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(Cfg.MountPath, route)) | ||||
| 			mux.Handle(path.Join(Cfg.MountPath, route), contextRoute{handler}) | ||||
| 			fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route)) | ||||
| 			mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler}) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		if Cfg.NotFoundHandler != nil { | ||||
| 			Cfg.NotFoundHandler.ServeHTTP(w, r) | ||||
| 		if a.NotFoundHandler != nil { | ||||
| 			a.NotFoundHandler.ServeHTTP(w, r) | ||||
| 		} else { | ||||
| 			w.WriteHeader(http.StatusNotFound) | ||||
| 			io.WriteString(w, "404 Page not found") | ||||
| @@ -37,25 +37,26 @@ func NewRouter() http.Handler { | ||||
| } | ||||
|  | ||||
| type contextRoute struct { | ||||
| 	*Authboss | ||||
| 	fn HandlerFunc | ||||
| } | ||||
|  | ||||
| func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 	ctx, err := ContextFromRequest(r) | ||||
| 	ctx, err := c.Authboss.ContextFromRequest(r) | ||||
| 	if err != nil { | ||||
| 		fmt.Fprintf(Cfg.LogWriter, "route: Malformed request, could not create context: %v", err) | ||||
| 		fmt.Fprintf(c.LogWriter, "route: Malformed request, could not create context: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	ctx.CookieStorer = clientStoreWrapper{Cfg.CookieStoreMaker(w, r)} | ||||
| 	ctx.SessionStorer = clientStoreWrapper{Cfg.SessionStoreMaker(w, r)} | ||||
| 	ctx.CookieStorer = clientStoreWrapper{c.CookieStoreMaker(w, r)} | ||||
| 	ctx.SessionStorer = clientStoreWrapper{c.SessionStoreMaker(w, r)} | ||||
|  | ||||
| 	err = c.fn(ctx, w, r) | ||||
| 	if err == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	fmt.Fprintf(Cfg.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err) | ||||
| 	fmt.Fprintf(c.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err) | ||||
|  | ||||
| 	switch e := err.(type) { | ||||
| 	case ErrAndRedirect: | ||||
| @@ -67,15 +68,15 @@ func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 		} | ||||
| 		http.Redirect(w, r, e.Location, http.StatusFound) | ||||
| 	case ClientDataErr: | ||||
| 		if Cfg.BadRequestHandler != nil { | ||||
| 			Cfg.BadRequestHandler.ServeHTTP(w, r) | ||||
| 		if c.BadRequestHandler != nil { | ||||
| 			c.BadRequestHandler.ServeHTTP(w, r) | ||||
| 		} else { | ||||
| 			w.WriteHeader(http.StatusBadRequest) | ||||
| 			io.WriteString(w, "400 Bad request") | ||||
| 		} | ||||
| 	default: | ||||
| 		if Cfg.ErrorHandler != nil { | ||||
| 			Cfg.ErrorHandler.ServeHTTP(w, r) | ||||
| 		if c.ErrorHandler != nil { | ||||
| 			c.ErrorHandler.ServeHTTP(w, r) | ||||
| 		} else { | ||||
| 			w.WriteHeader(http.StatusInternalServerError) | ||||
| 			io.WriteString(w, "500 An error has occurred") | ||||
|   | ||||
| @@ -9,34 +9,31 @@ import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type testRouterMod struct { | ||||
| 	handler HandlerFunc | ||||
| 	routes  RouteTable | ||||
| type testRouterModule struct { | ||||
| 	routes RouteTable | ||||
| } | ||||
|  | ||||
| func (t testRouterMod) Initialize() error       { return nil } | ||||
| func (t testRouterMod) Routes() RouteTable      { return t.routes } | ||||
| func (t testRouterMod) Storage() StorageOptions { return nil } | ||||
| func (t testRouterModule) Initialize(ab *Authboss) error { return nil } | ||||
| func (t testRouterModule) Routes() RouteTable            { return t.routes } | ||||
| func (t testRouterModule) Storage() StorageOptions       { return nil } | ||||
|  | ||||
| func testRouterSetup() (http.Handler, *bytes.Buffer) { | ||||
| 	Cfg = NewConfig() | ||||
| 	Cfg.MountPath = "/prefix" | ||||
| 	Cfg.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } | ||||
| 	Cfg.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } | ||||
| func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) { | ||||
| 	ab := New() | ||||
| 	ab.MountPath = "/prefix" | ||||
| 	ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } | ||||
| 	ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} } | ||||
| 	logger := &bytes.Buffer{} | ||||
| 	Cfg.LogWriter = logger | ||||
| 	ab.LogWriter = logger | ||||
|  | ||||
| 	return NewRouter(), logger | ||||
| 	return ab, ab.NewRouter(), logger | ||||
| } | ||||
|  | ||||
| // testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel | ||||
| func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) { | ||||
| 	modules = map[string]Modularizer{ | ||||
| 		"test": testRouterMod{ | ||||
| 			routes: map[string]HandlerFunc{ | ||||
| 				path: h, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	modules = map[string]Modularizer{} | ||||
| 	RegisterModule("testrouter", testRouterModule{ | ||||
| 		routes: map[string]HandlerFunc{path: h}, | ||||
| 	}) | ||||
|  | ||||
| 	w = httptest.NewRecorder() | ||||
| 	r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil) | ||||
| @@ -52,7 +49,7 @@ func TestRouter(t *testing.T) { | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| 	router, _ := testRouterSetup() | ||||
| 	_, router, _ := testRouterSetup() | ||||
|  | ||||
| 	router.ServeHTTP(w, r) | ||||
|  | ||||
| @@ -62,7 +59,7 @@ func TestRouter(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestRouter_NotFound(t *testing.T) { | ||||
| 	router, _ := testRouterSetup() | ||||
| 	ab, router, _ := testRouterSetup() | ||||
| 	w := httptest.NewRecorder() | ||||
| 	r, _ := http.NewRequest("GET", "http://localhost/wat", nil) | ||||
|  | ||||
| @@ -75,7 +72,7 @@ func TestRouter_NotFound(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	called := false | ||||
| 	Cfg.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	ab.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		called = true | ||||
| 	}) | ||||
|  | ||||
| @@ -93,7 +90,7 @@ func TestRouter_BadRequest(t *testing.T) { | ||||
| 		}, | ||||
| 	) | ||||
|  | ||||
| 	router, logger := testRouterSetup() | ||||
| 	ab, router, logger := testRouterSetup() | ||||
| 	logger.Reset() | ||||
| 	router.ServeHTTP(w, r) | ||||
|  | ||||
| @@ -109,7 +106,7 @@ func TestRouter_BadRequest(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	called := false | ||||
| 	Cfg.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	ab.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		called = true | ||||
| 	}) | ||||
|  | ||||
| @@ -132,7 +129,7 @@ func TestRouter_Error(t *testing.T) { | ||||
| 		}, | ||||
| 	) | ||||
|  | ||||
| 	router, logger := testRouterSetup() | ||||
| 	ab, router, logger := testRouterSetup() | ||||
| 	logger.Reset() | ||||
| 	router.ServeHTTP(w, r) | ||||
|  | ||||
| @@ -148,7 +145,7 @@ func TestRouter_Error(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	called := false | ||||
| 	Cfg.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 	ab.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		called = true | ||||
| 	}) | ||||
|  | ||||
| @@ -177,10 +174,10 @@ func TestRouter_Redirect(t *testing.T) { | ||||
| 		}, | ||||
| 	) | ||||
|  | ||||
| 	router, logger := testRouterSetup() | ||||
| 	ab, router, logger := testRouterSetup() | ||||
|  | ||||
| 	session := mockClientStore{} | ||||
| 	Cfg.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session } | ||||
| 	ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session } | ||||
|  | ||||
| 	logger.Reset() | ||||
| 	router.ServeHTTP(w, r) | ||||
|   | ||||
| @@ -72,6 +72,8 @@ func TestAttributeMeta_Names(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestAttributeMeta_Helpers(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	now := time.Now() | ||||
| 	attr := Attributes{ | ||||
| 		"integer":   int64(5), | ||||
|   | ||||
| @@ -64,7 +64,8 @@ func TestErrorList_Map(t *testing.T) { | ||||
| func TestValidate(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ctx := mockRequestContext(StoreUsername, "john", StoreEmail, "john@john.com") | ||||
| 	ab := New() | ||||
| 	ctx := mockRequestContext(ab, StoreUsername, "john", StoreEmail, "john@john.com") | ||||
|  | ||||
| 	errList := ctx.Validate([]Validator{ | ||||
| 		mockValidator{ | ||||
| @@ -95,19 +96,20 @@ func TestValidate(t *testing.T) { | ||||
| func TestValidate_Confirm(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	ctx := mockRequestContext(StoreUsername, "john", "confirmUsername", "johnny") | ||||
| 	ab := New() | ||||
| 	ctx := mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "johnny") | ||||
| 	errs := ctx.Validate(nil, StoreUsername, "confirmUsername").Map() | ||||
| 	if errs["confirmUsername"][0] != "Does not match username" { | ||||
| 		t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0]) | ||||
| 	} | ||||
|  | ||||
| 	ctx = mockRequestContext(StoreUsername, "john", "confirmUsername", "john") | ||||
| 	ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") | ||||
| 	errs = ctx.Validate(nil, StoreUsername, "confirmUsername").Map() | ||||
| 	if len(errs) != 0 { | ||||
| 		t.Error("Expected no errors:", errs) | ||||
| 	} | ||||
|  | ||||
| 	ctx = mockRequestContext(StoreUsername, "john", "confirmUsername", "john") | ||||
| 	ctx = mockRequestContext(ab, StoreUsername, "john", "confirmUsername", "john") | ||||
| 	errs = ctx.Validate(nil, StoreUsername).Map() | ||||
| 	if len(errs) != 0 { | ||||
| 		t.Error("Expected no errors:", errs) | ||||
|   | ||||
| @@ -3,6 +3,8 @@ package authboss | ||||
| import "testing" | ||||
|  | ||||
| func TestHTMLData(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	data := NewHTMLData("a", "b").MergeKV("c", "d").Merge(NewHTMLData("e", "f")) | ||||
| 	if data["a"].(string) != "b" { | ||||
| 		t.Error("A was wrong:", data["a"]) | ||||
| @@ -16,6 +18,8 @@ func TestHTMLData(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestHTMLData_Panics(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	nPanics := 0 | ||||
| 	panicCount := func() { | ||||
| 		if r := recover(); r != nil { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user