mirror of
https://github.com/volatiletech/authboss.git
synced 2024-11-28 08:58:38 +02:00
Finish otp module
This commit is contained in:
parent
b7cec028b9
commit
6164dd8da4
15
CHANGELOG.md
15
CHANGELOG.md
@ -7,9 +7,24 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
|
||||
### Added
|
||||
|
||||
- Add OTP module so users can create one time passwords and use them to
|
||||
log in.
|
||||
- Add more documentation about how RegisterPreserveFields works so people
|
||||
don't have to chase the godocs to figure out how to implement it.
|
||||
|
||||
### Changed
|
||||
|
||||
- authboss.Middleware now has two boolean flags to provide more control over
|
||||
how unathenticated users are dealt with. It can now redirect users to
|
||||
the login screen with a redirect to the page they were attempting to reach
|
||||
and it can also protect against half-authed users.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Ensure all uses of crypto/rand.Read are replaced by io.ReadFull(rand.Reader)
|
||||
to ensure that we never get a read that's full of zeroes. This was a bug
|
||||
present in a uuid library, we don't want to make the same mistake.
|
||||
|
||||
## [2.0.0-rc5] - 2018-07-04
|
||||
|
||||
### Changed
|
||||
|
206
authboss_test.go
206
authboss_test.go
@ -35,64 +35,176 @@ func TestAuthbossUpdatePassword(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type testRedirector struct {
|
||||
Opts RedirectOptions
|
||||
}
|
||||
|
||||
func (r *testRedirector) Redirect(w http.ResponseWriter, req *http.Request, ro RedirectOptions) error {
|
||||
r.Opts = ro
|
||||
if len(ro.RedirectPath) == 0 {
|
||||
panic("no redirect path on redirect call")
|
||||
}
|
||||
http.Redirect(w, req, ro.RedirectPath, ro.Code)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAuthbossMiddleware(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.Core.Logger = mockLogger{}
|
||||
|
||||
mid := Middleware(ab)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
w := ab.NewResponse(rec)
|
||||
|
||||
called := false
|
||||
hadUser := false
|
||||
server := mid(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
hadUser = r.Context().Value(CTXKeyUser) != nil
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.ServeHTTP(w, r)
|
||||
if called || hadUser {
|
||||
t.Error("should not be called or have a user when no session variables have been provided")
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Error("want a not found code")
|
||||
}
|
||||
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{
|
||||
state: mockClientState{SessionKey: "test@test.com"},
|
||||
}
|
||||
ab.Storage.Server = &mockServerStorer{
|
||||
Users: map[string]*mockUser{
|
||||
"test@test.com": &mockUser{},
|
||||
},
|
||||
}
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
w = ab.NewResponse(rec)
|
||||
setupMore := func(redirect bool, allowHalfAuth bool) (*httptest.ResponseRecorder, bool, bool) {
|
||||
r := httptest.NewRequest("GET", "/super/secret", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
w := ab.NewResponse(rec)
|
||||
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.ServeHTTP(w, r)
|
||||
if !called {
|
||||
t.Error("it should have been called")
|
||||
}
|
||||
if !hadUser {
|
||||
t.Error("it should have had a user loaded")
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Error("want a not found code")
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mid := Middleware(ab, redirect, allowHalfAuth)
|
||||
var called, hadUser bool
|
||||
server := mid(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
hadUser = r.Context().Value(CTXKeyUser) != nil
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
server.ServeHTTP(w, r)
|
||||
|
||||
return rec, called, hadUser
|
||||
}
|
||||
|
||||
t.Run("Accept", func(t *testing.T) {
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{
|
||||
state: mockClientState{SessionKey: "test@test.com"},
|
||||
}
|
||||
|
||||
_, called, hadUser := setupMore(false, false)
|
||||
|
||||
if !called {
|
||||
t.Error("should have been called")
|
||||
}
|
||||
if !hadUser {
|
||||
t.Error("should have had user")
|
||||
}
|
||||
})
|
||||
t.Run("AcceptHalfAuth", func(t *testing.T) {
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{
|
||||
state: mockClientState{SessionKey: "test@test.com", SessionHalfAuthKey: "true"},
|
||||
}
|
||||
|
||||
_, called, hadUser := setupMore(false, true)
|
||||
|
||||
if !called {
|
||||
t.Error("should have been called")
|
||||
}
|
||||
if !hadUser {
|
||||
t.Error("should have had user")
|
||||
}
|
||||
})
|
||||
t.Run("Reject404", func(t *testing.T) {
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{}
|
||||
|
||||
rec, called, hadUser := setupMore(false, false)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Error("wrong code:", rec.Code)
|
||||
}
|
||||
if called {
|
||||
t.Error("should not have been called")
|
||||
}
|
||||
if hadUser {
|
||||
t.Error("should not have had user")
|
||||
}
|
||||
})
|
||||
t.Run("RejectRedirect", func(t *testing.T) {
|
||||
redir := &testRedirector{}
|
||||
ab.Config.Core.Redirector = redir
|
||||
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{}
|
||||
|
||||
_, called, hadUser := setupMore(true, false)
|
||||
|
||||
if redir.Opts.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", redir.Opts.Code)
|
||||
}
|
||||
if redir.Opts.RedirectPath != "/auth/login?redir=%2Fsuper%2Fsecret" {
|
||||
t.Error("redirect path was wrong:", redir.Opts.RedirectPath)
|
||||
}
|
||||
if called {
|
||||
t.Error("should not have been called")
|
||||
}
|
||||
if hadUser {
|
||||
t.Error("should not have had user")
|
||||
}
|
||||
})
|
||||
t.Run("RejectHalfAuth", func(t *testing.T) {
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{
|
||||
state: mockClientState{SessionKey: "test@test.com", SessionHalfAuthKey: "true"},
|
||||
}
|
||||
|
||||
rec, called, hadUser := setupMore(false, false)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Error("wrong code:", rec.Code)
|
||||
}
|
||||
if called {
|
||||
t.Error("should not have been called")
|
||||
}
|
||||
if hadUser {
|
||||
t.Error("should not have had user")
|
||||
}
|
||||
})
|
||||
|
||||
/*
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.ServeHTTP(w, r)
|
||||
if called || hadUser {
|
||||
t.Error("should not be called or have a user when no session variables have been provided")
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Error("want a not found code")
|
||||
}
|
||||
|
||||
ab.Storage.SessionState = mockClientStateReadWriter{
|
||||
state: mockClientState{SessionKey: "test@test.com"},
|
||||
}
|
||||
ab.Storage.Server = &mockServerStorer{
|
||||
Users: map[string]*mockUser{
|
||||
"test@test.com": &mockUser{},
|
||||
},
|
||||
}
|
||||
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
w = ab.NewResponse(rec)
|
||||
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.ServeHTTP(w, r)
|
||||
if !called {
|
||||
t.Error("it should have been called")
|
||||
}
|
||||
if !hadUser {
|
||||
t.Error("it should have had a user loaded")
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Error("want a not found code")
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ func SetCore(config *authboss.Config, readJSON, useUsername bool) {
|
||||
config.Core.Router = NewRouter()
|
||||
config.Core.ErrorHandler = NewErrorHandler(logger)
|
||||
config.Core.Responder = NewResponder(config.Core.ViewRenderer)
|
||||
config.Core.Redirector = NewRedirector(config.Core.ViewRenderer, RedirectFormValueName)
|
||||
config.Core.Redirector = NewRedirector(config.Core.ViewRenderer, authboss.FormValueRedirect)
|
||||
config.Core.BodyReader = NewHTTPBodyReader(readJSON, useUsername)
|
||||
config.Core.Mailer = NewLogMailer(os.Stdout)
|
||||
config.Core.Logger = logger
|
||||
|
@ -34,6 +34,8 @@ type User struct {
|
||||
OAuth2Refresh string
|
||||
OAuth2Expiry time.Time
|
||||
|
||||
OTPs string
|
||||
|
||||
Arbitrary map[string]string
|
||||
}
|
||||
|
||||
@ -97,6 +99,9 @@ func (u User) GetOAuth2Expiry() time.Time { return u.OAuth2Expiry }
|
||||
// GetArbitrary from user
|
||||
func (u User) GetArbitrary() map[string]string { return u.Arbitrary }
|
||||
|
||||
// GetOTPs from user
|
||||
func (u User) GetOTPs() string { return u.OTPs }
|
||||
|
||||
// PutPID into user
|
||||
func (u *User) PutPID(email string) { u.Email = email }
|
||||
|
||||
@ -156,6 +161,9 @@ func (u *User) PutOAuth2Expiry(expiry time.Time) { u.OAuth2Expiry = expiry }
|
||||
// PutArbitrary into user
|
||||
func (u *User) PutArbitrary(arb map[string]string) { u.Arbitrary = arb }
|
||||
|
||||
// PutOTPs into user
|
||||
func (u *User) PutOTPs(otps string) { u.OTPs = otps }
|
||||
|
||||
// ServerStorer should be valid for any module storer defined in authboss.
|
||||
type ServerStorer struct {
|
||||
Users map[string]*User
|
||||
|
123
otp/otp.go
123
otp/otp.go
@ -4,31 +4,40 @@ package otp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha512"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/volatiletech/authboss"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
otpSize = 16
|
||||
|
||||
// PageLogin is for identifying the login page for parsing & validation
|
||||
PageLogin = "loginotp"
|
||||
PageLogin = "otplogin"
|
||||
// PageAdd is for adding an otp to the user
|
||||
PageAdd = "addotp"
|
||||
PageAdd = "otpadd"
|
||||
// PageClear is for deleting all the otps from the user
|
||||
PageClear = "clearotp"
|
||||
PageClear = "otpclear"
|
||||
|
||||
// DataNumberOTPs shows the number of otps for add/clear operations
|
||||
DataNumberOTPs = "notps"
|
||||
// DataNewOTP shows the new otp that was added
|
||||
DataNumberOTPs = "otp_count"
|
||||
// DataOTP shows the new otp that was added
|
||||
DataOTP = "otp"
|
||||
)
|
||||
|
||||
// User for one time passwords
|
||||
type User interface {
|
||||
authboss.User
|
||||
|
||||
// GetOTPs retrieves a string of comma separated bcrypt'd one time passwords
|
||||
GetOTPs() string
|
||||
// PutOTPs puts a string of comma separated bcrypt'd one time passwords
|
||||
@ -39,7 +48,7 @@ type User interface {
|
||||
func MustBeOTPable(user authboss.User) User {
|
||||
u, ok := user.(User)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("could not upgrade user to an authable user, type: %T", u))
|
||||
panic(fmt.Sprintf("could not upgrade user to an otpable user, type: %T", u))
|
||||
}
|
||||
|
||||
return u
|
||||
@ -105,21 +114,25 @@ func (o *OTP) LoginPost(w http.ResponseWriter, r *http.Request) error {
|
||||
}
|
||||
|
||||
otpUser := MustBeOTPable(pidUser)
|
||||
passwords := decodeOTPs(otpUser.GetOTPs())
|
||||
passwords := splitOTPs(otpUser.GetOTPs())
|
||||
|
||||
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, pidUser))
|
||||
|
||||
input := creds.GetPassword()
|
||||
inputSum := sha512.Sum512([]byte(creds.GetPassword()))
|
||||
matchPassword := -1
|
||||
handled := false
|
||||
for i, p := range passwords {
|
||||
err = bcrypt.CompareHashAndPassword([]byte(p), []byte(input))
|
||||
if err == nil {
|
||||
dbSum, err := base64.StdEncoding.DecodeString(p)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "otp in database was not valid base64")
|
||||
}
|
||||
|
||||
if 1 == subtle.ConstantTimeCompare(inputSum[:], dbSum) {
|
||||
matchPassword = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
handled := false
|
||||
if matchPassword < 0 {
|
||||
handled, err = o.Authboss.Events.FireAfter(authboss.EventAuthFail, w, r)
|
||||
if err != nil {
|
||||
@ -133,9 +146,10 @@ func (o *OTP) LoginPost(w http.ResponseWriter, r *http.Request) error {
|
||||
return o.Authboss.Core.Responder.Respond(w, r, http.StatusOK, PageLogin, data)
|
||||
}
|
||||
|
||||
logger.Infof("removing otp password from %s", pid)
|
||||
passwords[matchPassword] = passwords[len(passwords)-1]
|
||||
passwords = passwords[:len(passwords)-1]
|
||||
otpUser.PutOTPs(encodeOTPs(passwords))
|
||||
otpUser.PutOTPs(joinOTPs(passwords))
|
||||
if err = o.Authboss.Config.Storage.Server.Save(r.Context(), pidUser); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -170,17 +184,7 @@ func (o *OTP) LoginPost(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
// AddGet shows how many passwords exist and allows the user to create a new one
|
||||
func (o *OTP) AddGet(w http.ResponseWriter, r *http.Request) error {
|
||||
logger := o.RequestLogger(r)
|
||||
|
||||
user, err := o.Authboss.CurrentUser(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
otpUser := MustBeOTPable(user)
|
||||
ln := strconv.Itoa(len(decodeOTPs(otpUser.GetOTPs())))
|
||||
|
||||
return o.Core.Responder.Respond(w, r, http.StatusOK, PageAdd, authboss.HTMLData{NumberOTPS: ln})
|
||||
return o.showOTPCount(w, r, PageAdd)
|
||||
}
|
||||
|
||||
// AddPost adds a new password to the user and displays it
|
||||
@ -192,14 +196,16 @@ func (o *OTP) AddPost(w http.ResponseWriter, r *http.Request) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// GENERATE AN OTP
|
||||
panic("otp not generated")
|
||||
otp := ""
|
||||
logger.Infof("generating otp for %s", user.GetPID())
|
||||
otp, hash, err := generateOTP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
otpUser := MustBeOTPable(user)
|
||||
currentOTPs := decodeOTPs(otpUser.GetOTPs())
|
||||
currentOTPs = append(currentOTPs, otp)
|
||||
otpUser.PutOTPs(encodeOTPs(currentOTPs))
|
||||
currentOTPs := splitOTPs(otpUser.GetOTPs())
|
||||
currentOTPs = append(currentOTPs, hash)
|
||||
otpUser.PutOTPs(joinOTPs(currentOTPs))
|
||||
|
||||
if err := o.Authboss.Config.Storage.Server.Save(r.Context(), user); err != nil {
|
||||
return err
|
||||
@ -210,23 +216,70 @@ func (o *OTP) AddPost(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
// ClearGet shows how many passwords exist and allows the user to clear them all
|
||||
func (o *OTP) ClearGet(w http.ResponseWriter, r *http.Request) error {
|
||||
return o.Core.Responder.Respond(w, r, http.StatusOK, PageClear, nil)
|
||||
return o.showOTPCount(w, r, PageClear)
|
||||
}
|
||||
|
||||
// ClearPost clears all otps that are stored for the user.
|
||||
func (o *OTP) ClearPost(w http.ResponseWriter, r *http.Request) error {
|
||||
panic("not implemented")
|
||||
return nil
|
||||
logger := o.RequestLogger(r)
|
||||
|
||||
user, err := o.Authboss.CurrentUser(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("clearing all otps for user: %s", user.GetPID())
|
||||
otpUser := MustBeOTPable(user)
|
||||
otpUser.PutOTPs("")
|
||||
|
||||
if err := o.Authboss.Config.Storage.Server.Save(r.Context(), user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return o.Core.Responder.Respond(w, r, http.StatusOK, PageAdd, authboss.HTMLData{DataNumberOTPs: "0"})
|
||||
}
|
||||
|
||||
func encodeOTPs(otps []string) string {
|
||||
func (o *OTP) showOTPCount(w http.ResponseWriter, r *http.Request, page string) error {
|
||||
user, err := o.Authboss.CurrentUser(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
otpUser := MustBeOTPable(user)
|
||||
ln := strconv.Itoa(len(splitOTPs(otpUser.GetOTPs())))
|
||||
|
||||
return o.Core.Responder.Respond(w, r, http.StatusOK, page, authboss.HTMLData{DataNumberOTPs: ln})
|
||||
}
|
||||
|
||||
func joinOTPs(otps []string) string {
|
||||
return strings.Join(otps, ",")
|
||||
}
|
||||
|
||||
func decodeOTPs(otps string) []string {
|
||||
func splitOTPs(otps string) []string {
|
||||
if len(otps) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return strings.Split(otps, ",")
|
||||
}
|
||||
|
||||
func generateOTP() (otp string, hash string, err error) {
|
||||
secret := make([]byte, otpSize)
|
||||
if _, err = io.ReadFull(rand.Reader, secret); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
otp = fmt.Sprintf("%x-%x-%x-%x",
|
||||
secret[0:4],
|
||||
secret[4:8],
|
||||
secret[8:12],
|
||||
secret[12:16],
|
||||
)
|
||||
|
||||
sum := sha512.Sum512([]byte(otp))
|
||||
encoded := make([]byte, base64.StdEncoding.EncodedLen(sha512.Size))
|
||||
base64.StdEncoding.Encode(encoded, sum[:])
|
||||
hash = string(encoded)
|
||||
|
||||
return otp, hash, nil
|
||||
}
|
||||
|
589
otp/otp_test.go
Normal file
589
otp/otp_test.go
Normal file
@ -0,0 +1,589 @@
|
||||
package otp
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/volatiletech/authboss/internal/mocks"
|
||||
)
|
||||
|
||||
type testUser struct {
|
||||
PID string
|
||||
OTPs string
|
||||
}
|
||||
|
||||
func (t *testUser) GetPID() string { return t.PID }
|
||||
func (t *testUser) PutPID(pid string) { t.PID = pid }
|
||||
func (t *testUser) GetOTPs() string { return t.OTPs }
|
||||
func (t *testUser) PutOTPs(otps string) { t.OTPs = otps }
|
||||
|
||||
func TestMustBeOTPable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var user authboss.User = &testUser{}
|
||||
_ = MustBeOTPable(user)
|
||||
}
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
router := &mocks.Router{}
|
||||
renderer := &mocks.Renderer{}
|
||||
errHandler := &mocks.ErrorHandler{}
|
||||
|
||||
ab.Config.Core.Router = router
|
||||
ab.Config.Core.ViewRenderer = renderer
|
||||
ab.Config.Core.ErrorHandler = errHandler
|
||||
|
||||
o := &OTP{}
|
||||
if err := o.Init(ab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
routes := []string{"/otp/login", "/otp/add", "/otp/clear"}
|
||||
if err := router.HasGets(routes...); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := router.HasPosts(routes...); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
responder := &mocks.Responder{}
|
||||
ab.Config.Core.Responder = responder
|
||||
|
||||
a := &OTP{ab}
|
||||
a.LoginGet(nil, nil)
|
||||
|
||||
if responder.Page != PageLogin {
|
||||
t.Error("wanted login page, got:", responder.Page)
|
||||
}
|
||||
|
||||
if responder.Status != http.StatusOK {
|
||||
t.Error("wanted ok status, got:", responder.Status)
|
||||
}
|
||||
}
|
||||
|
||||
type testHarness struct {
|
||||
otp *OTP
|
||||
ab *authboss.Authboss
|
||||
|
||||
bodyReader *mocks.BodyReader
|
||||
responder *mocks.Responder
|
||||
redirector *mocks.Redirector
|
||||
session *mocks.ClientStateRW
|
||||
storer *mocks.ServerStorer
|
||||
}
|
||||
|
||||
func testSetup() *testHarness {
|
||||
harness := &testHarness{}
|
||||
|
||||
harness.ab = authboss.New()
|
||||
harness.bodyReader = &mocks.BodyReader{}
|
||||
harness.redirector = &mocks.Redirector{}
|
||||
harness.responder = &mocks.Responder{}
|
||||
harness.session = mocks.NewClientRW()
|
||||
harness.storer = mocks.NewServerStorer()
|
||||
|
||||
harness.ab.Config.Paths.AuthLoginOK = "/login/ok"
|
||||
|
||||
harness.ab.Config.Core.BodyReader = harness.bodyReader
|
||||
harness.ab.Config.Core.Logger = mocks.Logger{}
|
||||
harness.ab.Config.Core.Responder = harness.responder
|
||||
harness.ab.Config.Core.Redirector = harness.redirector
|
||||
harness.ab.Config.Storage.SessionState = harness.session
|
||||
harness.ab.Config.Storage.Server = harness.storer
|
||||
|
||||
harness.otp = &OTP{harness.ab}
|
||||
|
||||
return harness
|
||||
}
|
||||
|
||||
func TestLoginPostSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupMore := func(h *testHarness) *testHarness {
|
||||
h.bodyReader.Return = mocks.Values{
|
||||
PID: "test@test.com",
|
||||
Password: "3cc94671-958a912d-bd5a3ba7-3326a380",
|
||||
}
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
// 3cc94671-958a912d-bd5a3ba7-3326a380
|
||||
OTPs: "2aID,2aIDHxmTIy1W7Uyz9c+iqhOJSE0a2Yna3zTRTs2q/X7Bv3xdVjExoztBEG4sQ2Nn3jcaPxdIuhslvSsjaYK5uA==",
|
||||
}
|
||||
h.session.ClientValues[authboss.SessionHalfAuthKey] = "true"
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
var beforeCalled, afterCalled bool
|
||||
var beforeHasValues, afterHasValues bool
|
||||
h.ab.Events.Before(authboss.EventAuth, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
beforeCalled = true
|
||||
beforeHasValues = r.Context().Value(authboss.CTXKeyValues) != nil
|
||||
return false, nil
|
||||
})
|
||||
h.ab.Events.After(authboss.EventAuth, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
afterCalled = true
|
||||
afterHasValues = r.Context().Value(authboss.CTXKeyValues) != nil
|
||||
return false, nil
|
||||
})
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp)
|
||||
|
||||
if err := h.otp.LoginPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", resp.Code)
|
||||
}
|
||||
if h.redirector.Options.RedirectPath != "/login/ok" {
|
||||
t.Error("redirect path was wrong:", h.redirector.Options.RedirectPath)
|
||||
}
|
||||
|
||||
if _, ok := h.session.ClientValues[authboss.SessionHalfAuthKey]; ok {
|
||||
t.Error("half auth should have been deleted")
|
||||
}
|
||||
if pid := h.session.ClientValues[authboss.SessionKey]; pid != "test@test.com" {
|
||||
t.Error("pid was wrong:", pid)
|
||||
}
|
||||
|
||||
// Remaining length of the chunk of base64 is 4 characters
|
||||
if len(h.storer.Users["test@test.com"].OTPs) != 4 {
|
||||
t.Error("the user should have used one of his OTPs")
|
||||
}
|
||||
|
||||
if !beforeCalled {
|
||||
t.Error("before should have been called")
|
||||
}
|
||||
if !afterCalled {
|
||||
t.Error("after should have been called")
|
||||
}
|
||||
if !beforeHasValues {
|
||||
t.Error("before callback should have access to values")
|
||||
}
|
||||
if !afterHasValues {
|
||||
t.Error("after callback should have access to values")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handledBefore", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
var beforeCalled bool
|
||||
h.ab.Events.Before(authboss.EventAuth, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
beforeCalled = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp)
|
||||
|
||||
if err := h.otp.LoginPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.responder.Status != 0 {
|
||||
t.Error("a status should never have been sent back")
|
||||
}
|
||||
if _, ok := h.session.ClientValues[authboss.SessionKey]; ok {
|
||||
t.Error("session key should not have been set")
|
||||
}
|
||||
|
||||
if !beforeCalled {
|
||||
t.Error("before should have been called")
|
||||
}
|
||||
if resp.Code != http.StatusTeapot {
|
||||
t.Error("should have left the response alone once teapot was sent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handledAfter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
var afterCalled bool
|
||||
h.ab.Events.After(authboss.EventAuth, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
afterCalled = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp)
|
||||
|
||||
if err := h.otp.LoginPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.responder.Status != 0 {
|
||||
t.Error("a status should never have been sent back")
|
||||
}
|
||||
if _, ok := h.session.ClientValues[authboss.SessionKey]; !ok {
|
||||
t.Error("session key should have been set")
|
||||
}
|
||||
|
||||
if !afterCalled {
|
||||
t.Error("after should have been called")
|
||||
}
|
||||
if resp.Code != http.StatusTeapot {
|
||||
t.Error("should have left the response alone once teapot was sent")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginPostBadPassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupMore := func(h *testHarness) *testHarness {
|
||||
h.bodyReader.Return = mocks.Values{
|
||||
PID: "test@test.com",
|
||||
Password: "nope",
|
||||
}
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
Password: "", // hello world
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp)
|
||||
|
||||
var afterCalled bool
|
||||
h.ab.Events.After(authboss.EventAuthFail, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
afterCalled = true
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err := h.otp.LoginPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if resp.Code != 200 {
|
||||
t.Error("wanted a 200:", resp.Code)
|
||||
}
|
||||
|
||||
if h.responder.Data[authboss.DataErr] != "Invalid Credentials" {
|
||||
t.Error("wrong error:", h.responder.Data)
|
||||
}
|
||||
|
||||
if _, ok := h.session.ClientValues[authboss.SessionKey]; ok {
|
||||
t.Error("user should not be logged in")
|
||||
}
|
||||
|
||||
if !afterCalled {
|
||||
t.Error("after should have been called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handledAfter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp)
|
||||
|
||||
var afterCalled bool
|
||||
h.ab.Events.After(authboss.EventAuthFail, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
afterCalled = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
if err := h.otp.LoginPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.responder.Status != 0 {
|
||||
t.Error("responder should not have been called to give a status")
|
||||
}
|
||||
if _, ok := h.session.ClientValues[authboss.SessionKey]; ok {
|
||||
t.Error("user should not be logged in")
|
||||
}
|
||||
|
||||
if !afterCalled {
|
||||
t.Error("after should have been called")
|
||||
}
|
||||
if resp.Code != http.StatusTeapot {
|
||||
t.Error("should have left the response alone once teapot was sent")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthPostUserNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
harness := testSetup()
|
||||
harness.bodyReader.Return = mocks.Values{
|
||||
PID: "test@test.com",
|
||||
Password: "world hello",
|
||||
}
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := harness.ab.NewResponse(resp)
|
||||
|
||||
// This event is really the only thing that separates "user not found" from "bad password"
|
||||
var afterCalled bool
|
||||
harness.ab.Events.After(authboss.EventAuthFail, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
afterCalled = true
|
||||
return false, nil
|
||||
})
|
||||
|
||||
if err := harness.otp.LoginPost(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if resp.Code != 200 {
|
||||
t.Error("wanted a 200:", resp.Code)
|
||||
}
|
||||
|
||||
if harness.responder.Data[authboss.DataErr] != "Invalid Credentials" {
|
||||
t.Error("wrong error:", harness.responder.Data)
|
||||
}
|
||||
|
||||
if _, ok := harness.session.ClientValues[authboss.SessionKey]; ok {
|
||||
t.Error("user should not be logged in")
|
||||
}
|
||||
|
||||
if afterCalled {
|
||||
t.Error("after should not have been called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
// 3cc94671-958a912d-bd5a3ba7-3326a380
|
||||
OTPs: "2aID,2aIDHxmTIy1W7Uyz9c+iqhOJSE0a2Yna3zTRTs2q/X7Bv3xdVjExoztBEG4sQ2Nn3jcaPxdIuhslvSsjaYK5uA==",
|
||||
}
|
||||
h.session.ClientValues[authboss.SessionKey] = "test@test.com"
|
||||
|
||||
r := mocks.Request("POST")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := h.otp.AddGet(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if h.responder.Page != PageAdd {
|
||||
t.Error("wanted add page, got:", h.responder.Page)
|
||||
}
|
||||
|
||||
if h.responder.Status != http.StatusOK {
|
||||
t.Error("wanted ok status, got:", h.responder.Status)
|
||||
}
|
||||
|
||||
if ln := h.responder.Data[DataNumberOTPs]; ln != "2" {
|
||||
t.Error("want two otps:", ln)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
uname := "test@test.com"
|
||||
h.storer.Users[uname] = &mocks.User{
|
||||
Email: uname,
|
||||
// 3cc94671-958a912d-bd5a3ba7-3326a380
|
||||
OTPs: "2aID,2aIDHxmTIy1W7Uyz9c+iqhOJSE0a2Yna3zTRTs2q/X7Bv3xdVjExoztBEG4sQ2Nn3jcaPxdIuhslvSsjaYK5uA==",
|
||||
}
|
||||
h.session.ClientValues[authboss.SessionKey] = uname
|
||||
|
||||
r := mocks.Request("POST")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := h.otp.AddPost(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if h.responder.Page != PageAdd {
|
||||
t.Error("wanted add page, got:", h.responder.Page)
|
||||
}
|
||||
|
||||
if h.responder.Status != http.StatusOK {
|
||||
t.Error("wanted ok status, got:", h.responder.Status)
|
||||
}
|
||||
|
||||
sum := sha512.Sum512([]byte(h.responder.Data[DataOTP].(string)))
|
||||
encoded := base64.StdEncoding.EncodeToString(sum[:])
|
||||
|
||||
otps := splitOTPs(h.storer.Users[uname].OTPs)
|
||||
if len(otps) != 3 || encoded != otps[2] {
|
||||
t.Error("expected one new otp to be appended to the end")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGetUserNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
if err := h.otp.AddGet(w, r); err != authboss.ErrUserNotFound {
|
||||
t.Error("it should have failed with user not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPostUserNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
r := mocks.Request("POST")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
if err := h.otp.AddPost(w, r); err != authboss.ErrUserNotFound {
|
||||
t.Error("it should have failed with user not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
h.storer.Users["test@test.com"] = &mocks.User{
|
||||
Email: "test@test.com",
|
||||
// 3cc94671-958a912d-bd5a3ba7-3326a380
|
||||
OTPs: "2aID,2aIDHxmTIy1W7Uyz9c+iqhOJSE0a2Yna3zTRTs2q/X7Bv3xdVjExoztBEG4sQ2Nn3jcaPxdIuhslvSsjaYK5uA==",
|
||||
}
|
||||
h.session.ClientValues[authboss.SessionKey] = "test@test.com"
|
||||
|
||||
r := mocks.Request("POST")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := h.otp.ClearGet(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if h.responder.Page != PageClear {
|
||||
t.Error("wanted clear page, got:", h.responder.Page)
|
||||
}
|
||||
|
||||
if h.responder.Status != http.StatusOK {
|
||||
t.Error("wanted ok status, got:", h.responder.Status)
|
||||
}
|
||||
|
||||
if ln := h.responder.Data[DataNumberOTPs]; ln != "2" {
|
||||
t.Error("want two otps:", ln)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearPost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
uname := "test@test.com"
|
||||
h.storer.Users[uname] = &mocks.User{
|
||||
Email: uname,
|
||||
// 3cc94671-958a912d-bd5a3ba7-3326a380
|
||||
OTPs: "2aID,2aIDHxmTIy1W7Uyz9c+iqhOJSE0a2Yna3zTRTs2q/X7Bv3xdVjExoztBEG4sQ2Nn3jcaPxdIuhslvSsjaYK5uA==",
|
||||
}
|
||||
h.session.ClientValues[authboss.SessionKey] = uname
|
||||
|
||||
r := mocks.Request("POST")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
var err error
|
||||
r, err = h.ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := h.otp.ClearPost(w, r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if h.responder.Page != PageAdd {
|
||||
t.Error("wanted add page, got:", h.responder.Page)
|
||||
}
|
||||
|
||||
if h.responder.Status != http.StatusOK {
|
||||
t.Error("wanted ok status, got:", h.responder.Status)
|
||||
}
|
||||
|
||||
otps := splitOTPs(h.storer.Users[uname].OTPs)
|
||||
if len(otps) != 0 {
|
||||
t.Error("expected all otps to be gone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearGetUserNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
r := mocks.Request("GET")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
if err := h.otp.ClearGet(w, r); err != authboss.ErrUserNotFound {
|
||||
t.Error("it should have failed with user not found")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearPostUserNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
r := mocks.Request("POST")
|
||||
w := h.ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
if err := h.otp.AddPost(w, r); err != authboss.ErrUserNotFound {
|
||||
t.Error("it should have failed with user not found")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user