mirror of
https://github.com/volatiletech/authboss.git
synced 2024-11-24 08:42:17 +02:00
Finish register module
This commit is contained in:
parent
1068509540
commit
948aa8a115
@ -2,28 +2,21 @@
|
||||
package register
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/volatiletech/authboss/internal/response"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Pages
|
||||
const (
|
||||
tplRegister = "register.html.tpl"
|
||||
PageRegister = "register"
|
||||
)
|
||||
|
||||
// RegisterStorer must be implemented in order to satisfy the register module's
|
||||
// storage requirments.
|
||||
type RegisterStorer interface {
|
||||
authboss.Storer
|
||||
// Create is the same as put, except it refers to a non-existent key. If the key is
|
||||
// found simply return authboss.ErrUserFound
|
||||
Create(key string, attr authboss.Attributes) error
|
||||
}
|
||||
|
||||
func init() {
|
||||
authboss.RegisterModule("register", &Register{})
|
||||
}
|
||||
@ -31,127 +24,129 @@ func init() {
|
||||
// Register module.
|
||||
type Register struct {
|
||||
*authboss.Authboss
|
||||
templates response.Templates
|
||||
}
|
||||
|
||||
// Initialize the module.
|
||||
func (r *Register) Initialize(ab *authboss.Authboss) (err error) {
|
||||
// Init the module.
|
||||
func (r *Register) Init(ab *authboss.Authboss) (err error) {
|
||||
r.Authboss = ab
|
||||
|
||||
if r.Storer != nil {
|
||||
if _, ok := r.Storer.(RegisterStorer); !ok {
|
||||
return errors.New("registerStorer required for register functionality")
|
||||
}
|
||||
} else if r.StoreMaker == nil {
|
||||
return errors.New("need a registerStorer")
|
||||
if _, ok := ab.Config.Storage.Server.(authboss.CreatingServerStorer); !ok {
|
||||
return errors.New("register module activated but storer could not be upgraded to CreatingServerStorer")
|
||||
}
|
||||
|
||||
if r.templates, err = response.LoadTemplates(r.Authboss, r.Layout, r.ViewsPath, tplRegister); err != nil {
|
||||
if err := ab.Config.Core.ViewRenderer.Load(PageRegister); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sort.Strings(ab.Config.Modules.RegisterPreserveFields)
|
||||
|
||||
ab.Config.Core.Router.Get("/register", ab.Config.Core.ErrorHandler.Wrap(r.Get))
|
||||
ab.Config.Core.Router.Post("/register", ab.Config.Core.ErrorHandler.Wrap(r.Post))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Routes creates the routing table.
|
||||
func (r *Register) Routes() authboss.RouteTable {
|
||||
return authboss.RouteTable{
|
||||
"/register": r.registerHandler,
|
||||
}
|
||||
// Get the register page
|
||||
func (r *Register) Get(w http.ResponseWriter, req *http.Request) error {
|
||||
return r.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRegister, nil)
|
||||
}
|
||||
|
||||
// Storage returns storage requirements.
|
||||
func (r *Register) Storage() authboss.StorageOptions {
|
||||
return authboss.StorageOptions{
|
||||
r.PrimaryID: authboss.String,
|
||||
authboss.StorePassword: authboss.String,
|
||||
}
|
||||
}
|
||||
|
||||
func (reg *Register) registerHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
primaryID := r.FormValue("primaryID")
|
||||
|
||||
data := authboss.HTMLData{
|
||||
"primaryID": reg.PrimaryID,
|
||||
"primaryIDValue": primaryID,
|
||||
"primaryIDReadonly": len(primaryID) > 0,
|
||||
}
|
||||
return reg.templates.Render(ctx, w, r, tplRegister, data)
|
||||
case "POST":
|
||||
return reg.registerPostHandler(ctx, w, r)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (reg *Register) registerPostHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
key := r.FormValue(reg.PrimaryID)
|
||||
password := r.FormValue(authboss.StorePassword)
|
||||
|
||||
validationErrs := authboss.Validate(r, reg.Policies, reg.ConfirmFields...)
|
||||
|
||||
if user, err := ctx.Storer.Get(key); err != nil && err != authboss.ErrUserNotFound {
|
||||
return err
|
||||
} else if user != nil {
|
||||
validationErrs = append(validationErrs, authboss.FieldError{Name: reg.PrimaryID, Err: errors.New("Already in use")})
|
||||
}
|
||||
|
||||
if len(validationErrs) != 0 {
|
||||
data := authboss.HTMLData{
|
||||
"primaryID": reg.PrimaryID,
|
||||
"primaryIDValue": key,
|
||||
"errs": validationErrs.Map(),
|
||||
}
|
||||
|
||||
for _, f := range reg.PreserveFields {
|
||||
data[f] = r.FormValue(f)
|
||||
}
|
||||
|
||||
return reg.templates.Render(ctx, w, r, tplRegister, data)
|
||||
}
|
||||
|
||||
attr, err := authboss.AttributesFromRequest(r) // Attributes from overriden forms
|
||||
// Post to the register page
|
||||
func (r *Register) Post(w http.ResponseWriter, req *http.Request) error {
|
||||
logger := r.RequestLogger(req)
|
||||
validatable, err := r.Core.BodyReader.Read(PageRegister, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pass, err := bcrypt.GenerateFromPassword([]byte(password), reg.BCryptCost)
|
||||
var arbitrary map[string]string
|
||||
var preserve map[string]string
|
||||
if arb, ok := validatable.(authboss.ArbitraryValuer); ok {
|
||||
arbitrary = arb.GetValues()
|
||||
preserve = make(map[string]string)
|
||||
|
||||
for k, v := range arbitrary {
|
||||
if hasString(r.Config.Modules.RegisterPreserveFields, k) {
|
||||
preserve[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errs := validatable.Validate()
|
||||
if errs != nil {
|
||||
logger.Info("registration validation failed")
|
||||
data := authboss.HTMLData{
|
||||
authboss.DataValidation: authboss.ErrorList(errs),
|
||||
}
|
||||
if preserve != nil {
|
||||
data[authboss.DataPreserve] = preserve
|
||||
}
|
||||
return r.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRegister, data)
|
||||
}
|
||||
|
||||
// Get values from request
|
||||
userVals := authboss.MustHaveUserValues(validatable)
|
||||
pid, password := userVals.GetPID(), userVals.GetPassword()
|
||||
|
||||
// Put values into newly created user for storage
|
||||
storer := authboss.EnsureCanCreate(r.Config.Storage.Server)
|
||||
user := authboss.MustBeAuthable(storer.New(req.Context()))
|
||||
|
||||
pass, err := bcrypt.GenerateFromPassword([]byte(password), r.Config.Modules.RegisterBCryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
attr[reg.PrimaryID] = key
|
||||
attr[authboss.StorePassword] = string(pass)
|
||||
ctx.User = attr
|
||||
user.PutPID(pid)
|
||||
user.PutPassword(string(pass))
|
||||
|
||||
if err := ctx.Storer.(RegisterStorer).Create(key, attr); err == authboss.ErrUserFound {
|
||||
if arbUser, ok := user.(authboss.ArbitraryUser); ok && arbitrary != nil {
|
||||
arbUser.PutArbitrary(arbitrary)
|
||||
}
|
||||
|
||||
err = storer.Create(req.Context(), user)
|
||||
switch {
|
||||
case err == authboss.ErrUserFound:
|
||||
logger.Infof("user %s attempted to re-register", pid)
|
||||
errs = []error{errors.New("user already exists")}
|
||||
data := authboss.HTMLData{
|
||||
"primaryID": reg.PrimaryID,
|
||||
"primaryIDValue": key,
|
||||
"errs": map[string][]string{reg.PrimaryID: []string{"Already in use"}},
|
||||
authboss.DataValidation: authboss.ErrorList(errs),
|
||||
}
|
||||
|
||||
for _, f := range reg.PreserveFields {
|
||||
data[f] = r.FormValue(f)
|
||||
if preserve != nil {
|
||||
data[authboss.DataPreserve] = preserve
|
||||
}
|
||||
|
||||
return reg.templates.Render(ctx, w, r, tplRegister, data)
|
||||
} else if err != nil {
|
||||
return r.Config.Core.Responder.Respond(w, req, http.StatusOK, PageRegister, data)
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
if err := reg.Events.FireAfter(authboss.EventRegister, ctx); err != nil {
|
||||
req = req.WithContext(context.WithValue(req.Context(), authboss.CTXKeyUser, user))
|
||||
handled, err := r.Events.FireAfter(authboss.EventRegister, w, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reg.IsLoaded("confirm") {
|
||||
response.Redirect(ctx, w, r, reg.RegisterOKPath, "Account successfully created, please verify your e-mail address.", "", true)
|
||||
} else if handled {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx.SessionStorer.Put(authboss.SessionKey, key)
|
||||
response.Redirect(ctx, w, r, reg.RegisterOKPath, "Account successfully created, you are now logged in.", "", true)
|
||||
// Log the user in, but only if the response wasn't handled previously by a module
|
||||
// like confirm.
|
||||
authboss.PutSession(w, authboss.SessionKey, pid)
|
||||
|
||||
return nil
|
||||
logger.Infof("registered and logged in user %s", pid)
|
||||
ro := authboss.RedirectOptions{
|
||||
Code: http.StatusTemporaryRedirect,
|
||||
Success: "Account successfully created, you are now logged in",
|
||||
RedirectPath: r.Config.Paths.RegisterOK,
|
||||
}
|
||||
return r.Config.Core.Redirector.Redirect(w, req, ro)
|
||||
}
|
||||
|
||||
// hasString checks to see if a sorted (ascending) array of strings contains a string
|
||||
func hasString(arr []string, s string) bool {
|
||||
index := sort.SearchStrings(arr, s)
|
||||
if index < 0 || index >= len(arr) {
|
||||
return false
|
||||
}
|
||||
|
||||
return arr[index] == s
|
||||
}
|
||||
|
@ -1,18 +1,315 @@
|
||||
package register
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/volatiletech/authboss"
|
||||
"github.com/volatiletech/authboss/internal/mocks"
|
||||
)
|
||||
|
||||
func TestRegisterInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
|
||||
router := &mocks.Router{}
|
||||
renderer := &mocks.Renderer{}
|
||||
errHandler := &mocks.ErrorHandler{}
|
||||
ab.Config.Core.Router = router
|
||||
ab.Config.Core.ViewRenderer = renderer
|
||||
ab.Config.Core.ErrorHandler = errHandler
|
||||
ab.Config.Storage.Server = &mocks.ServerStorer{}
|
||||
|
||||
reg := &Register{}
|
||||
if err := reg.Init(ab); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := renderer.HasLoadedViews(PageRegister); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := router.HasGets("/register"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := router.HasPosts("/register"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterGet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := authboss.New()
|
||||
responder := &mocks.Responder{}
|
||||
ab.Config.Core.Responder = responder
|
||||
|
||||
a := &Register{ab}
|
||||
a.Get(nil, nil)
|
||||
|
||||
if responder.Page != PageRegister {
|
||||
t.Error("wanted login page, got:", responder.Page)
|
||||
}
|
||||
|
||||
if responder.Status != http.StatusOK {
|
||||
t.Error("wanted ok status, got:", responder.Status)
|
||||
}
|
||||
}
|
||||
|
||||
type testHarness struct {
|
||||
reg *Register
|
||||
ab *authboss.Authboss
|
||||
|
||||
bodyReader *mocks.BodyReader
|
||||
responder *mocks.Responder
|
||||
redirector *mocks.Redirector
|
||||
session *mocks.ClientStateRW
|
||||
storer *mocks.ServerStorer
|
||||
}
|
||||
|
||||
func testSetup() *testHarness {
|
||||
harness := &testHarness{}
|
||||
|
||||
harness.ab = authboss.New()
|
||||
harness.bodyReader = &mocks.BodyReader{}
|
||||
harness.redirector = &mocks.Redirector{}
|
||||
harness.responder = &mocks.Responder{}
|
||||
harness.session = mocks.NewClientRW()
|
||||
harness.storer = mocks.NewServerStorer()
|
||||
|
||||
harness.ab.Paths.RegisterOK = "/ok"
|
||||
|
||||
harness.ab.Config.Core.BodyReader = harness.bodyReader
|
||||
harness.ab.Config.Core.Logger = mocks.Logger{}
|
||||
harness.ab.Config.Core.Responder = harness.responder
|
||||
harness.ab.Config.Core.Redirector = harness.redirector
|
||||
harness.ab.Config.Storage.SessionState = harness.session
|
||||
harness.ab.Config.Storage.Server = harness.storer
|
||||
|
||||
harness.reg = &Register{harness.ab}
|
||||
|
||||
return harness
|
||||
}
|
||||
|
||||
func TestRegisterPostSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
setupMore := func(harness *testHarness) *testHarness {
|
||||
harness.ab.Modules.RegisterPreserveFields = []string{"email", "another"}
|
||||
harness.bodyReader.Return = mocks.ArbValues{
|
||||
Values: map[string]string{
|
||||
"email": "test@test.com",
|
||||
"password": "hello world",
|
||||
"another": "value",
|
||||
},
|
||||
}
|
||||
|
||||
return harness
|
||||
}
|
||||
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp, r)
|
||||
|
||||
if err := h.reg.Post(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
user, ok := h.storer.Users["test@test.com"]
|
||||
if !ok {
|
||||
t.Error("user was not persisted in the DB")
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte("hello world")); err != nil {
|
||||
t.Error("password was not properly encrypted:", err)
|
||||
}
|
||||
|
||||
if user.Arbitrary["another"] != "value" {
|
||||
t.Error("arbitrary values not saved")
|
||||
}
|
||||
|
||||
if h.session.ClientValues[authboss.SessionKey] != "test@test.com" {
|
||||
t.Error("user should have been logged in:", h.session.ClientValues)
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusTemporaryRedirect {
|
||||
t.Error("code was wrong:", resp.Code)
|
||||
}
|
||||
if h.redirector.Options.RedirectPath != "/ok" {
|
||||
t.Error("redirect path was wrong:", h.redirector.Options.RedirectPath)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handledAfter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := setupMore(testSetup())
|
||||
|
||||
var afterCalled bool
|
||||
h.ab.Events.After(authboss.EventRegister, func(w http.ResponseWriter, r *http.Request, handled bool) (bool, error) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
afterCalled = true
|
||||
return true, nil
|
||||
})
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp, r)
|
||||
|
||||
if err := h.reg.Post(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
user, ok := h.storer.Users["test@test.com"]
|
||||
if !ok {
|
||||
t.Error("user was not persisted in the DB")
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte("hello world")); err != nil {
|
||||
t.Error("password was not properly encrypted:", err)
|
||||
}
|
||||
|
||||
if val, ok := h.session.ClientValues[authboss.SessionKey]; ok {
|
||||
t.Error("user should not have been logged in:", val)
|
||||
}
|
||||
|
||||
if resp.Code != http.StatusTeapot {
|
||||
t.Error("code was wrong:", resp.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegisterPostValidationFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
// Ensure the below is sorted, the sort normally happens in Init() that we don't call
|
||||
h.ab.Modules.RegisterPreserveFields = []string{"another", "email"}
|
||||
h.bodyReader.Return = mocks.ArbValues{
|
||||
Values: map[string]string{
|
||||
"email": "test@test.com",
|
||||
"password": "hello world",
|
||||
"another": "value",
|
||||
},
|
||||
Errors: []error{
|
||||
errors.New("bad password"),
|
||||
},
|
||||
}
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp, r)
|
||||
|
||||
if err := h.reg.Post(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.responder.Status != http.StatusOK {
|
||||
t.Error("wrong status:", h.responder.Status)
|
||||
}
|
||||
if h.responder.Page != PageRegister {
|
||||
t.Error("rendered wrong page:", h.responder.Page)
|
||||
}
|
||||
|
||||
errList := h.responder.Data[authboss.DataValidation].(authboss.ErrorList)
|
||||
if e := errList[0].Error(); e != "bad password" {
|
||||
t.Error("validation error wrong:", e)
|
||||
}
|
||||
|
||||
intfD, ok := h.responder.Data[authboss.DataPreserve]
|
||||
if !ok {
|
||||
t.Fatal("there was no preserved data")
|
||||
}
|
||||
|
||||
d := intfD.(map[string]string)
|
||||
if d["email"] != "test@test.com" {
|
||||
t.Error("e-mail was not preserved:", d)
|
||||
} else if d["another"] != "value" {
|
||||
t.Error("another value was not preserved", d)
|
||||
} else if _, ok = d["password"]; ok {
|
||||
t.Error("password was preserved", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterPostUserExists(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := testSetup()
|
||||
|
||||
// Ensure the below is sorted, the sort normally happens in Init() that we don't call
|
||||
h.ab.Modules.RegisterPreserveFields = []string{"another", "email"}
|
||||
h.storer.Users["test@test.com"] = &mocks.User{}
|
||||
h.bodyReader.Return = mocks.ArbValues{
|
||||
Values: map[string]string{
|
||||
"email": "test@test.com",
|
||||
"password": "hello world",
|
||||
"another": "value",
|
||||
},
|
||||
}
|
||||
|
||||
r := mocks.Request("POST")
|
||||
resp := httptest.NewRecorder()
|
||||
w := h.ab.NewResponse(resp, r)
|
||||
|
||||
if err := h.reg.Post(w, r); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if h.responder.Status != http.StatusOK {
|
||||
t.Error("wrong status:", h.responder.Status)
|
||||
}
|
||||
if h.responder.Page != PageRegister {
|
||||
t.Error("rendered wrong page:", h.responder.Page)
|
||||
}
|
||||
|
||||
errList := h.responder.Data[authboss.DataValidation].(authboss.ErrorList)
|
||||
if e := errList[0].Error(); e != "user already exists" {
|
||||
t.Error("validation error wrong:", e)
|
||||
}
|
||||
|
||||
intfD, ok := h.responder.Data[authboss.DataPreserve]
|
||||
if !ok {
|
||||
t.Fatal("there was no preserved data")
|
||||
}
|
||||
|
||||
d := intfD.(map[string]string)
|
||||
if d["email"] != "test@test.com" {
|
||||
t.Error("e-mail was not preserved:", d)
|
||||
} else if d["another"] != "value" {
|
||||
t.Error("another value was not preserved", d)
|
||||
} else if _, ok = d["password"]; ok {
|
||||
t.Error("password was preserved", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strs := []string{"b", "c", "d", "e"}
|
||||
|
||||
if !hasString(strs, "b") {
|
||||
t.Error("should have a")
|
||||
}
|
||||
if !hasString(strs, "e") {
|
||||
t.Error("should have d")
|
||||
}
|
||||
|
||||
if hasString(strs, "a") {
|
||||
t.Error("should not have a")
|
||||
}
|
||||
if hasString(strs, "f") {
|
||||
t.Error("should not have f")
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
func setup() *Register {
|
||||
ab := authboss.New()
|
||||
ab.RegisterOKPath = "/regsuccess"
|
||||
@ -164,3 +461,4 @@ func TestRegisterPostSuccess(t *testing.T) {
|
||||
t.Error("Password was not hashed.")
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
Loading…
Reference in New Issue
Block a user