mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-10 04:17:59 +02:00
Rewrite the remember module
- Add context keys and storage pieces for remember
This commit is contained in:
parent
792f7381fd
commit
ac3d2846f8
12
context.go
12
context.go
@ -19,6 +19,11 @@ const (
|
||||
// map[string]interface{} (authboss.HTMLData) to pass to the
|
||||
// renderer
|
||||
CTXKeyData contextKey = "data"
|
||||
|
||||
// CTXKeyRM is used to flag the remember me module to actually do the
|
||||
// remembering, since this is a per-user operation, authentication modules
|
||||
// need to supply this key if they wish to allow users to be remembered.
|
||||
CTXKeyRM contextKey = "rm"
|
||||
)
|
||||
|
||||
func (c contextKey) String() string {
|
||||
@ -79,12 +84,7 @@ func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) User {
|
||||
}
|
||||
|
||||
func (a *Authboss) currentUser(ctx context.Context, pid string) (User, error) {
|
||||
user, err := a.Storage.Server.Load(ctx, pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
return a.Storage.Server.Load(ctx, pid)
|
||||
}
|
||||
|
||||
// LoadCurrentUserID takes a pointer to a pointer to the request in order to
|
||||
|
@ -137,38 +137,25 @@ func (s *ServerStorer) LoadByRecoverToken(ctx context.Context, token string) (au
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
/*
|
||||
// TODO(aarondl): What is this?
|
||||
// AddToken for remember me
|
||||
func (m *Storer) AddToken(key, token string) error {
|
||||
if len(m.AddTokenErr) > 0 {
|
||||
return errors.New(m.AddTokenErr)
|
||||
}
|
||||
|
||||
arr := m.Tokens[key]
|
||||
m.Tokens[key] = append(arr, token)
|
||||
// AddRememberToken for remember me
|
||||
func (s *ServerStorer) AddRememberToken(key, token string) error {
|
||||
arr := s.RMTokens[key]
|
||||
s.RMTokens[key] = append(arr, token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DelTokens for a user
|
||||
func (m *Storer) DelTokens(key string) error {
|
||||
if len(m.DelTokensErr) > 0 {
|
||||
return errors.New(m.DelTokensErr)
|
||||
}
|
||||
|
||||
delete(m.Tokens, key)
|
||||
// DelRememberTokens for a user
|
||||
func (s *ServerStorer) DelRememberTokens(key string) error {
|
||||
delete(s.RMTokens, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseToken if it exists, deleting it in the process
|
||||
func (m *Storer) UseToken(givenKey, token string) (err error) {
|
||||
if len(m.UseTokenErr) > 0 {
|
||||
return errors.New(m.UseTokenErr)
|
||||
}
|
||||
|
||||
if arr, ok := m.Tokens[givenKey]; ok {
|
||||
// UseRememberToken if it exists, deleting it in the process
|
||||
func (s *ServerStorer) UseRememberToken(givenKey, token string) (err error) {
|
||||
if arr, ok := s.RMTokens[givenKey]; ok {
|
||||
for _, tok := range arr {
|
||||
if tok == token {
|
||||
delete(s.RMTokens, givenKey)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -177,49 +164,6 @@ func (m *Storer) UseToken(givenKey, token string) (err error) {
|
||||
return authboss.ErrTokenNotFound
|
||||
}
|
||||
|
||||
// RecoverUser by the token.
|
||||
func (m *Storer) RecoverUser(token string) (result interface{}, err error) {
|
||||
if len(m.RecoverUserErr) > 0 {
|
||||
return nil, errors.New(m.RecoverUserErr)
|
||||
}
|
||||
|
||||
for _, user := range m.Users {
|
||||
if user["recover_token"] == token {
|
||||
|
||||
u := &User{}
|
||||
if err = user.Bind(u, false); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
// ConfirmUser via their token
|
||||
func (m *Storer) ConfirmUser(confirmToken string) (result interface{}, err error) {
|
||||
if len(m.ConfirmUserErr) > 0 {
|
||||
return nil, errors.New(m.ConfirmUserErr)
|
||||
}
|
||||
|
||||
for _, user := range m.Users {
|
||||
if user["confirm_token"] == confirmToken {
|
||||
|
||||
u := &User{}
|
||||
if err = user.Bind(u, false); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
*/
|
||||
|
||||
// FailStorer is used for testing module initialize functions that recover more than the base storer
|
||||
type FailStorer struct {
|
||||
User
|
||||
|
@ -1,12 +1,12 @@
|
||||
// Package remember implements persistent logins through the cookie storer.
|
||||
// Package remember implements persistent logins using cookies
|
||||
package remember
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@ -14,30 +14,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
nRandBytes = 32
|
||||
nNonceSize = 32
|
||||
)
|
||||
|
||||
var (
|
||||
errUserMissing = errors.New("user not loaded in callback")
|
||||
)
|
||||
|
||||
// RememberStorer must be implemented in order to satisfy the remember module's
|
||||
// storage requirements. If the implementer is a typical database then
|
||||
// the tokens should be stored in a separate table since they require a 1-n
|
||||
// with the user for each device the user wishes to remain logged in on.
|
||||
//
|
||||
// Remember storer will look at both authboss's configured Storer and OAuth2Storer
|
||||
// for compatibility.
|
||||
type RememberStorer interface {
|
||||
// AddToken saves a new token for the key.
|
||||
AddToken(key, token string) error
|
||||
// DelTokens removes all tokens for a given key.
|
||||
DelTokens(key string) error
|
||||
// UseToken finds the key-token pair, removes the entry in the store
|
||||
// and returns nil. If the token could not be found return ErrTokenNotFound.
|
||||
UseToken(givenKey, token string) (err error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
authboss.RegisterModule("remember", &Remember{})
|
||||
}
|
||||
@ -47,62 +30,45 @@ type Remember struct {
|
||||
*authboss.Authboss
|
||||
}
|
||||
|
||||
// Initialize module
|
||||
func (r *Remember) Initialize(ab *authboss.Authboss) error {
|
||||
// Init module
|
||||
func (r *Remember) Init(ab *authboss.Authboss) error {
|
||||
r.Authboss = ab
|
||||
|
||||
if r.Storer != nil || r.OAuth2Storer != nil {
|
||||
if _, ok := r.Storer.(RememberStorer); !ok {
|
||||
if _, ok := r.OAuth2Storer.(RememberStorer); !ok {
|
||||
return errors.New("rememberStorer required for remember functionality")
|
||||
}
|
||||
}
|
||||
} else if r.StoreMaker == nil && r.OAuth2StoreMaker == nil {
|
||||
return errors.New("need a rememberStorer")
|
||||
}
|
||||
|
||||
r.Events.Before(authboss.EventGetUserSession, r.auth)
|
||||
r.Events.After(authboss.EventAuth, r.afterAuth)
|
||||
r.Events.After(authboss.EventOAuth, r.afterOAuth)
|
||||
r.Events.After(authboss.EventPasswordReset, r.afterPassword)
|
||||
r.Events.After(authboss.EventAuth, r.RememberAfterAuth)
|
||||
//TODO(aarondl): Rectify this once oauth2 is done
|
||||
// r.Events.After(authboss.EventOAuth, r.RememberAfterAuth)
|
||||
r.Events.After(authboss.EventPasswordReset, r.AfterPasswordReset)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Routes for module
|
||||
func (r *Remember) Routes() authboss.RouteTable {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Storage requirements
|
||||
func (r *Remember) Storage() authboss.StorageOptions {
|
||||
return authboss.StorageOptions{
|
||||
r.PrimaryID: authboss.String,
|
||||
}
|
||||
}
|
||||
|
||||
// afterAuth is called after authentication is successful.
|
||||
func (r *Remember) afterAuth(ctx *authboss.Context) error {
|
||||
if val := ctx.Values[authboss.CookieRemember]; val != "true" {
|
||||
return nil
|
||||
// RememberAfterAuth creates a remember token and saves it in the user's cookies.
|
||||
func (r *Remember) RememberAfterAuth(w http.ResponseWriter, req *http.Request, handled bool) (bool, error) {
|
||||
rmIntf := req.Context().Value(authboss.CTXKeyRM)
|
||||
if rmIntf == nil {
|
||||
return false, nil
|
||||
} else if rm, ok := rmIntf.(bool); ok && !rm {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if ctx.User == nil {
|
||||
return errUserMissing
|
||||
}
|
||||
|
||||
key, err := ctx.User.StringErr(r.PrimaryID)
|
||||
user := r.Authboss.CurrentUserP(w, req)
|
||||
hash, token, err := GenerateToken(user.GetPID())
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
if _, err := r.new(ctx.CookieStorer, key); err != nil {
|
||||
return errors.Wrapf(err, "failed to create remember token")
|
||||
storer := authboss.EnsureCanRemember(r.Authboss.Config.Storage.Server)
|
||||
if err = storer.AddRememberToken(user.GetPID(), hash); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return nil
|
||||
authboss.PutCookie(w, authboss.CookieRemember, token)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
/*
|
||||
// TODO(aarondl): Either discard or make this useful later after oauth2
|
||||
// afterOAuth is called after oauth authentication is successful.
|
||||
// Has to pander to horrible state variable packing to figure out if we want
|
||||
// to be remembered.
|
||||
@ -143,113 +109,113 @@ func (r *Remember) afterOAuth(ctx *authboss.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
|
||||
// afterPassword is called after the password has been reset.
|
||||
func (r *Remember) afterPassword(ctx *authboss.Context) error {
|
||||
if ctx.User == nil {
|
||||
return nil
|
||||
// Middleware automatically authenticates users if they have remember me tokens
|
||||
// If the user has been loaded already, it returns early
|
||||
func Middleware(ab *authboss.Authboss) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Context().Value(authboss.CTXKeyPID) == nil && r.Context().Value(authboss.CTXKeyUser) == nil {
|
||||
if err := Authenticate(ab, w, r); err != nil {
|
||||
logger := ab.RequestLogger(r)
|
||||
logger.Errorf("failed to authenticate user via remember me: %+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
id, ok := ctx.User.String(r.PrimaryID)
|
||||
// Authenticate the user using their remember cookie.
|
||||
// If the cookie proves unusable it will be deleted. A cookie
|
||||
// may be unusable for the following reasons:
|
||||
// - Can't decode the base64
|
||||
// - Invalid token format
|
||||
// - Can't find token in DB
|
||||
func Authenticate(ab *authboss.Authboss, w http.ResponseWriter, req *http.Request) error {
|
||||
logger := ab.RequestLogger(req)
|
||||
cookie, ok := authboss.GetCookie(req, authboss.CookieRemember)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx.CookieStorer.Del(authboss.CookieRemember)
|
||||
|
||||
var storer RememberStorer
|
||||
if storer, ok = ctx.Storer.(RememberStorer); !ok {
|
||||
if storer, ok = ctx.OAuth2Storer.(RememberStorer); !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return storer.DelTokens(id)
|
||||
}
|
||||
|
||||
// new generates a new remember token and stores it in the configured RememberStorer.
|
||||
// The return value is a token that should only be given to a user if the delivery
|
||||
// method is secure which means at least signed if not encrypted.
|
||||
func (r *Remember) new(cstorer authboss.ClientStorer, storageKey string) (string, error) {
|
||||
token := make([]byte, nRandBytes+len(storageKey)+1)
|
||||
copy(token, []byte(storageKey))
|
||||
token[len(storageKey)] = ';'
|
||||
|
||||
if _, err := rand.Read(token[len(storageKey)+1:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sum := md5.Sum(token)
|
||||
finalToken := base64.URLEncoding.EncodeToString(token)
|
||||
storageToken := base64.StdEncoding.EncodeToString(sum[:])
|
||||
|
||||
var storer RememberStorer
|
||||
var ok bool
|
||||
if storer, ok = r.Storer.(RememberStorer); !ok {
|
||||
storer, ok = r.OAuth2Storer.(RememberStorer)
|
||||
}
|
||||
|
||||
// Save the token in the DB
|
||||
if err := storer.AddToken(storageKey, storageToken); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Write the finalToken to the cookie
|
||||
cstorer.Put(authboss.CookieRemember, finalToken)
|
||||
|
||||
return finalToken, nil
|
||||
}
|
||||
|
||||
// auth takes a token that was given to a user and checks to see if something
|
||||
// is matching in the database. If something is found the old token is deleted
|
||||
// and a new one should be generated.
|
||||
func (r *Remember) auth(ctx *authboss.Context) (authboss.Interrupt, error) {
|
||||
if val, ok := ctx.SessionStorer.Get(authboss.SessionKey); ok || len(val) > 0 {
|
||||
return authboss.InterruptNone, nil
|
||||
}
|
||||
|
||||
finalToken, ok := ctx.CookieStorer.Get(authboss.CookieRemember)
|
||||
if !ok {
|
||||
return authboss.InterruptNone, nil
|
||||
}
|
||||
|
||||
token, err := base64.URLEncoding.DecodeString(finalToken)
|
||||
rawToken, err := base64.URLEncoding.DecodeString(cookie)
|
||||
if err != nil {
|
||||
return authboss.InterruptNone, err
|
||||
authboss.DelCookie(w, authboss.CookieRemember)
|
||||
logger.Infof("failed to decode remember me cookie, deleting cookie")
|
||||
return nil
|
||||
}
|
||||
|
||||
index := bytes.IndexByte(token, ';')
|
||||
index := bytes.IndexByte(rawToken, ';')
|
||||
if index < 0 {
|
||||
return authboss.InterruptNone, errors.New("invalid remember token")
|
||||
authboss.DelCookie(w, authboss.CookieRemember)
|
||||
logger.Infof("failed to decode remember me token, deleting cookie")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the key.
|
||||
givenKey := string(token[:index])
|
||||
pid := string(rawToken[:index])
|
||||
sum := sha512.Sum512(rawToken)
|
||||
hash := base64.StdEncoding.EncodeToString(sum[:])
|
||||
|
||||
// Verify the tokens match.
|
||||
sum := md5.Sum(token)
|
||||
|
||||
var storer RememberStorer
|
||||
if storer, ok = ctx.Storer.(RememberStorer); !ok {
|
||||
storer, ok = ctx.OAuth2Storer.(RememberStorer)
|
||||
storer := authboss.EnsureCanRemember(ab.Config.Storage.Server)
|
||||
err = storer.UseRememberToken(pid, hash)
|
||||
switch {
|
||||
case err == authboss.ErrTokenNotFound:
|
||||
logger.Infof("remember me cookie had a token that was not in storage, deleting cookie")
|
||||
authboss.DelCookie(w, authboss.CookieRemember)
|
||||
return nil
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
err = storer.UseToken(givenKey, base64.StdEncoding.EncodeToString(sum[:]))
|
||||
if err == authboss.ErrTokenNotFound {
|
||||
return authboss.InterruptNone, nil
|
||||
} else if err != nil {
|
||||
return authboss.InterruptNone, err
|
||||
}
|
||||
|
||||
_, err = r.new(ctx.CookieStorer, givenKey)
|
||||
hash, token, err := GenerateToken(pid)
|
||||
if err != nil {
|
||||
return authboss.InterruptNone, err
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure a half-auth.
|
||||
ctx.SessionStorer.Put(authboss.SessionHalfAuthKey, "true")
|
||||
// Log the user in.
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, givenKey)
|
||||
if err = storer.AddRememberToken(pid, hash); err != nil {
|
||||
return errors.Wrap(err, "failed to save me token")
|
||||
}
|
||||
|
||||
return authboss.InterruptNone, nil
|
||||
authboss.PutSession(w, authboss.SessionKey, pid)
|
||||
authboss.PutSession(w, authboss.SessionHalfAuthKey, "true")
|
||||
authboss.DelCookie(w, authboss.CookieRemember)
|
||||
authboss.PutCookie(w, authboss.CookieRemember, token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterPasswordReset is called after the password has been reset, since
|
||||
// it should invalidate all tokens associated to that user.
|
||||
func (r *Remember) AfterPasswordReset(w http.ResponseWriter, req *http.Request, handled bool) (bool, error) {
|
||||
user, err := r.Authboss.CurrentUser(w, req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
logger := r.Authboss.RequestLogger(req)
|
||||
storer := authboss.EnsureCanRemember(r.Authboss.Config.Storage.Server)
|
||||
|
||||
pid := user.GetPID()
|
||||
authboss.DelCookie(w, authboss.CookieRemember)
|
||||
|
||||
logger.Infof("deleting tokens and rm cookies for user %s due to password reset", pid)
|
||||
|
||||
return false, storer.DelRememberTokens(pid)
|
||||
}
|
||||
|
||||
// GenerateToken creates a remember me token
|
||||
func GenerateToken(pid string) (hash string, token string, err error) {
|
||||
rawToken := make([]byte, nNonceSize+len(pid)+1)
|
||||
copy(rawToken, []byte(pid))
|
||||
rawToken[len(pid)] = ';'
|
||||
|
||||
if _, err := rand.Read(rawToken[len(pid)+1:]); err != nil {
|
||||
return "", "", errors.Wrap(err, "failed to create remember me nonce")
|
||||
}
|
||||
|
||||
sum := sha512.Sum512(rawToken)
|
||||
return base64.StdEncoding.EncodeToString(sum[:]), base64.URLEncoding.EncodeToString(rawToken), nil
|
||||
}
|
||||
|
@ -2,196 +2,343 @@ package remember
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/volatiletech/authboss/internal/mocks"
|
||||
)
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
func TestInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
r := &Remember{}
|
||||
err := r.Initialize(ab)
|
||||
if err == nil {
|
||||
t.Error("Expected error about token storers.")
|
||||
}
|
||||
|
||||
ab.Storage.Server = mocks.MockFailStorer{}
|
||||
err = r.Initialize(ab)
|
||||
if err == nil {
|
||||
t.Error("Expected error about token storers.")
|
||||
}
|
||||
|
||||
ab.Storage.Server = mocks.NewMockStorer()
|
||||
err = r.Initialize(ab)
|
||||
err := r.Init(ab)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterAuth(t *testing.T) {
|
||||
type testHarness struct {
|
||||
remember *Remember
|
||||
ab *authboss.Authboss
|
||||
|
||||
session *mocks.ClientStateRW
|
||||
cookies *mocks.ClientStateRW
|
||||
storer *mocks.ServerStorer
|
||||
}
|
||||
|
||||
func testSetup() *testHarness {
|
||||
harness := &testHarness{}
|
||||
|
||||
harness.ab = authboss.New()
|
||||
harness.session = mocks.NewClientRW()
|
||||
harness.cookies = mocks.NewClientRW()
|
||||
harness.storer = mocks.NewServerStorer()
|
||||
|
||||
harness.ab.Config.Core.Logger = mocks.Logger{}
|
||||
harness.ab.Config.Storage.SessionState = harness.session
|
||||
harness.ab.Config.Storage.CookieState = harness.cookies
|
||||
harness.ab.Config.Storage.Server = harness.storer
|
||||
|
||||
harness.remember = &Remember{harness.ab}
|
||||
|
||||
return harness
|
||||
}
|
||||
|
||||
func TestRememberAfterAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Remember{authboss.New()}
|
||||
storer := mocks.NewMockStorer()
|
||||
r.Storer = storer
|
||||
h := testSetup()
|
||||
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer()
|
||||
user := &mocks.User{Email: "test@test.com"}
|
||||
|
||||
req, err := http.NewRequest("POST", "http://localhost", bytes.NewBufferString("rm=true"))
|
||||
if err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
r := mocks.Request("POST")
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyRM, true))
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
ctx := r.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
ctx.User = authboss.Attributes{r.PrimaryID: "test@email.com"}
|
||||
|
||||
ctx.Values = map[string]string{authboss.CookieRemember: "true"}
|
||||
|
||||
if err := r.afterAuth(ctx); err != nil {
|
||||
t.Error(err)
|
||||
if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if handled {
|
||||
t.Error("should never be handled")
|
||||
}
|
||||
|
||||
if _, ok := cookies.Values[authboss.CookieRemember]; !ok {
|
||||
t.Error("Expected a cookie to have been set.")
|
||||
// Force flush of headers so cookies are written
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if len(h.storer.RMTokens["test@test.com"]) != 1 {
|
||||
t.Error("token was not persisted:", h.storer.RMTokens)
|
||||
}
|
||||
|
||||
if cookie, ok := h.cookies.ClientValues[authboss.CookieRemember]; !ok || len(cookie) == 0 {
|
||||
t.Error("remember me cookie was not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterOAuth(t *testing.T) {
|
||||
func TestRememberAfterAuthSkip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Remember{authboss.New()}
|
||||
storer := mocks.NewMockStorer()
|
||||
r.Storer = storer
|
||||
h := testSetup()
|
||||
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer(authboss.SessionOAuth2Params, `{"rm":"true"}`)
|
||||
r := mocks.Request("POST")
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
ctx := r.NewContext()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
ctx.User = authboss.Attributes{
|
||||
authboss.StoreOAuth2UID: "uid",
|
||||
authboss.StoreOAuth2Provider: "google",
|
||||
if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if handled {
|
||||
t.Error("should never be handled")
|
||||
}
|
||||
|
||||
if err := r.afterOAuth(ctx); err != nil {
|
||||
t.Error(err)
|
||||
if len(h.storer.RMTokens["test@test.com"]) != 0 {
|
||||
t.Error("expected no tokens to be created")
|
||||
}
|
||||
|
||||
if _, ok := cookies.Values[authboss.CookieRemember]; !ok {
|
||||
t.Error("Expected a cookie to have been set.")
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyRM, false))
|
||||
|
||||
if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if handled {
|
||||
t.Error("should never be handled")
|
||||
}
|
||||
|
||||
if len(h.storer.RMTokens["test@test.com"]) != 0 {
|
||||
t.Error("expected no tokens to be created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
user := &mocks.User{Email: "test@test.com"}
|
||||
hash, token, _ := GenerateToken(user.Email)
|
||||
|
||||
h.storer.Users[user.Email] = user
|
||||
h.storer.RMTokens[user.Email] = []string{hash}
|
||||
h.cookies.ClientValues[authboss.CookieRemember] = token
|
||||
|
||||
r := mocks.Request("POST")
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
called := false
|
||||
server := Middleware(h.ab)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
server.ServeHTTP(w, r)
|
||||
|
||||
if !called {
|
||||
t.Error("it should have called the underlying handler")
|
||||
}
|
||||
|
||||
if h.session.ClientValues[authboss.SessionKey] != user.Email {
|
||||
t.Error("should have saved the pid in the session")
|
||||
}
|
||||
// Elided the rest of the checks, authenticate tests do this
|
||||
}
|
||||
|
||||
func TestAuthenticateSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
user := &mocks.User{Email: "test@test.com"}
|
||||
hash, token, _ := GenerateToken(user.Email)
|
||||
|
||||
h.storer.Users[user.Email] = user
|
||||
h.storer.RMTokens[user.Email] = []string{hash}
|
||||
h.cookies.ClientValues[authboss.CookieRemember] = token
|
||||
|
||||
r := mocks.Request("POST")
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = Authenticate(h.ab, w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if cookie := h.cookies.ClientValues[authboss.CookieRemember]; cookie == token {
|
||||
t.Error("the cookie should have been replaced with a new token")
|
||||
}
|
||||
|
||||
if len(h.storer.RMTokens[user.Email]) != 1 {
|
||||
t.Error("one token should have been removed, and one should have been added")
|
||||
} else if h.storer.RMTokens[user.Email][0] == token {
|
||||
t.Error("a new token should have been saved")
|
||||
}
|
||||
|
||||
if h.session.ClientValues[authboss.SessionKey] != user.Email {
|
||||
t.Error("should have saved the pid in the session")
|
||||
}
|
||||
if h.session.ClientValues[authboss.SessionHalfAuthKey] != "true" {
|
||||
t.Error("it should have become a half-authed session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticateTokenNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
user := &mocks.User{Email: "test@test.com"}
|
||||
_, token, _ := GenerateToken(user.Email)
|
||||
|
||||
h.storer.Users[user.Email] = user
|
||||
h.cookies.ClientValues[authboss.CookieRemember] = token
|
||||
|
||||
r := mocks.Request("POST")
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = Authenticate(h.ab, w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 {
|
||||
t.Error("there should be no remember cookie left")
|
||||
}
|
||||
|
||||
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
|
||||
t.Error("it should have not logged the user in")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticateBadTokens(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
doTest := func(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
r := mocks.Request("POST")
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = Authenticate(h.ab, w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 {
|
||||
t.Error("there should be no remember cookie left")
|
||||
}
|
||||
|
||||
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
|
||||
t.Error("it should have not logged the user in")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("base64", func(t *testing.T) {
|
||||
h.cookies.ClientValues[authboss.CookieRemember] = "a"
|
||||
doTest(t)
|
||||
})
|
||||
t.Run("cookieformat", func(t *testing.T) {
|
||||
h.cookies.ClientValues[authboss.CookieRemember] = `aGVsbG8=` // hello
|
||||
doTest(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAfterPasswordReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := Remember{authboss.New()}
|
||||
h := testSetup()
|
||||
|
||||
id := "test@email.com"
|
||||
user := &mocks.User{Email: "test@test.com"}
|
||||
hash1, _, _ := GenerateToken(user.Email)
|
||||
hash2, token2, _ := GenerateToken(user.Email)
|
||||
|
||||
storer := mocks.NewMockStorer()
|
||||
r.Storer = storer
|
||||
session := mocks.NewMockClientStorer()
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
storer.Tokens[id] = []string{"one", "two"}
|
||||
cookies.Values[authboss.CookieRemember] = "token"
|
||||
h.storer.Users[user.Email] = user
|
||||
h.storer.RMTokens[user.Email] = []string{hash1, hash2}
|
||||
h.cookies.ClientValues[authboss.CookieRemember] = token2
|
||||
|
||||
ctx := r.NewContext()
|
||||
ctx.User = authboss.Attributes{r.PrimaryID: id}
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
r := mocks.Request("POST")
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
|
||||
rec := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(rec, r)
|
||||
|
||||
if err := r.afterPassword(ctx); err != nil {
|
||||
if handled, err := h.remember.AfterPasswordReset(w, r, false); err != nil {
|
||||
t.Error(err)
|
||||
} else if handled {
|
||||
t.Error("it should never be handled")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK) // Force header flush
|
||||
|
||||
if len(h.storer.RMTokens[user.Email]) != 0 {
|
||||
t.Error("all remember me tokens should have been removed")
|
||||
}
|
||||
if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 {
|
||||
t.Error("there should be no remember cookie left")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hash, tok, err := GenerateToken("test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rawToken, err := base64.URLEncoding.DecodeString(tok)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if _, ok := cookies.Values[authboss.CookieRemember]; ok {
|
||||
t.Error("Expected the remember cookie to be deleted.")
|
||||
index := bytes.IndexByte(rawToken, ';')
|
||||
if index < 0 {
|
||||
t.Fatalf("problem with the token format: %v", rawToken)
|
||||
}
|
||||
|
||||
if len(storer.Tokens) != 0 {
|
||||
t.Error("Should have wiped out all tokens.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := &Remember{authboss.New()}
|
||||
storer := mocks.NewMockStorer()
|
||||
r.Storer = storer
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
|
||||
key := "tester"
|
||||
token, err := r.new(cookies, key)
|
||||
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
if len(token) == 0 {
|
||||
t.Error("Expected a token.")
|
||||
}
|
||||
|
||||
if tok, ok := storer.Tokens[key]; !ok {
|
||||
t.Error("Expected it to store against the key:", key)
|
||||
} else if len(tok) != 1 || len(tok[0]) == 0 {
|
||||
t.Error("Expected a token to be saved.")
|
||||
}
|
||||
|
||||
if token != cookies.Values[authboss.CookieRemember] {
|
||||
t.Error("Expected a cookie set with the token.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := &Remember{authboss.New()}
|
||||
storer := mocks.NewMockStorer()
|
||||
r.Storer = storer
|
||||
|
||||
cookies := mocks.NewMockClientStorer()
|
||||
session := mocks.NewMockClientStorer()
|
||||
ctx := r.NewContext()
|
||||
ctx.CookieStorer = cookies
|
||||
ctx.SessionStorer = session
|
||||
|
||||
key := "tester"
|
||||
_, err := r.new(cookies, key)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
cookie, _ := cookies.Get(authboss.CookieRemember)
|
||||
|
||||
interrupt, err := r.auth(ctx)
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
if session.Values[authboss.SessionHalfAuthKey] != "true" {
|
||||
t.Error("The user should have been half-authed.")
|
||||
}
|
||||
|
||||
if session.Values[authboss.SessionKey] != key {
|
||||
t.Error("The user should have been logged in.")
|
||||
}
|
||||
|
||||
if chocolateChip, _ := cookies.Get(authboss.CookieRemember); chocolateChip == cookie {
|
||||
t.Error("Expected cookie to be different")
|
||||
}
|
||||
|
||||
if authboss.InterruptNone != interrupt {
|
||||
t.Error("Keys should have matched:", interrupt)
|
||||
bytPID := rawToken[:index]
|
||||
if string(bytPID) != "test" {
|
||||
t.Errorf("pid wrong: %s", bytPID)
|
||||
}
|
||||
|
||||
sum := sha512.Sum512(rawToken)
|
||||
gotHash := base64.StdEncoding.EncodeToString(sum[:])
|
||||
if hash != gotHash {
|
||||
t.Errorf("hash wrong, want: %s, got: %s", hash, gotHash)
|
||||
}
|
||||
}
|
||||
|
33
storage.go
33
storage.go
@ -30,13 +30,13 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID
|
||||
// of the record is found.
|
||||
ErrUserFound = errors.New("user found")
|
||||
// ErrUserNotFound should be returned from Get when the record is not found.
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
// ErrTokenNotFound should be returned from UseToken when the record is not found.
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
// ErrUserFound should be returned from Create (see ConfirmUser) when the primaryID
|
||||
// of the record is found.
|
||||
ErrUserFound = errors.New("user found")
|
||||
)
|
||||
|
||||
// ServerStorer represents the data store that's capable of loading users
|
||||
@ -82,11 +82,22 @@ type RecoveringServerStorer interface {
|
||||
LoadByRecoverToken(ctx context.Context, token string) (RecoverableUser, error)
|
||||
}
|
||||
|
||||
// RememberingServerStorer allows users to be remembered across sessions
|
||||
type RememberingServerStorer interface {
|
||||
// AddRememberToken to a user
|
||||
AddRememberToken(pid, token string) error
|
||||
// DelRememberTokens removes all tokens for the given pid
|
||||
DelRememberTokens(pid string) error
|
||||
// UseRememberToken finds the pid-token pair and deletes it.
|
||||
// If the token could not be found return ErrTokenNotFound
|
||||
UseRememberToken(pid, token string) error
|
||||
}
|
||||
|
||||
// EnsureCanCreate makes sure the server storer supports create operations
|
||||
func EnsureCanCreate(storer ServerStorer) CreatingServerStorer {
|
||||
s, ok := storer.(CreatingServerStorer)
|
||||
if !ok {
|
||||
panic("could not upgrade serverstorer to creatingserverstorer, check your struct")
|
||||
panic("could not upgrade ServerStorer to CreatingServerStorer, check your struct")
|
||||
}
|
||||
|
||||
return s
|
||||
@ -96,7 +107,7 @@ func EnsureCanCreate(storer ServerStorer) CreatingServerStorer {
|
||||
func EnsureCanConfirm(storer ServerStorer) ConfirmingServerStorer {
|
||||
s, ok := storer.(ConfirmingServerStorer)
|
||||
if !ok {
|
||||
panic("could not upgrade serverstorer to confirmingserverstorer, check your struct")
|
||||
panic("could not upgrade ServerStorer to ConfirmingServerStorer, check your struct")
|
||||
}
|
||||
|
||||
return s
|
||||
@ -106,7 +117,17 @@ func EnsureCanConfirm(storer ServerStorer) ConfirmingServerStorer {
|
||||
func EnsureCanRecover(storer ServerStorer) RecoveringServerStorer {
|
||||
s, ok := storer.(RecoveringServerStorer)
|
||||
if !ok {
|
||||
panic("could not upgrade serverstorer to recoveringserverstorer, check your struct")
|
||||
panic("could not upgrade ServerStorer to RecoveringServerStorer, check your struct")
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// EnsureCanRemember makes sure the server storer supports remember operations
|
||||
func EnsureCanRemember(storer ServerStorer) RememberingServerStorer {
|
||||
s, ok := storer.(RememberingServerStorer)
|
||||
if !ok {
|
||||
panic("could not upgrade ServerStorer to RememberingServerStorer, check your struct")
|
||||
}
|
||||
|
||||
return s
|
||||
|
Loading…
Reference in New Issue
Block a user