1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-24 05:17:10 +02:00
authboss/recover/recover_test.go
2018-05-18 06:49:21 -07:00

476 lines
11 KiB
Go

package recover
import (
"bytes"
"context"
"crypto/sha512"
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/volatiletech/authboss"
"github.com/volatiletech/authboss/internal/mocks"
)
const (
testSelector = `rnaGE8TDilrINHPxq/2xNU1FUTzsUSX8FvN5YzooyyWKk88fw1DjjbKBRGFtGew9OeZ+xeCC4mslfvQQMYspIg==`
testVerifier = `W1Mz30QhavVM4d8jKaFtxGBfb4GX+fOn7V0Pc1WeftgtyOtY5OX7sY9gIeY5CIY4n8LvfWy14W7/6rs2KO9pgA==`
testToken = `w5OZ51E61Q6wsJOVr9o7KmyepP7Od5VBHQ1ADDUBkiGGMjKfnMFPjtvNpLjLKJqffw72KWZzNLj0Cs8wqywdEQ==`
)
func TestInit(t *testing.T) {
t.Parallel()
ab := authboss.New()
router := &mocks.Router{}
renderer := &mocks.Renderer{}
mailRenderer := &mocks.Renderer{}
errHandler := &mocks.ErrorHandler{}
ab.Config.Core.Router = router
ab.Config.Core.ViewRenderer = renderer
ab.Config.Core.MailRenderer = mailRenderer
ab.Config.Core.ErrorHandler = errHandler
r := &Recover{}
if err := r.Init(ab); err != nil {
t.Fatal(err)
}
if err := renderer.HasLoadedViews(PageRecoverStart, PageRecoverEnd); err != nil {
t.Error(err)
}
if err := mailRenderer.HasLoadedViews(EmailRecoverHTML, EmailRecoverTxt); err != nil {
t.Error(err)
}
if err := router.HasGets("/recover", "/recover/end"); err != nil {
t.Error(err)
}
if err := router.HasPosts("/recover", "/recover/end"); err != nil {
t.Error(err)
}
}
type testHarness struct {
recover *Recover
ab *authboss.Authboss
bodyReader *mocks.BodyReader
mailer *mocks.Emailer
redirector *mocks.Redirector
renderer *mocks.Renderer
responder *mocks.Responder
session *mocks.ClientStateRW
storer *mocks.ServerStorer
}
func testSetup() *testHarness {
harness := &testHarness{}
harness.ab = authboss.New()
harness.bodyReader = &mocks.BodyReader{}
harness.mailer = &mocks.Emailer{}
harness.redirector = &mocks.Redirector{}
harness.renderer = &mocks.Renderer{}
harness.responder = &mocks.Responder{}
harness.session = mocks.NewClientRW()
harness.storer = mocks.NewServerStorer()
harness.ab.Paths.RecoverOK = "/recover/ok"
harness.ab.Config.Core.BodyReader = harness.bodyReader
harness.ab.Config.Core.Logger = mocks.Logger{}
harness.ab.Config.Core.Mailer = harness.mailer
harness.ab.Config.Core.Redirector = harness.redirector
harness.ab.Config.Core.MailRenderer = harness.renderer
harness.ab.Config.Core.Responder = harness.responder
harness.ab.Config.Storage.SessionState = harness.session
harness.ab.Config.Storage.Server = harness.storer
harness.recover = &Recover{harness.ab}
return harness
}
func TestStartGet(t *testing.T) {
t.Parallel()
h := testSetup()
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.StartGet(w, r); err != nil {
t.Error(err)
}
if w.Code != http.StatusOK {
t.Error("code was wrong:", w.Code)
}
if h.responder.Page != PageRecoverStart {
t.Error("page was wrong:", h.responder.Page)
}
if h.responder.Data != nil {
t.Error("expected no data:", h.responder.Data)
}
}
func TestStartPostSuccess(t *testing.T) {
// no t.Parallel(), global var mangling
oldRecoverEmail := goRecoverEmail
goRecoverEmail = func(r *Recover, ctx context.Context, to, token string) {
r.SendRecoverEmail(ctx, to, token)
}
defer func() {
goRecoverEmail = oldRecoverEmail
}()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
PID: "test@test.com",
}
h.storer.Users["test@test.com"] = &mocks.User{
Email: "test@test.com",
Password: "i can't recall, doesn't seem like something bcrypted though",
}
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.StartPost(w, r); err != nil {
t.Error(err)
}
if w.Code != http.StatusTemporaryRedirect {
t.Error("code was wrong:", w.Code)
}
if h.redirector.Options.RedirectPath != h.ab.Config.Paths.RecoverOK {
t.Error("page was wrong:", h.responder.Page)
}
if len(h.redirector.Options.Success) == 0 {
t.Error("expected a nice success message")
}
if h.mailer.Email.To[0] != "test@test.com" {
t.Error("e-mail to address is wrong:", h.mailer.Email.To)
}
if !strings.HasSuffix(h.mailer.Email.Subject, "Password Reset") {
t.Error("e-mail subject line is wrong:", h.mailer.Email.Subject)
}
if len(h.renderer.Data[DataRecoverURL].(string)) == 0 {
t.Errorf("the renderer's url in data was missing: %#v", h.renderer.Data)
}
}
func TestStartPostFailure(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
PID: "test@test.com",
}
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.StartPost(w, r); err != nil {
t.Error(err)
}
if w.Code != http.StatusTemporaryRedirect {
t.Error("code was wrong:", w.Code)
}
if h.redirector.Options.RedirectPath != h.ab.Config.Paths.RecoverOK {
t.Error("page was wrong:", h.responder.Page)
}
if len(h.redirector.Options.Success) == 0 {
t.Error("expected a nice success message")
}
if len(h.mailer.Email.To) != 0 {
t.Error("should not have sent an e-mail out!")
}
}
func TestEndGet(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
Token: "abcd",
}
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.EndGet(w, r); err != nil {
t.Error(err)
}
if w.Code != http.StatusOK {
t.Error("code was wrong:", w.Code)
}
if h.responder.Page != PageRecoverEnd {
t.Error("page was wrong:", h.responder.Page)
}
if h.responder.Data[DataRecoverToken].(string) != "abcd" {
t.Errorf("recovery token is wrong: %#v", h.responder.Data)
}
}
func TestEndPostSuccess(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
Token: testToken,
}
h.storer.Users["test@test.com"] = &mocks.User{
Email: "test@test.com",
Password: "to-overwrite",
RecoverSelector: testSelector,
RecoverVerifier: testVerifier,
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, 1),
}
r := mocks.Request("POST")
w := httptest.NewRecorder()
if err := h.recover.EndPost(w, r); err != nil {
t.Error(err)
}
if w.Code != http.StatusTemporaryRedirect {
t.Error("code was wrong:", w.Code)
}
if p := h.redirector.Options.RedirectPath; p != h.ab.Paths.RecoverOK {
t.Error("path was wrong:", p)
}
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
t.Error("should not have logged in the user")
}
if !strings.Contains(h.redirector.Options.Success, "updated password") {
t.Error("should talk about recovering the password")
}
if strings.Contains(h.redirector.Options.Success, "logged in") {
t.Error("should not talk about logging in")
}
}
func TestEndPostSuccessLogin(t *testing.T) {
t.Parallel()
h := testSetup()
h.ab.Config.Modules.RecoverLoginAfterRecovery = true
h.bodyReader.Return = &mocks.Values{
Token: testToken,
}
h.storer.Users["test@test.com"] = &mocks.User{
Email: "test@test.com",
Password: "to-overwrite",
RecoverSelector: testSelector,
RecoverVerifier: testVerifier,
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, 1),
}
r := mocks.Request("POST")
w := httptest.NewRecorder()
if err := h.recover.EndPost(h.ab.NewResponse(w), r); err != nil {
t.Error(err)
}
if w.Code != http.StatusTemporaryRedirect {
t.Error("code was wrong:", w.Code)
}
if p := h.redirector.Options.RedirectPath; p != h.ab.Paths.RecoverOK {
t.Error("path was wrong:", p)
}
if len(h.session.ClientValues[authboss.SessionKey]) == 0 {
t.Error("it should have logged in the user")
}
if !strings.Contains(h.redirector.Options.Success, "logged in") {
t.Error("should talk about logging in")
}
}
func TestEndPostValidationFailure(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
Errors: []error{errors.New("password is not sufficiently complex")},
}
h.storer.Users["test@test.com"] = &mocks.User{
Email: "test@test.com",
Password: "to-overwrite",
RecoverSelector: testSelector,
RecoverVerifier: testVerifier,
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, 1),
}
r := mocks.Request("POST")
w := httptest.NewRecorder()
if err := h.recover.EndPost(w, r); err != nil {
t.Error(err)
}
if w.Code != http.StatusOK {
t.Error("code was wrong:", w.Code)
}
if h.responder.Page != PageRecoverEnd {
t.Error("rendered the wrong page")
}
if m, ok := h.responder.Data[authboss.DataValidation].(map[string][]string); !ok {
t.Error("expected validation errors")
} else if m[""][0] != "password is not sufficiently complex" {
t.Error("error message data was not correct:", m[""])
}
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
t.Error("should not have logged in the user")
}
}
func TestEndPostInvalidBase64(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
Token: "a",
}
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.EndPost(w, r); err != nil {
t.Error(err)
}
invalidCheck(t, h, w)
}
func TestEndPostExpiredToken(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
Token: testToken,
}
h.storer.Users["test@test.com"] = &mocks.User{
Email: "test@test.com",
Password: "to-overwrite",
RecoverSelector: testSelector,
RecoverVerifier: testVerifier,
RecoverTokenExpiry: time.Now().UTC().AddDate(0, 0, -1),
}
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.EndPost(w, r); err != nil {
t.Error(err)
}
invalidCheck(t, h, w)
}
func TestEndPostUserNotExist(t *testing.T) {
t.Parallel()
h := testSetup()
h.bodyReader.Return = &mocks.Values{
Token: testToken,
}
r := mocks.Request("GET")
w := httptest.NewRecorder()
if err := h.recover.EndPost(w, r); err != nil {
t.Error(err)
}
invalidCheck(t, h, w)
}
func invalidCheck(t *testing.T, h *testHarness, w *httptest.ResponseRecorder) {
t.Helper()
if w.Code != http.StatusOK {
t.Error("code was wrong:", w.Code)
}
if h.responder.Page != PageRecoverEnd {
t.Error("page was wrong:", h.responder.Page)
}
if h.responder.Data[authboss.DataValidation].(map[string][]string)[""][0] != "recovery token is invalid" {
t.Error("expected a vague error to mislead")
}
}
func TestGenerateRecoverCreds(t *testing.T) {
t.Parallel()
selector, verifier, token, err := GenerateRecoverCreds()
if err != nil {
t.Error(err)
}
if verifier == selector {
t.Error("the verifier and selector should be different")
}
// base64 length: n = 64; 4*(64/3) = 85.3; round to nearest 4: 88
if len(verifier) != 88 {
t.Errorf("verifier length was wrong (%d): %s", len(verifier), verifier)
}
// base64 length: n = 64; 4*(64/3) = 85.3; round to nearest 4: 88
if len(selector) != 88 {
t.Errorf("selector length was wrong (%d): %s", len(selector), selector)
}
// base64 length: n = 64; 4*(64/3) = 85.33; round to nearest 4: 88
if len(token) != 88 {
t.Errorf("token length was wrong (%d): %s", len(token), token)
}
rawToken, err := base64.URLEncoding.DecodeString(token)
if err != nil {
t.Error(err)
}
rawSelector, err := base64.StdEncoding.DecodeString(selector)
if err != nil {
t.Error(err)
}
rawVerifier, err := base64.StdEncoding.DecodeString(verifier)
if err != nil {
t.Error(err)
}
checkSelector := sha512.Sum512(rawToken[:recoverTokenSplit])
if 0 != bytes.Compare(checkSelector[:], rawSelector) {
t.Error("expected selector to match")
}
checkVerifier := sha512.Sum512(rawToken[recoverTokenSplit:])
if 0 != bytes.Compare(checkVerifier[:], rawVerifier) {
t.Error("expected verifier to match")
}
}