From ac3d2846f8e44ac9baeff0b1a1aee8cca265e20c Mon Sep 17 00:00:00 2001 From: Aaron L Date: Wed, 7 Mar 2018 15:13:06 -0800 Subject: [PATCH] Rewrite the remember module - Add context keys and storage pieces for remember --- context.go | 12 +- internal/mocks/mocks.go | 78 +------ remember/remember.go | 268 ++++++++++------------- remember/remember_test.go | 433 +++++++++++++++++++++++++------------- storage.go | 33 ++- 5 files changed, 451 insertions(+), 373 deletions(-) diff --git a/context.go b/context.go index e7adc6b..5f32a4a 100644 --- a/context.go +++ b/context.go @@ -19,6 +19,11 @@ const ( // map[string]interface{} (authboss.HTMLData) to pass to the // renderer CTXKeyData contextKey = "data" + + // CTXKeyRM is used to flag the remember me module to actually do the + // remembering, since this is a per-user operation, authentication modules + // need to supply this key if they wish to allow users to be remembered. + CTXKeyRM contextKey = "rm" ) func (c contextKey) String() string { @@ -79,12 +84,7 @@ func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) User { } func (a *Authboss) currentUser(ctx context.Context, pid string) (User, error) { - user, err := a.Storage.Server.Load(ctx, pid) - if err != nil { - return nil, err - } - - return user, nil + return a.Storage.Server.Load(ctx, pid) } // LoadCurrentUserID takes a pointer to a pointer to the request in order to diff --git a/internal/mocks/mocks.go b/internal/mocks/mocks.go index d47ec7e..faf6dcd 100644 --- a/internal/mocks/mocks.go +++ b/internal/mocks/mocks.go @@ -137,38 +137,25 @@ func (s *ServerStorer) LoadByRecoverToken(ctx context.Context, token string) (au return nil, authboss.ErrUserNotFound } -/* -// TODO(aarondl): What is this? -// AddToken for remember me -func (m *Storer) AddToken(key, token string) error { - if len(m.AddTokenErr) > 0 { - return errors.New(m.AddTokenErr) - } - - arr := m.Tokens[key] - m.Tokens[key] = append(arr, token) +// AddRememberToken for remember me +func (s *ServerStorer) AddRememberToken(key, token string) error { + arr := s.RMTokens[key] + s.RMTokens[key] = append(arr, token) return nil } -// DelTokens for a user -func (m *Storer) DelTokens(key string) error { - if len(m.DelTokensErr) > 0 { - return errors.New(m.DelTokensErr) - } - - delete(m.Tokens, key) +// DelRememberTokens for a user +func (s *ServerStorer) DelRememberTokens(key string) error { + delete(s.RMTokens, key) return nil } -// UseToken if it exists, deleting it in the process -func (m *Storer) UseToken(givenKey, token string) (err error) { - if len(m.UseTokenErr) > 0 { - return errors.New(m.UseTokenErr) - } - - if arr, ok := m.Tokens[givenKey]; ok { +// UseRememberToken if it exists, deleting it in the process +func (s *ServerStorer) UseRememberToken(givenKey, token string) (err error) { + if arr, ok := s.RMTokens[givenKey]; ok { for _, tok := range arr { if tok == token { + delete(s.RMTokens, givenKey) return nil } } @@ -177,49 +164,6 @@ func (m *Storer) UseToken(givenKey, token string) (err error) { return authboss.ErrTokenNotFound } -// RecoverUser by the token. -func (m *Storer) RecoverUser(token string) (result interface{}, err error) { - if len(m.RecoverUserErr) > 0 { - return nil, errors.New(m.RecoverUserErr) - } - - for _, user := range m.Users { - if user["recover_token"] == token { - - u := &User{} - if err = user.Bind(u, false); err != nil { - panic(err) - } - - return u, nil - } - } - - return nil, authboss.ErrUserNotFound -} - -// ConfirmUser via their token -func (m *Storer) ConfirmUser(confirmToken string) (result interface{}, err error) { - if len(m.ConfirmUserErr) > 0 { - return nil, errors.New(m.ConfirmUserErr) - } - - for _, user := range m.Users { - if user["confirm_token"] == confirmToken { - - u := &User{} - if err = user.Bind(u, false); err != nil { - panic(err) - } - - return u, nil - } - } - - return nil, authboss.ErrUserNotFound -} -*/ - // FailStorer is used for testing module initialize functions that recover more than the base storer type FailStorer struct { User diff --git a/remember/remember.go b/remember/remember.go index 57c880d..ea55643 100644 --- a/remember/remember.go +++ b/remember/remember.go @@ -1,12 +1,12 @@ -// Package remember implements persistent logins through the cookie storer. +// Package remember implements persistent logins using cookies package remember import ( "bytes" - "crypto/md5" "crypto/rand" + "crypto/sha512" "encoding/base64" - "encoding/json" + "net/http" "github.com/pkg/errors" @@ -14,30 +14,13 @@ import ( ) const ( - nRandBytes = 32 + nNonceSize = 32 ) var ( errUserMissing = errors.New("user not loaded in callback") ) -// RememberStorer must be implemented in order to satisfy the remember module's -// storage requirements. If the implementer is a typical database then -// the tokens should be stored in a separate table since they require a 1-n -// with the user for each device the user wishes to remain logged in on. -// -// Remember storer will look at both authboss's configured Storer and OAuth2Storer -// for compatibility. -type RememberStorer interface { - // AddToken saves a new token for the key. - AddToken(key, token string) error - // DelTokens removes all tokens for a given key. - DelTokens(key string) error - // UseToken finds the key-token pair, removes the entry in the store - // and returns nil. If the token could not be found return ErrTokenNotFound. - UseToken(givenKey, token string) (err error) -} - func init() { authboss.RegisterModule("remember", &Remember{}) } @@ -47,62 +30,45 @@ type Remember struct { *authboss.Authboss } -// Initialize module -func (r *Remember) Initialize(ab *authboss.Authboss) error { +// Init module +func (r *Remember) Init(ab *authboss.Authboss) error { r.Authboss = ab - if r.Storer != nil || r.OAuth2Storer != nil { - if _, ok := r.Storer.(RememberStorer); !ok { - if _, ok := r.OAuth2Storer.(RememberStorer); !ok { - return errors.New("rememberStorer required for remember functionality") - } - } - } else if r.StoreMaker == nil && r.OAuth2StoreMaker == nil { - return errors.New("need a rememberStorer") - } - - r.Events.Before(authboss.EventGetUserSession, r.auth) - r.Events.After(authboss.EventAuth, r.afterAuth) - r.Events.After(authboss.EventOAuth, r.afterOAuth) - r.Events.After(authboss.EventPasswordReset, r.afterPassword) + r.Events.After(authboss.EventAuth, r.RememberAfterAuth) + //TODO(aarondl): Rectify this once oauth2 is done + // r.Events.After(authboss.EventOAuth, r.RememberAfterAuth) + r.Events.After(authboss.EventPasswordReset, r.AfterPasswordReset) return nil } -// Routes for module -func (r *Remember) Routes() authboss.RouteTable { - return nil -} - -// Storage requirements -func (r *Remember) Storage() authboss.StorageOptions { - return authboss.StorageOptions{ - r.PrimaryID: authboss.String, - } -} - -// afterAuth is called after authentication is successful. -func (r *Remember) afterAuth(ctx *authboss.Context) error { - if val := ctx.Values[authboss.CookieRemember]; val != "true" { - return nil +// RememberAfterAuth creates a remember token and saves it in the user's cookies. +func (r *Remember) RememberAfterAuth(w http.ResponseWriter, req *http.Request, handled bool) (bool, error) { + rmIntf := req.Context().Value(authboss.CTXKeyRM) + if rmIntf == nil { + return false, nil + } else if rm, ok := rmIntf.(bool); ok && !rm { + return false, nil } - if ctx.User == nil { - return errUserMissing - } - - key, err := ctx.User.StringErr(r.PrimaryID) + user := r.Authboss.CurrentUserP(w, req) + hash, token, err := GenerateToken(user.GetPID()) if err != nil { - return err + return false, err } - if _, err := r.new(ctx.CookieStorer, key); err != nil { - return errors.Wrapf(err, "failed to create remember token") + storer := authboss.EnsureCanRemember(r.Authboss.Config.Storage.Server) + if err = storer.AddRememberToken(user.GetPID(), hash); err != nil { + return false, err } - return nil + authboss.PutCookie(w, authboss.CookieRemember, token) + + return false, nil } +/* +// TODO(aarondl): Either discard or make this useful later after oauth2 // afterOAuth is called after oauth authentication is successful. // Has to pander to horrible state variable packing to figure out if we want // to be remembered. @@ -143,113 +109,113 @@ func (r *Remember) afterOAuth(ctx *authboss.Context) error { return nil } +*/ -// afterPassword is called after the password has been reset. -func (r *Remember) afterPassword(ctx *authboss.Context) error { - if ctx.User == nil { - return nil +// Middleware automatically authenticates users if they have remember me tokens +// If the user has been loaded already, it returns early +func Middleware(ab *authboss.Authboss) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Context().Value(authboss.CTXKeyPID) == nil && r.Context().Value(authboss.CTXKeyUser) == nil { + if err := Authenticate(ab, w, r); err != nil { + logger := ab.RequestLogger(r) + logger.Errorf("failed to authenticate user via remember me: %+v", err) + } + } + + next.ServeHTTP(w, r) + }) } +} - id, ok := ctx.User.String(r.PrimaryID) +// Authenticate the user using their remember cookie. +// If the cookie proves unusable it will be deleted. A cookie +// may be unusable for the following reasons: +// - Can't decode the base64 +// - Invalid token format +// - Can't find token in DB +func Authenticate(ab *authboss.Authboss, w http.ResponseWriter, req *http.Request) error { + logger := ab.RequestLogger(req) + cookie, ok := authboss.GetCookie(req, authboss.CookieRemember) if !ok { return nil } - ctx.CookieStorer.Del(authboss.CookieRemember) - - var storer RememberStorer - if storer, ok = ctx.Storer.(RememberStorer); !ok { - if storer, ok = ctx.OAuth2Storer.(RememberStorer); !ok { - return nil - } - } - - return storer.DelTokens(id) -} - -// new generates a new remember token and stores it in the configured RememberStorer. -// The return value is a token that should only be given to a user if the delivery -// method is secure which means at least signed if not encrypted. -func (r *Remember) new(cstorer authboss.ClientStorer, storageKey string) (string, error) { - token := make([]byte, nRandBytes+len(storageKey)+1) - copy(token, []byte(storageKey)) - token[len(storageKey)] = ';' - - if _, err := rand.Read(token[len(storageKey)+1:]); err != nil { - return "", err - } - - sum := md5.Sum(token) - finalToken := base64.URLEncoding.EncodeToString(token) - storageToken := base64.StdEncoding.EncodeToString(sum[:]) - - var storer RememberStorer - var ok bool - if storer, ok = r.Storer.(RememberStorer); !ok { - storer, ok = r.OAuth2Storer.(RememberStorer) - } - - // Save the token in the DB - if err := storer.AddToken(storageKey, storageToken); err != nil { - return "", err - } - - // Write the finalToken to the cookie - cstorer.Put(authboss.CookieRemember, finalToken) - - return finalToken, nil -} - -// auth takes a token that was given to a user and checks to see if something -// is matching in the database. If something is found the old token is deleted -// and a new one should be generated. -func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) { - if val, ok := ctx.SessionStorer.Get(authboss.SessionKey); ok || len(val) > 0 { - return authboss.InterruptNone, nil - } - - finalToken, ok := ctx.CookieStorer.Get(authboss.CookieRemember) - if !ok { - return authboss.InterruptNone, nil - } - - token, err := base64.URLEncoding.DecodeString(finalToken) + rawToken, err := base64.URLEncoding.DecodeString(cookie) if err != nil { - return authboss.InterruptNone, err + authboss.DelCookie(w, authboss.CookieRemember) + logger.Infof("failed to decode remember me cookie, deleting cookie") + return nil } - index := bytes.IndexByte(token, ';') + index := bytes.IndexByte(rawToken, ';') if index < 0 { - return authboss.InterruptNone, errors.New("invalid remember token") + authboss.DelCookie(w, authboss.CookieRemember) + logger.Infof("failed to decode remember me token, deleting cookie") + return nil } - // Get the key. - givenKey := string(token[:index]) + pid := string(rawToken[:index]) + sum := sha512.Sum512(rawToken) + hash := base64.StdEncoding.EncodeToString(sum[:]) - // Verify the tokens match. - sum := md5.Sum(token) - - var storer RememberStorer - if storer, ok = ctx.Storer.(RememberStorer); !ok { - storer, ok = ctx.OAuth2Storer.(RememberStorer) + storer := authboss.EnsureCanRemember(ab.Config.Storage.Server) + err = storer.UseRememberToken(pid, hash) + switch { + case err == authboss.ErrTokenNotFound: + logger.Infof("remember me cookie had a token that was not in storage, deleting cookie") + authboss.DelCookie(w, authboss.CookieRemember) + return nil + case err != nil: + return err } - err = storer.UseToken(givenKey, base64.StdEncoding.EncodeToString(sum[:])) - if err == authboss.ErrTokenNotFound { - return authboss.InterruptNone, nil - } else if err != nil { - return authboss.InterruptNone, err - } - - _, err = r.new(ctx.CookieStorer, givenKey) + hash, token, err := GenerateToken(pid) if err != nil { - return authboss.InterruptNone, err + return err } - // Ensure a half-auth. - ctx.SessionStorer.Put(authboss.SessionHalfAuthKey, "true") - // Log the user in. - ctx.SessionStorer.Put(authboss.SessionKey, givenKey) + if err = storer.AddRememberToken(pid, hash); err != nil { + return errors.Wrap(err, "failed to save me token") + } - return authboss.InterruptNone, nil + authboss.PutSession(w, authboss.SessionKey, pid) + authboss.PutSession(w, authboss.SessionHalfAuthKey, "true") + authboss.DelCookie(w, authboss.CookieRemember) + authboss.PutCookie(w, authboss.CookieRemember, token) + + return nil +} + +// AfterPasswordReset is called after the password has been reset, since +// it should invalidate all tokens associated to that user. +func (r *Remember) AfterPasswordReset(w http.ResponseWriter, req *http.Request, handled bool) (bool, error) { + user, err := r.Authboss.CurrentUser(w, req) + if err != nil { + return false, err + } + + logger := r.Authboss.RequestLogger(req) + storer := authboss.EnsureCanRemember(r.Authboss.Config.Storage.Server) + + pid := user.GetPID() + authboss.DelCookie(w, authboss.CookieRemember) + + logger.Infof("deleting tokens and rm cookies for user %s due to password reset", pid) + + return false, storer.DelRememberTokens(pid) +} + +// GenerateToken creates a remember me token +func GenerateToken(pid string) (hash string, token string, err error) { + rawToken := make([]byte, nNonceSize+len(pid)+1) + copy(rawToken, []byte(pid)) + rawToken[len(pid)] = ';' + + if _, err := rand.Read(rawToken[len(pid)+1:]); err != nil { + return "", "", errors.Wrap(err, "failed to create remember me nonce") + } + + sum := sha512.Sum512(rawToken) + return base64.StdEncoding.EncodeToString(sum[:]), base64.URLEncoding.EncodeToString(rawToken), nil } diff --git a/remember/remember_test.go b/remember/remember_test.go index 547749e..81c06ac 100644 --- a/remember/remember_test.go +++ b/remember/remember_test.go @@ -2,196 +2,343 @@ package remember import ( "bytes" + "context" + "crypto/sha512" + "encoding/base64" "net/http" + "net/http/httptest" "testing" "github.com/volatiletech/authboss" "github.com/volatiletech/authboss/internal/mocks" ) -func TestInitialize(t *testing.T) { +func TestInit(t *testing.T) { t.Parallel() ab := authboss.New() r := &Remember{} - err := r.Initialize(ab) - if err == nil { - t.Error("Expected error about token storers.") - } - - ab.Storage.Server = mocks.MockFailStorer{} - err = r.Initialize(ab) - if err == nil { - t.Error("Expected error about token storers.") - } - - ab.Storage.Server = mocks.NewMockStorer() - err = r.Initialize(ab) + err := r.Init(ab) if err != nil { - t.Error("Unexpected error:", err) + t.Fatal(err) } } -func TestAfterAuth(t *testing.T) { +type testHarness struct { + remember *Remember + ab *authboss.Authboss + + session *mocks.ClientStateRW + cookies *mocks.ClientStateRW + storer *mocks.ServerStorer +} + +func testSetup() *testHarness { + harness := &testHarness{} + + harness.ab = authboss.New() + harness.session = mocks.NewClientRW() + harness.cookies = mocks.NewClientRW() + harness.storer = mocks.NewServerStorer() + + harness.ab.Config.Core.Logger = mocks.Logger{} + harness.ab.Config.Storage.SessionState = harness.session + harness.ab.Config.Storage.CookieState = harness.cookies + harness.ab.Config.Storage.Server = harness.storer + + harness.remember = &Remember{harness.ab} + + return harness +} + +func TestRememberAfterAuth(t *testing.T) { t.Parallel() - r := Remember{authboss.New()} - storer := mocks.NewMockStorer() - r.Storer = storer + h := testSetup() - cookies := mocks.NewMockClientStorer() - session := mocks.NewMockClientStorer() + user := &mocks.User{Email: "test@test.com"} - req, err := http.NewRequest("POST", "http://localhost", bytes.NewBufferString("rm=true")) - if err != nil { - t.Error("Unexpected Error:", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + r := mocks.Request("POST") + r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyRM, true)) + r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user)) + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) - ctx := r.NewContext() - ctx.SessionStorer = session - ctx.CookieStorer = cookies - ctx.User = authboss.Attributes{r.PrimaryID: "test@email.com"} - - ctx.Values = map[string]string{authboss.CookieRemember: "true"} - - if err := r.afterAuth(ctx); err != nil { - t.Error(err) + if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil { + t.Fatal(err) + } else if handled { + t.Error("should never be handled") } - if _, ok := cookies.Values[authboss.CookieRemember]; !ok { - t.Error("Expected a cookie to have been set.") + // Force flush of headers so cookies are written + w.WriteHeader(http.StatusOK) + + if len(h.storer.RMTokens["test@test.com"]) != 1 { + t.Error("token was not persisted:", h.storer.RMTokens) + } + + if cookie, ok := h.cookies.ClientValues[authboss.CookieRemember]; !ok || len(cookie) == 0 { + t.Error("remember me cookie was not set") } } -func TestAfterOAuth(t *testing.T) { +func TestRememberAfterAuthSkip(t *testing.T) { t.Parallel() - r := Remember{authboss.New()} - storer := mocks.NewMockStorer() - r.Storer = storer + h := testSetup() - cookies := mocks.NewMockClientStorer() - session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`) + r := mocks.Request("POST") + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) - ctx := r.NewContext() - ctx.SessionStorer = session - ctx.CookieStorer = cookies - ctx.User = authboss.Attributes{ - authboss.StoreOAuth2UID: "uid", - authboss.StoreOAuth2Provider: "google", + if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil { + t.Fatal(err) + } else if handled { + t.Error("should never be handled") } - if err := r.afterOAuth(ctx); err != nil { - t.Error(err) + if len(h.storer.RMTokens["test@test.com"]) != 0 { + t.Error("expected no tokens to be created") } - if _, ok := cookies.Values[authboss.CookieRemember]; !ok { - t.Error("Expected a cookie to have been set.") + r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyRM, false)) + + if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil { + t.Fatal(err) + } else if handled { + t.Error("should never be handled") } + + if len(h.storer.RMTokens["test@test.com"]) != 0 { + t.Error("expected no tokens to be created") + } +} + +func TestMiddlewareAuth(t *testing.T) { + t.Parallel() + + h := testSetup() + + user := &mocks.User{Email: "test@test.com"} + hash, token, _ := GenerateToken(user.Email) + + h.storer.Users[user.Email] = user + h.storer.RMTokens[user.Email] = []string{hash} + h.cookies.ClientValues[authboss.CookieRemember] = token + + r := mocks.Request("POST") + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) + + var err error + r, err = h.ab.LoadClientState(w, r) + if err != nil { + t.Fatal(err) + } + + called := false + server := Middleware(h.ab)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + server.ServeHTTP(w, r) + + if !called { + t.Error("it should have called the underlying handler") + } + + if h.session.ClientValues[authboss.SessionKey] != user.Email { + t.Error("should have saved the pid in the session") + } + // Elided the rest of the checks, authenticate tests do this +} + +func TestAuthenticateSuccess(t *testing.T) { + t.Parallel() + + h := testSetup() + + user := &mocks.User{Email: "test@test.com"} + hash, token, _ := GenerateToken(user.Email) + + h.storer.Users[user.Email] = user + h.storer.RMTokens[user.Email] = []string{hash} + h.cookies.ClientValues[authboss.CookieRemember] = token + + r := mocks.Request("POST") + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) + + var err error + r, err = h.ab.LoadClientState(w, r) + if err != nil { + t.Fatal(err) + } + + if err = Authenticate(h.ab, w, r); err != nil { + t.Fatal(err) + } + + w.WriteHeader(http.StatusOK) + + if cookie := h.cookies.ClientValues[authboss.CookieRemember]; cookie == token { + t.Error("the cookie should have been replaced with a new token") + } + + if len(h.storer.RMTokens[user.Email]) != 1 { + t.Error("one token should have been removed, and one should have been added") + } else if h.storer.RMTokens[user.Email][0] == token { + t.Error("a new token should have been saved") + } + + if h.session.ClientValues[authboss.SessionKey] != user.Email { + t.Error("should have saved the pid in the session") + } + if h.session.ClientValues[authboss.SessionHalfAuthKey] != "true" { + t.Error("it should have become a half-authed session") + } +} + +func TestAuthenticateTokenNotFound(t *testing.T) { + t.Parallel() + + h := testSetup() + + user := &mocks.User{Email: "test@test.com"} + _, token, _ := GenerateToken(user.Email) + + h.storer.Users[user.Email] = user + h.cookies.ClientValues[authboss.CookieRemember] = token + + r := mocks.Request("POST") + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) + + var err error + r, err = h.ab.LoadClientState(w, r) + if err != nil { + t.Fatal(err) + } + + if err = Authenticate(h.ab, w, r); err != nil { + t.Fatal(err) + } + + w.WriteHeader(http.StatusOK) + + if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 { + t.Error("there should be no remember cookie left") + } + + if len(h.session.ClientValues[authboss.SessionKey]) != 0 { + t.Error("it should have not logged the user in") + } +} + +func TestAuthenticateBadTokens(t *testing.T) { + t.Parallel() + + h := testSetup() + + doTest := func(t *testing.T) { + t.Helper() + + r := mocks.Request("POST") + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) + + var err error + r, err = h.ab.LoadClientState(w, r) + if err != nil { + t.Fatal(err) + } + + if err = Authenticate(h.ab, w, r); err != nil { + t.Fatal(err) + } + + w.WriteHeader(http.StatusOK) + + if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 { + t.Error("there should be no remember cookie left") + } + + if len(h.session.ClientValues[authboss.SessionKey]) != 0 { + t.Error("it should have not logged the user in") + } + } + + t.Run("base64", func(t *testing.T) { + h.cookies.ClientValues[authboss.CookieRemember] = "a" + doTest(t) + }) + t.Run("cookieformat", func(t *testing.T) { + h.cookies.ClientValues[authboss.CookieRemember] = `aGVsbG8=` // hello + doTest(t) + }) } func TestAfterPasswordReset(t *testing.T) { t.Parallel() - r := Remember{authboss.New()} + h := testSetup() - id := "test@email.com" + user := &mocks.User{Email: "test@test.com"} + hash1, _, _ := GenerateToken(user.Email) + hash2, token2, _ := GenerateToken(user.Email) - storer := mocks.NewMockStorer() - r.Storer = storer - session := mocks.NewMockClientStorer() - cookies := mocks.NewMockClientStorer() - storer.Tokens[id] = []string{"one", "two"} - cookies.Values[authboss.CookieRemember] = "token" + h.storer.Users[user.Email] = user + h.storer.RMTokens[user.Email] = []string{hash1, hash2} + h.cookies.ClientValues[authboss.CookieRemember] = token2 - ctx := r.NewContext() - ctx.User = authboss.Attributes{r.PrimaryID: id} - ctx.SessionStorer = session - ctx.CookieStorer = cookies + r := mocks.Request("POST") + r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user)) + rec := httptest.NewRecorder() + w := h.ab.NewResponse(rec, r) - if err := r.afterPassword(ctx); err != nil { + if handled, err := h.remember.AfterPasswordReset(w, r, false); err != nil { + t.Error(err) + } else if handled { + t.Error("it should never be handled") + } + + w.WriteHeader(http.StatusOK) // Force header flush + + if len(h.storer.RMTokens[user.Email]) != 0 { + t.Error("all remember me tokens should have been removed") + } + if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 { + t.Error("there should be no remember cookie left") + } +} + +func TestGenerateToken(t *testing.T) { + t.Parallel() + + hash, tok, err := GenerateToken("test") + if err != nil { + t.Fatal(err) + } + + rawToken, err := base64.URLEncoding.DecodeString(tok) + if err != nil { t.Error(err) } - if _, ok := cookies.Values[authboss.CookieRemember]; ok { - t.Error("Expected the remember cookie to be deleted.") + index := bytes.IndexByte(rawToken, ';') + if index < 0 { + t.Fatalf("problem with the token format: %v", rawToken) } - if len(storer.Tokens) != 0 { - t.Error("Should have wiped out all tokens.") - } -} - -func TestNew(t *testing.T) { - t.Parallel() - - r := &Remember{authboss.New()} - storer := mocks.NewMockStorer() - r.Storer = storer - cookies := mocks.NewMockClientStorer() - - key := "tester" - token, err := r.new(cookies, key) - - if err != nil { - t.Error("Unexpected error:", err) - } - - if len(token) == 0 { - t.Error("Expected a token.") - } - - if tok, ok := storer.Tokens[key]; !ok { - t.Error("Expected it to store against the key:", key) - } else if len(tok) != 1 || len(tok[0]) == 0 { - t.Error("Expected a token to be saved.") - } - - if token != cookies.Values[authboss.CookieRemember] { - t.Error("Expected a cookie set with the token.") - } -} - -func TestAuth(t *testing.T) { - t.Parallel() - - r := &Remember{authboss.New()} - storer := mocks.NewMockStorer() - r.Storer = storer - - cookies := mocks.NewMockClientStorer() - session := mocks.NewMockClientStorer() - ctx := r.NewContext() - ctx.CookieStorer = cookies - ctx.SessionStorer = session - - key := "tester" - _, err := r.new(cookies, key) - if err != nil { - t.Error("Unexpected error:", err) - } - - cookie, _ := cookies.Get(authboss.CookieRemember) - - interrupt, err := r.auth(ctx) - if err != nil { - t.Error("Unexpected error:", err) - } - - if session.Values[authboss.SessionHalfAuthKey] != "true" { - t.Error("The user should have been half-authed.") - } - - if session.Values[authboss.SessionKey] != key { - t.Error("The user should have been logged in.") - } - - if chocolateChip, _ := cookies.Get(authboss.CookieRemember); chocolateChip == cookie { - t.Error("Expected cookie to be different") - } - - if authboss.InterruptNone != interrupt { - t.Error("Keys should have matched:", interrupt) + bytPID := rawToken[:index] + if string(bytPID) != "test" { + t.Errorf("pid wrong: %s", bytPID) + } + + sum := sha512.Sum512(rawToken) + gotHash := base64.StdEncoding.EncodeToString(sum[:]) + if hash != gotHash { + t.Errorf("hash wrong, want: %s, got: %s", hash, gotHash) } } diff --git a/storage.go b/storage.go index e66d005..70a4eaa 100644 --- a/storage.go +++ b/storage.go @@ -30,13 +30,13 @@ const ( ) var ( + // ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID + // of the record is found. + ErrUserFound = errors.New("user found") // ErrUserNotFound should be returned from Get when the record is not found. ErrUserNotFound = errors.New("user not found") // ErrTokenNotFound should be returned from UseToken when the record is not found. ErrTokenNotFound = errors.New("token not found") - // ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID - // of the record is found. - ErrUserFound = errors.New("user found") ) // ServerStorer represents the data store that's capable of loading users @@ -82,11 +82,22 @@ type RecoveringServerStorer interface { LoadByRecoverToken(ctx context.Context, token string) (RecoverableUser, error) } +// RememberingServerStorer allows users to be remembered across sessions +type RememberingServerStorer interface { + // AddRememberToken to a user + AddRememberToken(pid, token string) error + // DelRememberTokens removes all tokens for the given pid + DelRememberTokens(pid string) error + // UseRememberToken finds the pid-token pair and deletes it. + // If the token could not be found return ErrTokenNotFound + UseRememberToken(pid, token string) error +} + // EnsureCanCreate makes sure the server storer supports create operations func EnsureCanCreate(storer ServerStorer) CreatingServerStorer { s, ok := storer.(CreatingServerStorer) if !ok { - panic("could not upgrade serverstorer to creatingserverstorer, check your struct") + panic("could not upgrade ServerStorer to CreatingServerStorer, check your struct") } return s @@ -96,7 +107,7 @@ func EnsureCanCreate(storer ServerStorer) CreatingServerStorer { func EnsureCanConfirm(storer ServerStorer) ConfirmingServerStorer { s, ok := storer.(ConfirmingServerStorer) if !ok { - panic("could not upgrade serverstorer to confirmingserverstorer, check your struct") + panic("could not upgrade ServerStorer to ConfirmingServerStorer, check your struct") } return s @@ -106,7 +117,17 @@ func EnsureCanConfirm(storer ServerStorer) ConfirmingServerStorer { func EnsureCanRecover(storer ServerStorer) RecoveringServerStorer { s, ok := storer.(RecoveringServerStorer) if !ok { - panic("could not upgrade serverstorer to recoveringserverstorer, check your struct") + panic("could not upgrade ServerStorer to RecoveringServerStorer, check your struct") + } + + return s +} + +// EnsureCanRemember makes sure the server storer supports remember operations +func EnsureCanRemember(storer ServerStorer) RememberingServerStorer { + s, ok := storer.(RememberingServerStorer) + if !ok { + panic("could not upgrade ServerStorer to RememberingServerStorer, check your struct") } return s