mirror of
https://github.com/volatiletech/authboss.git
synced 2025-02-05 13:24:54 +02:00
More gigantic edits.
- Change response to be more central to Authboss. Make sure it has useful methods and works with the new rendering idioms. - Change the load user methods to all work with context keys, and even be able to set context keys on the current request to avoid setting contexts everywhere in the code base.
This commit is contained in:
parent
f65d9f6bb6
commit
fa6ba517db
117
authboss.go
117
authboss.go
@ -20,6 +20,9 @@ type Authboss struct {
|
||||
|
||||
loadedModules map[string]Modularizer
|
||||
mux *http.ServeMux
|
||||
|
||||
templateNames []string
|
||||
renderer Renderer
|
||||
}
|
||||
|
||||
// New makes a new instance of authboss with a default
|
||||
@ -47,68 +50,15 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
|
||||
}
|
||||
}
|
||||
|
||||
renderer, err := a.ViewLoader.Init(a.templateNames)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to init view loader")
|
||||
}
|
||||
a.renderer = renderer
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CurrentUser retrieves the current user from the session and the database.
|
||||
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
|
||||
return nil, errors.New("TODO")
|
||||
}
|
||||
|
||||
func (a *Authboss) currentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
|
||||
/*
|
||||
_, err := a.Callbacks.FireBefore(EventGetUserSession, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, ok := ctx.SessionStorer.Get(SessionKey)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err = a.Callbacks.FireBefore(EventGetUser, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user interface{}
|
||||
|
||||
if index := strings.IndexByte(key, ';'); index > 0 {
|
||||
user, err = a.OAuth2Storer.GetOAuth(key[:index], key[index+1:])
|
||||
} else {
|
||||
user, err = a.Storer.Get(key)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.User = Unbind(user)
|
||||
|
||||
err = a.Callbacks.FireAfter(EventGetUser, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, err
|
||||
*/
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// CurrentUserP retrieves the current user but panics if it's not available for
|
||||
// any reason.
|
||||
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) interface{} {
|
||||
/*
|
||||
i, err := a.CurrentUser(w, r)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
return i
|
||||
*/
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
/*
|
||||
UpdatePassword should be called to recalculate hashes and do any cleanup
|
||||
that should occur on password resets. Updater should return an error if the
|
||||
@ -130,43 +80,28 @@ The error returned is returned either from the updater if that produced an error
|
||||
or from the cleanup routines.
|
||||
*/
|
||||
func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
|
||||
ptPassword string, user interface{}, updater func() error) error {
|
||||
ptPassword string, user Storer, updater func() error) error {
|
||||
|
||||
/*
|
||||
updatePwd := len(ptPassword) > 0
|
||||
/*updatePwd := len(ptPassword) > 0
|
||||
|
||||
if updatePwd {
|
||||
pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), a.BCryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(user).Elem()
|
||||
field := val.FieldByName("Password")
|
||||
if !field.CanSet() {
|
||||
return errors.New("authboss: updatePassword called without a modifyable user struct")
|
||||
}
|
||||
fieldPtr := field.Addr()
|
||||
|
||||
if scanner, ok := fieldPtr.Interface().(sql.Scanner); ok {
|
||||
if err := scanner.Scan(string(pass)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
field.SetString(string(pass))
|
||||
}
|
||||
}
|
||||
|
||||
if err := updater(); err != nil {
|
||||
if updatePwd {
|
||||
pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), a.BCryptCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !updatePwd {
|
||||
return nil
|
||||
}
|
||||
user.PutPassword(r.Context(),
|
||||
}
|
||||
|
||||
return a.Callbacks.FireAfter(EventPasswordReset, a.InitContext(w, r))
|
||||
*/
|
||||
if err := updater(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return errors.New("TODO")
|
||||
if !updatePwd {
|
||||
return nil
|
||||
}
|
||||
|
||||
return a.Callbacks.FireAfter(EventPasswordReset, r.Context())*/
|
||||
// TODO(aarondl): Fix
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
156
authboss_test.go
156
authboss_test.go
@ -2,7 +2,6 @@ package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -16,6 +15,7 @@ func TestAuthBossInit(t *testing.T) {
|
||||
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
err := ab.Init()
|
||||
if err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
@ -28,12 +28,9 @@ func TestAuthBossCurrentUser(t *testing.T) {
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
|
||||
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return mockClientStore{SessionKey: "joe"}
|
||||
}
|
||||
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return mockClientStore{}
|
||||
}
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
if err := ab.Init(); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
@ -43,7 +40,7 @@ func TestAuthBossCurrentUser(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "localhost", nil)
|
||||
|
||||
userStruct := ab.CurrentUserP(rec, req)
|
||||
us := userStruct.(*mockUser)
|
||||
us := userStruct.(mockStoredUser)
|
||||
|
||||
if us.Email != "john@john.com" || us.Password != "lies" {
|
||||
t.Error("Wrong user found!")
|
||||
@ -56,12 +53,9 @@ func TestAuthBossCurrentUserCallbacks(t *testing.T) {
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
|
||||
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return mockClientStore{SessionKey: "joe"}
|
||||
}
|
||||
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return mockClientStore{}
|
||||
}
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
if err := ab.Init(); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
@ -97,86 +91,88 @@ func TestAuthBossCurrentUserCallbacks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthbossUpdatePassword(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("TODO(aarondl): Implement")
|
||||
/*
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
session := mockClientStore{}
|
||||
cookies := mockClientStore{}
|
||||
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return session
|
||||
}
|
||||
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return cookies
|
||||
}
|
||||
ab := New()
|
||||
session := mockClientStore{}
|
||||
cookies := mockClientStore{}
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(cookies)
|
||||
|
||||
called := false
|
||||
ab.Callbacks.After(EventPasswordReset, func(ctx context.Context) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
called := false
|
||||
ab.Callbacks.After(EventPasswordReset, func(ctx context.Context) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
|
||||
user1 := struct {
|
||||
Password string
|
||||
}{}
|
||||
user2 := struct {
|
||||
Password sql.NullString
|
||||
}{}
|
||||
user1 := struct {
|
||||
Password string
|
||||
}{}
|
||||
user2 := struct {
|
||||
Password sql.NullString
|
||||
}{}
|
||||
|
||||
r, _ := http.NewRequest("GET", "http://localhost", nil)
|
||||
r, _ := http.NewRequest("GET", "http://localhost", nil)
|
||||
|
||||
called = false
|
||||
err := ab.UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
called = false
|
||||
err := ab.UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(user1.Password) == 0 {
|
||||
t.Error("Password not updated")
|
||||
}
|
||||
if !called {
|
||||
t.Error("Callbacks should have been called.")
|
||||
}
|
||||
if len(user1.Password) == 0 {
|
||||
t.Error("Password not updated")
|
||||
}
|
||||
if !called {
|
||||
t.Error("Callbacks should have been called.")
|
||||
}
|
||||
|
||||
called = false
|
||||
err = ab.UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
called = false
|
||||
err = ab.UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !user2.Password.Valid || len(user2.Password.String) == 0 {
|
||||
t.Error("Password not updated")
|
||||
}
|
||||
if !called {
|
||||
t.Error("Callbacks should have been called.")
|
||||
}
|
||||
if !user2.Password.Valid || len(user2.Password.String) == 0 {
|
||||
t.Error("Password not updated")
|
||||
}
|
||||
if !called {
|
||||
t.Error("Callbacks should have been called.")
|
||||
}
|
||||
|
||||
called = false
|
||||
oldPassword := user1.Password
|
||||
err = ab.UpdatePassword(nil, r, "", &user1, func() error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
called = false
|
||||
oldPassword := user1.Password
|
||||
err = ab.UpdatePassword(nil, r, "", &user1, func() error { return nil })
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if user1.Password != oldPassword {
|
||||
t.Error("Password not updated")
|
||||
}
|
||||
if called {
|
||||
t.Error("Callbacks should not have been called")
|
||||
}
|
||||
if user1.Password != oldPassword {
|
||||
t.Error("Password not updated")
|
||||
}
|
||||
if called {
|
||||
t.Error("Callbacks should not have been called")
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
func TestAuthbossUpdatePasswordFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Skip("TODO(aarondl): Implement")
|
||||
/*
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab := New()
|
||||
|
||||
user1 := struct {
|
||||
Password string
|
||||
}{}
|
||||
user1 := struct {
|
||||
Password string
|
||||
}{}
|
||||
|
||||
anErr := errors.New("anError")
|
||||
err := ab.UpdatePassword(nil, nil, "update", &user1, func() error { return anErr })
|
||||
if err != anErr {
|
||||
t.Error("Expected an specific error:", err)
|
||||
}
|
||||
anErr := errors.New("anError")
|
||||
err := ab.UpdatePassword(nil, nil, "update", &user1, func() error { return anErr })
|
||||
if err != anErr {
|
||||
t.Error("Expected an specific error:", err)
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
16
callbacks.go
16
callbacks.go
@ -1,12 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"runtime"
|
||||
)
|
||||
import "context"
|
||||
|
||||
//go:generate stringer -output stringers.go -type "Event,Interrupt"
|
||||
|
||||
@ -65,8 +59,8 @@ type Callbacks struct {
|
||||
// Called only by authboss internals and for testing.
|
||||
func NewCallbacks() *Callbacks {
|
||||
return &Callbacks{
|
||||
make(map[Event][]Before),
|
||||
make(map[Event][]After),
|
||||
before: make(map[Event][]Before),
|
||||
after: make(map[Event][]After),
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,8 +89,6 @@ func (c *Callbacks) FireBefore(e Event, ctx context.Context) (interrupt Interrup
|
||||
for _, fn := range callbacks {
|
||||
interrupt, err = fn(ctx)
|
||||
if err != nil {
|
||||
// TODO(aarondl): logwriter fail
|
||||
fmt.Fprintf(ioutil.Discard, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
|
||||
return InterruptNone, err
|
||||
}
|
||||
if interrupt != InterruptNone {
|
||||
@ -113,8 +105,6 @@ func (c *Callbacks) FireAfter(e Event, ctx context.Context) (err error) {
|
||||
callbacks := c.after[e]
|
||||
for _, fn := range callbacks {
|
||||
if err = fn(ctx); err != nil {
|
||||
// TODO(aarondl): logwriter fail
|
||||
fmt.Fprintf(ioutil.Discard, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package authboss
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@ -116,10 +115,6 @@ func TestCallbacksBeforeErrors(t *testing.T) {
|
||||
if before2 {
|
||||
t.Error("Before2 should not have been called.")
|
||||
}
|
||||
|
||||
if estr := log.String(); !strings.Contains(estr, errValue.Error()) {
|
||||
t.Error("Error string wrong:", estr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbacksAfterErrors(t *testing.T) {
|
||||
@ -153,10 +148,6 @@ func TestCallbacksAfterErrors(t *testing.T) {
|
||||
if after2 {
|
||||
t.Error("After2 should not have been called.")
|
||||
}
|
||||
|
||||
if estr := log.String(); !strings.Contains(estr, errValue.Error()) {
|
||||
t.Error("Error string wrong:", estr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventString(t *testing.T) {
|
||||
|
@ -25,8 +25,9 @@ const (
|
||||
FlashErrorKey = "flash_error"
|
||||
)
|
||||
|
||||
// ClientStoreMaker is used to create a cookie storer from an http request. Keep in mind
|
||||
// security considerations for your implementation, Secure, HTTP-Only, etc flags.
|
||||
// ClientStoreMaker is used to create a cookie storer from an http request.
|
||||
// Keep in mind security considerations for your implementation, Secure,
|
||||
// HTTP-Only, etc flags.
|
||||
//
|
||||
// There's two major uses for this. To create session storage, and remember me
|
||||
// cookies.
|
||||
@ -64,7 +65,7 @@ func (c clientStoreWrapper) GetErr(key string) (string, error) {
|
||||
|
||||
// FlashSuccess returns FlashSuccessKey from the session and removes it.
|
||||
func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
||||
storer := a.SessionStoreMaker(w, r)
|
||||
storer := a.SessionStoreMaker.Make(w, r)
|
||||
msg, ok := storer.Get(FlashSuccessKey)
|
||||
if ok {
|
||||
storer.Del(FlashSuccessKey)
|
||||
@ -75,7 +76,7 @@ func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
||||
|
||||
// FlashError returns FlashError from the session and removes it.
|
||||
func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string {
|
||||
storer := a.SessionStoreMaker(w, r)
|
||||
storer := a.SessionStoreMaker.Make(w, r)
|
||||
msg, ok := storer.Get(FlashErrorKey)
|
||||
if ok {
|
||||
storer.Del(FlashErrorKey)
|
||||
|
@ -1,9 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
import "testing"
|
||||
|
||||
type testClientStorerErr string
|
||||
|
||||
@ -36,9 +33,7 @@ func TestFlashClearer(t *testing.T) {
|
||||
|
||||
session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"}
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer {
|
||||
return session
|
||||
}
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
|
||||
if msg := ab.FlashSuccess(nil, nil); msg != "success" {
|
||||
t.Error("Unexpected flash success:", msg)
|
||||
|
@ -2,6 +2,7 @@
|
||||
package confirm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
@ -31,14 +32,25 @@ var (
|
||||
errUserMissing = errors.New("after registration user must be loaded")
|
||||
)
|
||||
|
||||
// ConfirmStorer must be implemented in order to satisfy the confirm module's
|
||||
// storage requirements.
|
||||
type ConfirmStorer interface {
|
||||
authboss.Storer
|
||||
// ConfirmStoreLoader allows lookup of users by different parameters.
|
||||
type ConfirmStoreLoader interface {
|
||||
authboss.StoreLoader
|
||||
|
||||
// ConfirmUser looks up a user by a confirm token. See confirm module for
|
||||
// attribute names. If the token is not found in the data store,
|
||||
// simply return nil, ErrUserNotFound.
|
||||
ConfirmUser(confirmToken string) (interface{}, error)
|
||||
LoadByConfirmToken(confirmToken string) (ConfirmStorer, error)
|
||||
}
|
||||
|
||||
// ConfirmStorer defines attributes for the confirm module.
|
||||
type ConfirmStorer interface {
|
||||
authboss.Storer
|
||||
|
||||
PutConfirmed(ctx context.Context, confirmed bool) error
|
||||
PutConfirmToken(ctx context.Context, token string) error
|
||||
|
||||
GetConfirmed(ctx context.Context) (confirmed bool, err error)
|
||||
GetConfirmToken(ctx context.Context) (token string, err error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -48,30 +60,13 @@ func init() {
|
||||
// Confirm module
|
||||
type Confirm struct {
|
||||
*authboss.Authboss
|
||||
emailHTMLTemplates response.Templates
|
||||
emailTextTemplates response.Templates
|
||||
}
|
||||
|
||||
// Initialize the module
|
||||
func (c *Confirm) Initialize(ab *authboss.Authboss) (err error) {
|
||||
c.Authboss = ab
|
||||
|
||||
var ok bool
|
||||
storer, ok := c.Storer.(ConfirmStorer)
|
||||
if c.StoreMaker == nil && (storer == nil || !ok) {
|
||||
return errors.New("need a confirmStorer")
|
||||
}
|
||||
|
||||
c.emailHTMLTemplates, err = response.LoadTemplates(ab, c.LayoutHTMLEmail, c.ViewsPath, tplConfirmHTML)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.emailTextTemplates, err = response.LoadTemplates(ab, c.LayoutTextEmail, c.ViewsPath, tplConfirmText)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Callbacks.After(authboss.EventGetUser, func(ctx *authboss.Context) error {
|
||||
c.Callbacks.After(authboss.EventGetUser, func(ctx context.Context) error {
|
||||
_, err := c.beforeGet(ctx)
|
||||
return err
|
||||
})
|
||||
@ -88,17 +83,12 @@ func (c *Confirm) Routes() authboss.RouteTable {
|
||||
}
|
||||
}
|
||||
|
||||
// Storage requirements
|
||||
func (c *Confirm) Storage() authboss.StorageOptions {
|
||||
return authboss.StorageOptions{
|
||||
c.PrimaryID: authboss.String,
|
||||
authboss.StoreEmail: authboss.String,
|
||||
StoreConfirmToken: authboss.String,
|
||||
StoreConfirmed: authboss.Bool,
|
||||
}
|
||||
// Templates returns the list of templates required by this module
|
||||
func (c *Confirm) Templates() []string {
|
||||
return []string{tplConfirmHTML, tplConfirmText}
|
||||
}
|
||||
|
||||
func (c *Confirm) beforeGet(ctx *authboss.Context) (authboss.Interrupt, error) {
|
||||
func (c *Confirm) beforeGet(ctx context.Context) (authboss.Interrupt, error) {
|
||||
if confirmed, err := ctx.User.BoolErr(StoreConfirmed); err != nil {
|
||||
return authboss.InterruptNone, err
|
||||
} else if !confirmed {
|
||||
@ -109,7 +99,7 @@ func (c *Confirm) beforeGet(ctx *authboss.Context) (authboss.Interrupt, error) {
|
||||
}
|
||||
|
||||
// AfterRegister ensures the account is not activated.
|
||||
func (c *Confirm) afterRegister(ctx *authboss.Context) error {
|
||||
func (c *Confirm) afterRegister(ctx context.Context) error {
|
||||
if ctx.User == nil {
|
||||
return errUserMissing
|
||||
}
|
||||
@ -136,7 +126,7 @@ func (c *Confirm) afterRegister(ctx *authboss.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var goConfirmEmail = func(c *Confirm, ctx *authboss.Context, to, token string) {
|
||||
var goConfirmEmail = func(c *Confirm, ctx context.Context, to, token string) {
|
||||
if ctx.MailMaker != nil {
|
||||
c.confirmEmail(ctx, to, token)
|
||||
} else {
|
||||
@ -145,7 +135,7 @@ var goConfirmEmail = func(c *Confirm, ctx *authboss.Context, to, token string) {
|
||||
}
|
||||
|
||||
// confirmEmail sends a confirmation e-mail.
|
||||
func (c *Confirm) confirmEmail(ctx *authboss.Context, to, token string) {
|
||||
func (c *Confirm) confirmEmail(ctx context.Context, to, token string) {
|
||||
p := path.Join(c.MountPath, "confirm")
|
||||
url := fmt.Sprintf("%s%s?%s=%s", c.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token))
|
||||
|
||||
@ -161,7 +151,7 @@ func (c *Confirm) confirmEmail(ctx *authboss.Context, to, token string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
|
||||
func (c *Confirm) confirmHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
token := r.FormValue(FormValueConfirm)
|
||||
if len(token) == 0 {
|
||||
return authboss.ClientDataErr{Name: FormValueConfirm}
|
||||
|
150
context.go
Normal file
150
context.go
Normal file
@ -0,0 +1,150 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
ctxKeyPID contextKey = "pid"
|
||||
ctxKeyUser contextKey = "user"
|
||||
)
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "authboss ctx key " + string(c)
|
||||
}
|
||||
|
||||
// CurrentUserID retrieves the current user from the session.
|
||||
func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
_, err := a.Callbacks.FireBefore(EventGetUserSession, r.Context())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
key, _ := session.Get(SessionKey)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// CurrentUserIDP retrieves the current user but panics if it's not available for
|
||||
// any reason.
|
||||
func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string {
|
||||
i, err := a.CurrentUserID(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(i) == 0 {
|
||||
panic(ErrUserNotFound)
|
||||
}
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
// CurrentUser retrieves the current user from the session and the database.
|
||||
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (Storer, error) {
|
||||
pid, err := a.CurrentUserID(w, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(pid) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return a.currentUser(r.Context(), pid)
|
||||
}
|
||||
|
||||
// CurrentUserP retrieves the current user but panics if it's not available for
|
||||
// any reason.
|
||||
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) Storer {
|
||||
i, err := a.CurrentUser(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func (a *Authboss) currentUser(ctx context.Context, pid string) (Storer, error) {
|
||||
_, err := a.Callbacks.FireBefore(EventGetUser, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := a.StoreLoader.Load(ctx, pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, user)
|
||||
err = a.Callbacks.FireAfter(EventGetUser, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// LoadCurrentUser takes a pointer to a pointer to the request in order to
|
||||
// change the current method's request pointer itself to the new request that
|
||||
// contains the new context that has the pid in it.
|
||||
func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (string, error) {
|
||||
pid, err := a.CurrentUserID(w, *r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(pid) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx := context.WithValue((**r).Context(), ctxKeyPID, pid)
|
||||
*r = (**r).WithContext(ctx)
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) string {
|
||||
pid, err := a.LoadCurrentUserID(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(pid) == 0 {
|
||||
panic(ErrUserNotFound)
|
||||
}
|
||||
|
||||
return pid
|
||||
}
|
||||
|
||||
// LoadCurrentUser takes a pointer to a pointer to the request in order to
|
||||
// change the current method's request pointer itself to the new request that
|
||||
// contains the new context that has the user in it. Calls LoadCurrentUserID
|
||||
// so the primary id is also put in the context.
|
||||
func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (Storer, error) {
|
||||
pid, err := a.LoadCurrentUserID(w, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(pid) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctx := (**r).Context()
|
||||
user, err := a.currentUser(ctx, pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, user)
|
||||
*r = (**r).WithContext(ctx)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a *Authboss) LoadCurrentUserP(w http.ResponseWriter, r **http.Request) Storer {
|
||||
user, err := a.LoadCurrentUser(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if user == nil {
|
||||
panic(ErrUserNotFound)
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
170
context_test.go
Normal file
170
context_test.go
Normal file
@ -0,0 +1,170 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCurrentUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
|
||||
id, err := ab.CurrentUserID(nil, httptest.NewRequest("GET", "/", nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if id != "george-pid" {
|
||||
t.Error("got:", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCurrentUserIDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
t.Failed()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = ab.CurrentUserIDP(nil, httptest.NewRequest("GET", "/", nil))
|
||||
}
|
||||
|
||||
func TestCurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
|
||||
user, err := ab.CurrentUser(nil, httptest.NewRequest("GET", "/", nil))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := user.GetEmail(context.TODO()); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != "george-pid" {
|
||||
t.Error("got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCurrentUserP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{}
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
t.Failed()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = ab.CurrentUserP(nil, httptest.NewRequest("GET", "/", nil))
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
id, err := ab.LoadCurrentUserID(nil, &req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if id != "george-pid" {
|
||||
t.Error("got:", id)
|
||||
}
|
||||
|
||||
if req.Context().Value(ctxKeyPID).(string) != "george-pid" {
|
||||
t.Error("context was not updated in local request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserIDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
t.Failed()
|
||||
}
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserIDP(nil, &req)
|
||||
}
|
||||
|
||||
func TestLoadCurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
user, err := ab.LoadCurrentUser(nil, &req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got, err := user.GetEmail(context.TODO()); err != nil {
|
||||
t.Error(err)
|
||||
} else if got != "george-pid" {
|
||||
t.Error("got:", got)
|
||||
}
|
||||
|
||||
want := user.(mockStoredUser).mockUser
|
||||
got := req.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
|
||||
if got != want {
|
||||
t.Error("users mismatched:\nwant: %#v\ngot: %#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{}
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
t.Failed()
|
||||
}
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserP(nil, &req)
|
||||
}
|
41
errors.go
41
errors.go
@ -2,32 +2,6 @@ package authboss
|
||||
|
||||
import "fmt"
|
||||
|
||||
// AttributeErr represents a failure to retrieve a critical
|
||||
// piece of data from the storer.
|
||||
type AttributeErr struct {
|
||||
Name string
|
||||
WantKind DataType
|
||||
GotKind string
|
||||
}
|
||||
|
||||
// NewAttributeErr creates a new attribute err type. Useful for when you want
|
||||
// to have a type mismatch error.
|
||||
func NewAttributeErr(name string, kind DataType, val interface{}) AttributeErr {
|
||||
return AttributeErr{
|
||||
Name: name,
|
||||
WantKind: kind,
|
||||
GotKind: fmt.Sprintf("%T", val),
|
||||
}
|
||||
}
|
||||
|
||||
func (a AttributeErr) Error() string {
|
||||
if len(a.GotKind) == 0 {
|
||||
return fmt.Sprintf("Failed to retrieve database attribute: %s", a.Name)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Failed to retrieve database attribute, type was wrong: %s (want: %v, got: %s)", a.Name, a.WantKind, a.GotKind)
|
||||
}
|
||||
|
||||
// ClientDataErr represents a failure to retrieve a critical
|
||||
// piece of client information such as a cookie or session value.
|
||||
type ClientDataErr struct {
|
||||
@ -38,19 +12,6 @@ func (c ClientDataErr) Error() string {
|
||||
return fmt.Sprintf("Failed to retrieve client attribute: %s", c.Name)
|
||||
}
|
||||
|
||||
// ErrAndRedirect represents a general error whose response should
|
||||
// be to redirect.
|
||||
type ErrAndRedirect struct {
|
||||
Err error
|
||||
Location string
|
||||
FlashSuccess string
|
||||
FlashError string
|
||||
}
|
||||
|
||||
func (e ErrAndRedirect) Error() string {
|
||||
return fmt.Sprintf("Error: %v, Redirecting to: %s", e.Err, e.Location)
|
||||
}
|
||||
|
||||
// RenderErr represents an error that occured during rendering
|
||||
// of a template.
|
||||
type RenderErr struct {
|
||||
@ -60,5 +21,5 @@ type RenderErr struct {
|
||||
}
|
||||
|
||||
func (r RenderErr) Error() string {
|
||||
return fmt.Sprintf("Error rendering template %q: %v, data: %#v", r.TemplateName, r.Err, r.Data)
|
||||
return fmt.Sprintf("error rendering response %q: %v, data: %#v", r.TemplateName, r.Err, r.Data)
|
||||
}
|
||||
|
@ -6,21 +6,6 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestAttributeErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
estr := "Failed to retrieve database attribute, type was wrong: lol (want: String, got: int)"
|
||||
if str := NewAttributeErr("lol", String, 5).Error(); str != estr {
|
||||
t.Error("Error was wrong:", str)
|
||||
}
|
||||
|
||||
estr = "Failed to retrieve database attribute: lol"
|
||||
err := AttributeErr{Name: "lol"}
|
||||
if str := err.Error(); str != estr {
|
||||
t.Error("Error was wrong:", str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDataErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -31,20 +16,10 @@ func TestClientDataErr(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrAndRedirect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
estr := "Error: cause, Redirecting to: /"
|
||||
err := ErrAndRedirect{errors.New("cause"), "/", "success", "failure"}
|
||||
if str := err.Error(); str != estr {
|
||||
t.Error("Error was wrong:", str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderErr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
estr := `Error rendering template "lol": cause, data: authboss.HTMLData{"a":5}`
|
||||
estr := `error rendering response "lol": cause, data: authboss.HTMLData{"a":5}`
|
||||
err := RenderErr{"lol", NewHTMLData("a", 5), errors.New("cause")}
|
||||
if str := err.Error(); str != estr {
|
||||
t.Error("Error was wrong:", str)
|
||||
|
@ -9,7 +9,7 @@ var nowTime = time.Now
|
||||
|
||||
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
|
||||
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
|
||||
return a.timeToExpiry(a.SessionStoreMaker(w, r))
|
||||
return a.timeToExpiry(a.SessionStoreMaker.Make(w, r))
|
||||
}
|
||||
|
||||
func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
|
||||
@ -33,7 +33,7 @@ func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
|
||||
|
||||
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
|
||||
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
|
||||
session := a.SessionStoreMaker(w, r)
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
a.refreshExpiry(session)
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
|
||||
|
||||
// ServeHTTP removes the session if it's passed the expire time.
|
||||
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
session := m.ab.SessionStoreMaker(w, r)
|
||||
session := m.ab.SessionStoreMaker.Make(w, r)
|
||||
if _, ok := session.Get(SessionKey); ok {
|
||||
ttl := m.ab.timeToExpiry(session)
|
||||
if ttl == 0 {
|
||||
|
@ -11,15 +11,18 @@ import (
|
||||
|
||||
func TestDudeIsExpired(t *testing.T) {
|
||||
ab := New()
|
||||
|
||||
session := mockClientStore{SessionKey: "username"}
|
||||
ab.refreshExpiry(session)
|
||||
|
||||
// No t.Parallel()
|
||||
nowTime = func() time.Time {
|
||||
return time.Now().UTC().Add(ab.ExpireAfter * 2)
|
||||
}
|
||||
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return session
|
||||
}
|
||||
defer func() {
|
||||
nowTime = time.Now
|
||||
}()
|
||||
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
|
||||
r, _ := http.NewRequest("GET", "tra/la/la", nil)
|
||||
w := httptest.NewRecorder()
|
||||
@ -36,25 +39,28 @@ func TestDudeIsExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
if key, ok := session.Get(SessionKey); ok {
|
||||
t.Error("Unexpcted session key:", key)
|
||||
t.Error("Unexpected session key:", key)
|
||||
}
|
||||
|
||||
if key, ok := session.Get(SessionLastAction); ok {
|
||||
t.Error("Unexpcted last action key:", key)
|
||||
t.Error("Unexpected last action key:", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDudeIsNotExpired(t *testing.T) {
|
||||
ab := New()
|
||||
|
||||
session := mockClientStore{SessionKey: "username"}
|
||||
ab.refreshExpiry(session)
|
||||
|
||||
// No t.Parallel()
|
||||
nowTime = func() time.Time {
|
||||
return time.Now().UTC().Add(ab.ExpireAfter / 2)
|
||||
}
|
||||
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
|
||||
return session
|
||||
}
|
||||
defer func() {
|
||||
nowTime = time.Now
|
||||
}()
|
||||
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
|
||||
r, _ := http.NewRequest("GET", "tra/la/la", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -2,6 +2,7 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -25,113 +26,99 @@ type MockUser struct {
|
||||
Locked bool
|
||||
AttemptNumber int
|
||||
AttemptTime time.Time
|
||||
OauthToken string
|
||||
OauthRefresh string
|
||||
OauthExpiry time.Time
|
||||
OAuthToken string
|
||||
OAuthRefresh string
|
||||
OAuthExpiry time.Time
|
||||
}
|
||||
|
||||
func (m MockUser) GetUsername(context.Context) (string, error) { return m.Username, nil }
|
||||
func (m MockUser) GetEmail(context.Context) (string, error) { return m.Email, nil }
|
||||
func (m MockUser) GetPassword(context.Context) (string, error) { return m.Password, nil }
|
||||
func (m MockUser) GetRecoverToken(context.Context) (string, error) { return m.RecoverToken, nil }
|
||||
func (m MockUser) GetRecoverTokenExpiry(context.Context) (time.Time, error) {
|
||||
return m.RecoverTokenExpiry, nil
|
||||
}
|
||||
func (m MockUser) GetConfirmToken(context.Context) (string, error) { return m.ConfirmToken, nil }
|
||||
func (m MockUser) GetConfirmed(context.Context) (bool, error) { return m.Confirmed, nil }
|
||||
func (m MockUser) GetLocked(context.Context) (bool, error) { return m.Locked, nil }
|
||||
func (m MockUser) GetAttemptNumber(context.Context) (int, error) { return m.AttemptNumber, nil }
|
||||
func (m MockUser) GetAttemptTime(context.Context) (time.Time, error) {
|
||||
return m.AttemptTime, nil
|
||||
}
|
||||
func (m MockUser) GetOAuthToken(context.Context) (string, error) { return m.OAuthToken, nil }
|
||||
func (m MockUser) GetOAuthRefresh(context.Context) (string, error) { return m.OAuthRefresh, nil }
|
||||
func (m MockUser) GetOAuthExpiry(context.Context) (time.Time, error) {
|
||||
return m.OAuthExpiry, nil
|
||||
}
|
||||
|
||||
func (m *MockUser) SetUsername(ctx context.Context, username string) error {
|
||||
m.Username = username
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetEmail(ctx context.Context, email string) error {
|
||||
m.Email = email
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetPassword(ctx context.Context, password string) error {
|
||||
m.Password = password
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetRecoverToken(ctx context.Context, recoverToken string) error {
|
||||
m.RecoverToken = recoverToken
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetRecoverTokenExpiry(ctx context.Context, recoverTokenExpiry time.Time) error {
|
||||
m.RecoverTokenExpiry = recoverTokenExpiry
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetConfirmToken(ctx context.Context, confirmToken string) error {
|
||||
m.ConfirmToken = confirmToken
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetConfirmed(ctx context.Context, confirmed bool) error {
|
||||
m.Confirmed = confirmed
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetLocked(ctx context.Context, locked bool) error {
|
||||
m.Locked = locked
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetAttemptNumber(ctx context.Context, attemptNumber int) error {
|
||||
m.AttemptNumber = attemptNumber
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetAttemptTime(ctx context.Context, attemptTime time.Time) error {
|
||||
m.AttemptTime = attemptTime
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetOAuthToken(ctx context.Context, oAuthToken string) error {
|
||||
m.OAuthToken = oAuthToken
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetOAuthRefresh(ctx context.Context, oAuthRefresh string) error {
|
||||
m.OAuthRefresh = oAuthRefresh
|
||||
return nil
|
||||
}
|
||||
func (m *MockUser) SetOAuthExpiry(ctx context.Context, oAuthExpiry time.Time) error {
|
||||
m.OAuthExpiry = oAuthExpiry
|
||||
return nil
|
||||
}
|
||||
|
||||
// MockStorer should be valid for any module storer defined in authboss.
|
||||
type MockStorer struct {
|
||||
Users map[string]authboss.Attributes
|
||||
Tokens map[string][]string
|
||||
CreateErr string
|
||||
PutErr string
|
||||
GetErr string
|
||||
AddTokenErr string
|
||||
DelTokensErr string
|
||||
UseTokenErr string
|
||||
RecoverUserErr string
|
||||
ConfirmUserErr string
|
||||
type MockStoreLoader struct {
|
||||
Users map[string]*MockUser
|
||||
RMTokens map[string][]string
|
||||
}
|
||||
|
||||
// NewMockStorer constructor
|
||||
func NewMockStorer() *MockStorer {
|
||||
return &MockStorer{
|
||||
Users: make(map[string]authboss.Attributes),
|
||||
Tokens: make(map[string][]string),
|
||||
func NewMockStoreLoader() *MockStoreLoader {
|
||||
return &MockStoreLoader{
|
||||
Users: make(map[string]*MockUser),
|
||||
RMTokens: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new user
|
||||
func (m *MockStorer) Create(key string, attr authboss.Attributes) error {
|
||||
if len(m.CreateErr) > 0 {
|
||||
return errors.New(m.CreateErr)
|
||||
}
|
||||
|
||||
m.Users[key] = attr
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put updates to a user
|
||||
func (m *MockStorer) Put(key string, attr authboss.Attributes) error {
|
||||
if len(m.PutErr) > 0 {
|
||||
return errors.New(m.PutErr)
|
||||
}
|
||||
|
||||
if _, ok := m.Users[key]; !ok {
|
||||
m.Users[key] = attr
|
||||
return nil
|
||||
}
|
||||
for k, v := range attr {
|
||||
m.Users[key][k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get a user
|
||||
func (m *MockStorer) Get(key string) (result interface{}, err error) {
|
||||
if len(m.GetErr) > 0 {
|
||||
return nil, errors.New(m.GetErr)
|
||||
}
|
||||
|
||||
userAttrs, ok := m.Users[key]
|
||||
if !ok {
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
u := &MockUser{}
|
||||
if err := userAttrs.Bind(u, true); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// PutOAuth user
|
||||
func (m *MockStorer) PutOAuth(uid, provider string, attr authboss.Attributes) error {
|
||||
if len(m.PutErr) > 0 {
|
||||
return errors.New(m.PutErr)
|
||||
}
|
||||
|
||||
if _, ok := m.Users[uid+provider]; !ok {
|
||||
m.Users[uid+provider] = attr
|
||||
return nil
|
||||
}
|
||||
for k, v := range attr {
|
||||
m.Users[uid+provider][k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOAuth user
|
||||
func (m *MockStorer) GetOAuth(uid, provider string) (result interface{}, err error) {
|
||||
if len(m.GetErr) > 0 {
|
||||
return nil, errors.New(m.GetErr)
|
||||
}
|
||||
|
||||
userAttrs, ok := m.Users[uid+provider]
|
||||
if !ok {
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
|
||||
u := &MockUser{}
|
||||
if err := userAttrs.Bind(u, true); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
/*
|
||||
// AddToken for remember me
|
||||
func (m *MockStorer) AddToken(key, token string) error {
|
||||
if len(m.AddTokenErr) > 0 {
|
||||
@ -211,23 +198,26 @@ func (m *MockStorer) ConfirmUser(confirmToken string) (result interface{}, err e
|
||||
|
||||
return nil, authboss.ErrUserNotFound
|
||||
}
|
||||
*/
|
||||
|
||||
// MockFailStorer is used for testing module initialize functions that recover more than the base storer
|
||||
type MockFailStorer struct{}
|
||||
type MockFailStorer struct {
|
||||
MockUser
|
||||
}
|
||||
|
||||
// Create fails
|
||||
func (_ MockFailStorer) Create(_ string, _ authboss.Attributes) error {
|
||||
func (_ MockFailStorer) Create(context.Context) error {
|
||||
return errors.New("fail storer: create")
|
||||
}
|
||||
|
||||
// Put fails
|
||||
func (_ MockFailStorer) Put(_ string, _ authboss.Attributes) error {
|
||||
func (_ MockFailStorer) Save(context.Context) error {
|
||||
return errors.New("fail storer: put")
|
||||
}
|
||||
|
||||
// Get fails
|
||||
func (_ MockFailStorer) Get(_ string) (interface{}, error) {
|
||||
return nil, errors.New("fail storer: get")
|
||||
func (_ MockFailStorer) Load(context.Context) error {
|
||||
return errors.New("fail storer: get")
|
||||
}
|
||||
|
||||
// MockClientStorer is used for testing the client stores on context
|
||||
@ -322,7 +312,7 @@ func NewMockMailer() *MockMailer {
|
||||
}
|
||||
|
||||
// Send an e-mail
|
||||
func (m *MockMailer) Send(email authboss.Email) error {
|
||||
func (m *MockMailer) Send(ctx context.Context, email authboss.Email) error {
|
||||
if len(m.SendErr) > 0 {
|
||||
return errors.New(m.SendErr)
|
||||
}
|
||||
@ -341,7 +331,7 @@ type MockAfterCallback struct {
|
||||
func NewMockAfterCallback() *MockAfterCallback {
|
||||
m := MockAfterCallback{}
|
||||
|
||||
m.Fn = func(_ *authboss.Context) error {
|
||||
m.Fn = func(context.Context) error {
|
||||
m.HasBeenCalled = true
|
||||
return nil
|
||||
}
|
||||
|
@ -3,30 +3,13 @@ package response
|
||||
|
||||
//go:generate go-bindata -pkg=response -prefix=templates templates
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/go-authboss/authboss"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTemplateNotFound should be returned from Get when the view is not found
|
||||
ErrTemplateNotFound = errors.New("template not found")
|
||||
)
|
||||
// TODO(aarondl): Extract this into the default "template" provider
|
||||
|
||||
/*
|
||||
// Templates is a map depicting the forms a template needs wrapped within the specified layout
|
||||
type Templates map[string]*template.Template
|
||||
|
||||
|
||||
// LoadTemplates parses all specified files located in fpath. Each template is wrapped
|
||||
// in a unique clone of layout. All templates are expecting {{authboss}} handlebars
|
||||
// for parsing. It will check the override directory specified in the config, replacing any
|
||||
@ -158,3 +141,4 @@ func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, pat
|
||||
}
|
||||
http.Redirect(w, r, path, http.StatusFound)
|
||||
}
|
||||
*/
|
||||
|
@ -76,7 +76,7 @@ func TestBoundary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mailer := smtpMailer{"server", nil, rand.New(rand.NewSource(3))}
|
||||
if got := mailer.boundary(); got != "ntadoe" {
|
||||
if got := mailer.boundary(); got != "fe3fhpsm69lx8jvnrnju0wr" {
|
||||
t.Error("boundary was wrong", got)
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@ -77,8 +78,20 @@ func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err e
|
||||
return m.Password, nil
|
||||
}
|
||||
|
||||
type mockClientStoreMaker struct {
|
||||
store mockClientStore
|
||||
}
|
||||
type mockClientStore map[string]string
|
||||
|
||||
func newMockClientStoreMaker(store mockClientStore) mockClientStoreMaker {
|
||||
return mockClientStoreMaker{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
func (m mockClientStoreMaker) Make(w http.ResponseWriter, r *http.Request) ClientStorer {
|
||||
return m.store
|
||||
}
|
||||
|
||||
func (m mockClientStore) Get(key string) (string, bool) {
|
||||
v, ok := m[key]
|
||||
return v, ok
|
||||
@ -125,3 +138,20 @@ func (m mockValidator) Errors(in string) ErrorList {
|
||||
func (m mockValidator) Rules() []string {
|
||||
return m.Ruleset
|
||||
}
|
||||
|
||||
type mockRenderLoader struct{}
|
||||
|
||||
func (m mockRenderLoader) Init(names []string) (Renderer, error) {
|
||||
return mockRenderer{}, nil
|
||||
}
|
||||
|
||||
type mockRenderer struct{}
|
||||
|
||||
func (m mockRenderer) Render(ctx context.Context, name string, data HTMLData) ([]byte, string, error) {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return b, "application/json", nil
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ var registeredModules = make(map[string]Modularizer)
|
||||
type Modularizer interface {
|
||||
Initialize(*Authboss) error
|
||||
Routes() RouteTable
|
||||
Templates() []string
|
||||
}
|
||||
|
||||
// RegisterModule with the core providing all the necessary information to
|
||||
@ -55,6 +56,7 @@ func (a *Authboss) loadModule(name string) error {
|
||||
}
|
||||
mod, ok := value.Interface().(Modularizer)
|
||||
a.loadedModules[name] = mod
|
||||
a.templateNames = append(a.templateNames, mod.Templates()...)
|
||||
return mod.Initialize(a)
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,7 @@ func testHandler(w http.ResponseWriter, r *http.Request) error {
|
||||
|
||||
func (t *testModule) Initialize(a *Authboss) error { return nil }
|
||||
func (t *testModule) Routes() RouteTable { return t.r }
|
||||
func (t *testModule) Templates() []string { return []string{"template1.tpl"} }
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
// RegisterModule called by init()
|
||||
@ -59,6 +60,7 @@ func TestLoadedModules(t *testing.T) {
|
||||
func TestIsLoaded(t *testing.T) {
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
if err := ab.Init(testModName); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
@ -6,10 +6,10 @@ import "context"
|
||||
// It's possible that Init() is a no-op if the responses are JSON or anything
|
||||
// else.
|
||||
type RenderLoader interface {
|
||||
Init(names string) (Renderer, error)
|
||||
Init(names []string) (Renderer, error)
|
||||
}
|
||||
|
||||
// Renderer is a type that can render a given template with some data.
|
||||
type Renderer interface {
|
||||
Render(ctx context.Context, data HTMLData) ([]byte, error)
|
||||
Render(ctx context.Context, name string, data HTMLData) (output []byte, contentType string, err error)
|
||||
}
|
||||
|
168
response.go
Normal file
168
response.go
Normal file
@ -0,0 +1,168 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// RedirectOptions packages up all the pieces a module needs to write out a
|
||||
// response.
|
||||
type RedirectOptions struct {
|
||||
// Success & Failure are used to set Flash messages / JSON messages
|
||||
// if set. They should be mutually exclusive.
|
||||
Success string
|
||||
Failure string
|
||||
|
||||
// Code is used when it's an API request instead of 200.
|
||||
Code int
|
||||
|
||||
// When a request should redirect a user somewhere on completion, these
|
||||
// should be set. RedirectURL tells it where to go. And optionally set
|
||||
// FollowRedirParam to override the RedirectURL if the form parameter defined
|
||||
// by FormValueRedirect is passed in the request.
|
||||
//
|
||||
// Redirecting works differently whether it's an API request or not.
|
||||
// If it's an API request, then it will leave the URL in a "redirect"
|
||||
// parameter.
|
||||
RedirectPath string
|
||||
FollowRedirParam bool
|
||||
}
|
||||
|
||||
// Respond to an HTTP request.
|
||||
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
|
||||
data.MergeKV(
|
||||
"xsrfName", template.HTML(a.XSRFName),
|
||||
"xsrfToken", template.HTML(a.XSRFMaker(w, r)),
|
||||
)
|
||||
|
||||
if a.LayoutDataMaker != nil {
|
||||
data.Merge(a.LayoutDataMaker(w, r))
|
||||
}
|
||||
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
if flash, ok := session.Get(FlashSuccessKey); ok {
|
||||
session.Del(FlashSuccessKey)
|
||||
data.MergeKV(FlashSuccessKey, flash)
|
||||
}
|
||||
if flash, ok := session.Get(FlashErrorKey); ok {
|
||||
session.Del(FlashErrorKey)
|
||||
data.MergeKV(FlashErrorKey, flash)
|
||||
}
|
||||
|
||||
rendered, mime, err := a.renderer.Render(r.Context(), templateName, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", mime)
|
||||
w.WriteHeader(code)
|
||||
|
||||
_, err = w.Write(rendered)
|
||||
return err
|
||||
}
|
||||
|
||||
// EmailResponseOptions controls how e-mails are rendered and sent
|
||||
type EmailResponseOptions struct {
|
||||
Data HTMLData
|
||||
HTMLTemplate string
|
||||
TextTemplate string
|
||||
}
|
||||
|
||||
// Email renders the e-mail templates and sends it using the mailer.
|
||||
func (a *Authboss) Email(w http.ResponseWriter, r *http.Request, email Email, ro EmailResponseOptions) error {
|
||||
ctx := r.Context()
|
||||
|
||||
if len(ro.HTMLTemplate) != 0 {
|
||||
htmlBody, _, err := a.renderer.Render(ctx, ro.HTMLTemplate, ro.Data)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to render e-mail html body")
|
||||
}
|
||||
email.HTMLBody = string(htmlBody)
|
||||
}
|
||||
|
||||
if len(ro.TextTemplate) != 0 {
|
||||
textBody, _, err := a.renderer.Render(ctx, ro.TextTemplate, ro.Data)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to render e-mail text body")
|
||||
}
|
||||
email.TextBody = string(textBody)
|
||||
}
|
||||
|
||||
return a.Mailer.Send(ctx, email)
|
||||
}
|
||||
|
||||
// Redirect the client elsewhere. If it's an API request it will simply render
|
||||
// a JSON response with information that should help a client to decide what
|
||||
// to do.
|
||||
func (a *Authboss) Redirect(w http.ResponseWriter, r *http.Request, ro RedirectOptions) error {
|
||||
var redirectFunction = a.redirectNonAPI
|
||||
if isAPIRequest(r) {
|
||||
redirectFunction = a.redirectAPI
|
||||
}
|
||||
|
||||
return redirectFunction(w, r, ro)
|
||||
}
|
||||
|
||||
func (a *Authboss) redirectAPI(w http.ResponseWriter, r *http.Request, ro RedirectOptions) error {
|
||||
path := ro.RedirectPath
|
||||
redir := r.FormValue(FormValueRedirect)
|
||||
if len(redir) != 0 && ro.FollowRedirParam {
|
||||
path = redir
|
||||
}
|
||||
|
||||
var status, message string
|
||||
if len(ro.Success) != 0 {
|
||||
status = "success"
|
||||
message = ro.Success
|
||||
}
|
||||
if len(ro.Failure) != 0 {
|
||||
status = "failure"
|
||||
message = ro.Failure
|
||||
}
|
||||
|
||||
data := HTMLData{
|
||||
"path": path,
|
||||
}
|
||||
|
||||
if len(status) != 0 {
|
||||
data["status"] = status
|
||||
data["message"] = message
|
||||
}
|
||||
|
||||
body, mime, err := a.renderer.Render(r.Context(), "redirect", data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(body) != 0 {
|
||||
w.Header().Set("Content-Type", mime)
|
||||
}
|
||||
|
||||
if ro.Code != 0 {
|
||||
w.WriteHeader(ro.Code)
|
||||
}
|
||||
_, err = w.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *Authboss) redirectNonAPI(w http.ResponseWriter, r *http.Request, ro RedirectOptions) error {
|
||||
path := ro.RedirectPath
|
||||
redir := r.FormValue(FormValueRedirect)
|
||||
if len(redir) != 0 && ro.FollowRedirParam {
|
||||
path = redir
|
||||
}
|
||||
|
||||
if len(ro.Success) != 0 {
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
session.Put(FlashSuccessKey, ro.Success)
|
||||
}
|
||||
if len(ro.Failure) != 0 {
|
||||
session := a.SessionStoreMaker.Make(w, r)
|
||||
session.Put(FlashErrorKey, ro.Failure)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, path, http.StatusFound)
|
||||
return nil
|
||||
}
|
197
router.go
197
router.go
@ -28,7 +28,7 @@ func (a *Authboss) NewRouter() http.Handler {
|
||||
for name, mod := range a.loadedModules {
|
||||
for route, handler := range mod.Routes() {
|
||||
fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route))
|
||||
a.mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
|
||||
a.mux.Handle(path.Join(a.MountPath, route), abHandler{a, handler})
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,106 +44,105 @@ func (a *Authboss) NewRouter() http.Handler {
|
||||
return a.mux
|
||||
}
|
||||
|
||||
type contextRoute struct {
|
||||
type abHandler struct {
|
||||
*Authboss
|
||||
fn HandlerFunc
|
||||
}
|
||||
|
||||
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
/*
|
||||
// Instantiate the context
|
||||
ctx := c.Authboss.InitContext(w, r)
|
||||
|
||||
// Check to make sure we actually need to visit this route
|
||||
if redirectIfLoggedIn(ctx, w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
// Call the handler
|
||||
err := c.fn(ctx, w, r)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Log the error
|
||||
fmt.Fprintf(c.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err)
|
||||
|
||||
// Do specific error handling for special kinds of errors.
|
||||
switch e := err.(type) {
|
||||
case ErrAndRedirect:
|
||||
if len(e.FlashSuccess) > 0 {
|
||||
ctx.SessionStorer.Put(FlashSuccessKey, e.FlashSuccess)
|
||||
}
|
||||
if len(e.FlashError) > 0 {
|
||||
ctx.SessionStorer.Put(FlashErrorKey, e.FlashError)
|
||||
}
|
||||
http.Redirect(w, r, e.Location, http.StatusFound)
|
||||
case ClientDataErr:
|
||||
if c.BadRequestHandler != nil {
|
||||
c.BadRequestHandler.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(w, "400 Bad request")
|
||||
}
|
||||
default:
|
||||
if c.ErrorHandler != nil {
|
||||
c.ErrorHandler.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// redirectIfLoggedIn checks a user's existence by using currentUser. This is done instead of
|
||||
// a simple Session cookie check so that the remember module has a chance to log the user in
|
||||
// before they are determined to "not be logged in".
|
||||
//
|
||||
// The exceptional routes are sort of hardcoded in a terrible way in here, later on this could move to some
|
||||
// configuration or something more interesting.
|
||||
func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (handled bool) {
|
||||
// If it's a log out url, always let it pass through.
|
||||
if strings.HasSuffix(r.URL.Path, "/logout") {
|
||||
return false
|
||||
}
|
||||
|
||||
// If it's an auth url, allow them through if they're half-authed.
|
||||
if strings.HasSuffix(r.URL.Path, "/auth") || strings.Contains(r.URL.Path, "/oauth2/") {
|
||||
if halfAuthed, ok := ctx.SessionStorer.Get(SessionHalfAuthKey); ok && halfAuthed == "true" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, check if they're logged in, this uses hooks to allow remember
|
||||
// to set the session cookie
|
||||
cu, err := ctx.currentUser(ctx, w, r)
|
||||
|
||||
// if the user was not found, that means the user was deleted from the underlying
|
||||
// storer and we should just remove this session cookie and allow them through.
|
||||
// if it's a generic error, 500
|
||||
// if the user is found, redirect them away from this page, because they don't need
|
||||
// to see it.
|
||||
if err == ErrUserNotFound {
|
||||
uname, _ := ctx.SessionStorer.Get(SessionKey)
|
||||
fmt.Fprintf(ctx.LogWriter, "user (%s) has session cookie but user not found, removing cookie", uname)
|
||||
ctx.SessionStorer.Del(SessionKey)
|
||||
return false
|
||||
} else if err != nil {
|
||||
fmt.Fprintf(ctx.LogWriter, "error occurred reading current user at %s: %v", r.URL.Path, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
return true
|
||||
}
|
||||
|
||||
if cu != nil {
|
||||
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
|
||||
http.Redirect(w, r, redir, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
*/
|
||||
// TODO(aarondl): Move this somewhere reasonable
|
||||
func isAPIRequest(r *http.Request) bool {
|
||||
return r.Header.Get("Content-Type") == "application/json"
|
||||
}
|
||||
|
||||
func (a abHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Put uid in the context
|
||||
_, err := a.LoadCurrentUserID(w, &r)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
fmt.Fprintf(a.LogWriter, "failed to load current user id: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Call the handler
|
||||
err = a.fn(w, r)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Log the error
|
||||
fmt.Fprintf(a.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err)
|
||||
|
||||
// Do specific error handling for special kinds of errors.
|
||||
if _, ok := err.(ClientDataErr); ok {
|
||||
if a.BadRequestHandler != nil {
|
||||
a.BadRequestHandler.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
io.WriteString(w, "400 Bad request")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if a.ErrorHandler != nil {
|
||||
a.ErrorHandler.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
// TODO(aarondl): Throw away this function
|
||||
// redirectIfLoggedIn checks a user's existence by using currentUser. This is done instead of
|
||||
// a simple Session cookie check so that the remember module has a chance to log the user in
|
||||
// before they are determined to "not be logged in".
|
||||
//
|
||||
// The exceptional routes are sort of hardcoded in a terrible way in here, later on this could move to some
|
||||
// configuration or something more interesting.
|
||||
func redirectIfLoggedIn(w http.ResponseWriter, r *http.Request) (handled bool) {
|
||||
// If it's a log out url, always let it pass through.
|
||||
if strings.HasSuffix(r.URL.Path, "/logout") {
|
||||
return false
|
||||
}
|
||||
|
||||
// If it's an auth url, allow them through if they're half-authed.
|
||||
if strings.HasSuffix(r.URL.Path, "/auth") || strings.Contains(r.URL.Path, "/oauth2/") {
|
||||
if halfAuthed, ok := ctx.SessionStorer.Get(SessionHalfAuthKey); ok && halfAuthed == "true" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, check if they're logged in, this uses hooks to allow remember
|
||||
// to set the session cookie
|
||||
cu, err := ctx.currentUser(ctx, w, r)
|
||||
|
||||
// if the user was not found, that means the user was deleted from the underlying
|
||||
// storer and we should just remove this session cookie and allow them through.
|
||||
// if it's a generic error, 500
|
||||
// if the user is found, redirect them away from this page, because they don't need
|
||||
// to see it.
|
||||
if err == ErrUserNotFound {
|
||||
uname, _ := ctx.SessionStorer.Get(SessionKey)
|
||||
fmt.Fprintf(ctx.LogWriter, "user (%s) has session cookie but user not found, removing cookie", uname)
|
||||
ctx.SessionStorer.Del(SessionKey)
|
||||
return false
|
||||
} else if err != nil {
|
||||
fmt.Fprintf(ctx.LogWriter, "error occurred reading current user at %s: %v", r.URL.Path, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, "500 An error has occurred")
|
||||
return true
|
||||
}
|
||||
|
||||
if cu != nil {
|
||||
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
|
||||
http.Redirect(w, r, redir, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
*/
|
||||
|
171
router_test.go
171
router_test.go
@ -2,8 +2,6 @@ package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@ -24,15 +22,17 @@ type testRouterModule struct {
|
||||
|
||||
func (t testRouterModule) Initialize(ab *Authboss) error { return nil }
|
||||
func (t testRouterModule) Routes() RouteTable { return t.routes }
|
||||
func (t testRouterModule) Templates() []string { return []string{"template1.tpl"} }
|
||||
|
||||
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
|
||||
ab := New()
|
||||
logger := &bytes.Buffer{}
|
||||
ab.LogWriter = logger
|
||||
ab.ViewLoader = mockRenderLoader{}
|
||||
ab.Init(testRouterModName)
|
||||
ab.MountPath = "/prefix"
|
||||
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
|
||||
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
|
||||
logger.Reset() // Clear out the module load messages
|
||||
|
||||
@ -169,166 +169,3 @@ func TestRouter_Error(t *testing.T) {
|
||||
t.Error(str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_Redirect(t *testing.T) {
|
||||
err := ErrAndRedirect{
|
||||
Err: errors.New("error"),
|
||||
Location: "/",
|
||||
FlashSuccess: "yay",
|
||||
FlashError: "nay",
|
||||
}
|
||||
|
||||
w, r := testRouterCallbackSetup("/error",
|
||||
func(http.ResponseWriter, *http.Request) error {
|
||||
return err
|
||||
},
|
||||
)
|
||||
|
||||
ab, router, logger := testRouterSetup()
|
||||
|
||||
session := mockClientStore{}
|
||||
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session }
|
||||
|
||||
logger.Reset()
|
||||
router.ServeHTTP(w, r)
|
||||
|
||||
if w.Code != http.StatusFound {
|
||||
t.Error("Wrong code:", w.Code)
|
||||
}
|
||||
if loc := w.Header().Get("Location"); loc != err.Location {
|
||||
t.Error("Wrong location:", loc)
|
||||
}
|
||||
if succ, ok := session.Get(FlashSuccessKey); !ok || succ != err.FlashSuccess {
|
||||
t.Error(succ, ok)
|
||||
}
|
||||
if fail, ok := session.Get(FlashErrorKey); !ok || fail != err.FlashError {
|
||||
t.Error(fail, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouter_redirectIfLoggedIn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
Path string
|
||||
LoggedIn bool
|
||||
HalfAuthed bool
|
||||
|
||||
ShouldRedirect bool
|
||||
}{
|
||||
// These routes will be accessed depending on logged in and half auth's value
|
||||
{"/auth", false, false, false},
|
||||
{"/auth", true, false, true},
|
||||
{"/auth", true, true, false},
|
||||
{"/oauth2/facebook", false, false, false},
|
||||
{"/oauth2/facebook", true, false, true},
|
||||
{"/oauth2/facebook", true, true, false},
|
||||
{"/oauth2/callback/facebook", false, false, false},
|
||||
{"/oauth2/callback/facebook", true, false, true},
|
||||
{"/oauth2/callback/facebook", true, true, false},
|
||||
// These are logout routes and never redirect
|
||||
{"/logout", true, false, false},
|
||||
{"/logout", true, true, false},
|
||||
{"/oauth2/logout", true, false, false},
|
||||
{"/oauth2/logout", true, true, false},
|
||||
// These routes should always redirect despite half auth
|
||||
{"/register", true, true, true},
|
||||
{"/recover", true, true, true},
|
||||
{"/register", false, false, false},
|
||||
{"/recover", false, false, false},
|
||||
}
|
||||
|
||||
storer := mockStoreLoader{"john@john.com": mockUser{
|
||||
Email: "john@john.com",
|
||||
Password: "password",
|
||||
}}
|
||||
ab := New()
|
||||
ab.StoreLoader = storer
|
||||
|
||||
for i, test := range tests {
|
||||
session := mockClientStore{}
|
||||
cookies := mockClientStore{}
|
||||
ctx := context.TODO()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
|
||||
if test.LoggedIn {
|
||||
session[SessionKey] = "john@john.com"
|
||||
}
|
||||
if test.HalfAuthed {
|
||||
session[SessionHalfAuthKey] = "true"
|
||||
}
|
||||
|
||||
r, _ := http.NewRequest("GET", test.Path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
handled := redirectIfLoggedIn(ctx, w, r)
|
||||
|
||||
if test.ShouldRedirect && (!handled || w.Code != http.StatusFound) {
|
||||
t.Errorf("%d) It should have redirected the request: %q %t %d", i, test.Path, handled, w.Code)
|
||||
} else if !test.ShouldRedirect && (handled || w.Code != http.StatusOK) {
|
||||
t.Errorf("%d) It should have NOT redirected the request: %q %t %d", i, test.Path, handled, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type deathStorer struct{}
|
||||
|
||||
func (d deathStorer) Create(key string, attributes Attributes) error { return nil }
|
||||
func (d deathStorer) Put(key string, attributes Attributes) error { return nil }
|
||||
func (d deathStorer) Get(key string) (interface{}, error) { return nil, errors.New("explosion") }
|
||||
|
||||
func TestRouter_redirectIfLoggedInError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.Storer = deathStorer{}
|
||||
|
||||
session := mockClientStore{SessionKey: "john"}
|
||||
cookies := mockClientStore{}
|
||||
ctx := context.TODO()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
|
||||
r, _ := http.NewRequest("GET", "/auth", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handled := redirectIfLoggedIn(ctx, w, r)
|
||||
|
||||
if !handled {
|
||||
t.Error("It should have been handled.")
|
||||
}
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Error("It should have internal server error'd:", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
type notFoundStorer struct{}
|
||||
|
||||
func (n notFoundStorer) Create(key string, attributes Attributes) error { return nil }
|
||||
func (n notFoundStorer) Put(key string, attributes Attributes) error { return nil }
|
||||
func (n notFoundStorer) Get(key string) (interface{}, error) { return nil, ErrUserNotFound }
|
||||
|
||||
func TestRouter_redirectIfLoggedInUserNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.LogWriter = ioutil.Discard
|
||||
ab.Storer = notFoundStorer{}
|
||||
|
||||
session := mockClientStore{SessionKey: "john"}
|
||||
cookies := mockClientStore{}
|
||||
ctx := context.TODO()
|
||||
ctx.SessionStorer = session
|
||||
ctx.CookieStorer = cookies
|
||||
|
||||
r, _ := http.NewRequest("GET", "/auth", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handled := redirectIfLoggedIn(ctx, w, r)
|
||||
|
||||
if handled {
|
||||
t.Error("It should not have been handled.")
|
||||
}
|
||||
if _, ok := session.Get(SessionKey); ok {
|
||||
t.Error("It should have removed the bad session cookie")
|
||||
}
|
||||
}
|
||||
|
35
storer.go
35
storer.go
@ -3,7 +3,6 @@ package authboss
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@ -73,10 +72,12 @@ type ArbitraryStorer interface {
|
||||
GetArbitrary(ctx context.Context) (arbitrary map[string]string, err error)
|
||||
}
|
||||
|
||||
// OAuth2Storer allows reading and writing values
|
||||
// OAuth2Storer allows reading and writing values relating to OAuth2
|
||||
type OAuth2Storer interface {
|
||||
Storer
|
||||
|
||||
IsOAuth2User(ctx context.Context) (bool, error)
|
||||
|
||||
PutUID(ctx context.Context, uid string) error
|
||||
PutProvider(ctx context.Context, provider string) error
|
||||
PutToken(ctx context.Context, token string) error
|
||||
@ -90,36 +91,6 @@ type OAuth2Storer interface {
|
||||
GetExpiry(ctx context.Context) (expiry time.Duration, err error)
|
||||
}
|
||||
|
||||
// DataType represents the various types that clients must be able to store.
|
||||
type DataType int
|
||||
|
||||
// DataType constants
|
||||
const (
|
||||
Integer DataType = iota
|
||||
String
|
||||
Bool
|
||||
DateTime
|
||||
)
|
||||
|
||||
var (
|
||||
dateTimeType = reflect.TypeOf(time.Time{})
|
||||
)
|
||||
|
||||
// String returns a string for the DataType representation.
|
||||
func (d DataType) String() string {
|
||||
switch d {
|
||||
case Integer:
|
||||
return "Integer"
|
||||
case String:
|
||||
return "String"
|
||||
case Bool:
|
||||
return "Bool"
|
||||
case DateTime:
|
||||
return "DateTime"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func camelToUnder(in string) string {
|
||||
out := bytes.Buffer{}
|
||||
for i := 0; i < len(in); i++ {
|
||||
|
472
storer_test.go
472
storer_test.go
@ -1,476 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool
|
||||
}
|
||||
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
nt.Time, nt.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
||||
func TestAttributes_FromRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
vals := make(url.Values)
|
||||
vals.Set("a", "a")
|
||||
vals.Set("b_int", "5")
|
||||
vals.Set("wildcard", "")
|
||||
vals.Set("c_date", now.Format(time.RFC3339))
|
||||
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
attr, err := AttributesFromRequest(req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := attr["a"].(string); got != "a" {
|
||||
t.Error("a's value is wrong:", got)
|
||||
}
|
||||
if got := attr["b"].(int); got != 5 {
|
||||
t.Error("b's value is wrong:", got)
|
||||
}
|
||||
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
|
||||
t.Error("c's value is wrong:", now, got)
|
||||
}
|
||||
if _, ok := attr["wildcard"]; ok {
|
||||
t.Error("We don't need totally empty fields.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_Names(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attr := Attributes{
|
||||
"integer": 5,
|
||||
"string": "string",
|
||||
"bool": true,
|
||||
"date_time": time.Now(),
|
||||
}
|
||||
names := attr.Names()
|
||||
|
||||
found := map[string]bool{"integer": false, "string": false, "bool": false, "date_time": false}
|
||||
for _, n := range names {
|
||||
found[n] = true
|
||||
}
|
||||
|
||||
for k, v := range found {
|
||||
if !v {
|
||||
t.Error("Could not find:", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributeMeta_Names(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
meta := AttributeMeta{
|
||||
"integer": Integer,
|
||||
"string": String,
|
||||
"bool": Bool,
|
||||
"date_time": DateTime,
|
||||
}
|
||||
names := meta.Names()
|
||||
|
||||
found := map[string]bool{"integer": false, "string": false, "bool": false, "date_time": false}
|
||||
for _, n := range names {
|
||||
found[n] = true
|
||||
}
|
||||
|
||||
for k, v := range found {
|
||||
if !v {
|
||||
t.Error("Could not find:", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributeMeta_Helpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now()
|
||||
attr := Attributes{
|
||||
"integer": int64(5),
|
||||
"string": "a",
|
||||
"bool": true,
|
||||
"date_time": now,
|
||||
}
|
||||
|
||||
if str, ok := attr.String("string"); !ok || str != "a" {
|
||||
t.Error(str, ok)
|
||||
}
|
||||
if str, err := attr.StringErr("string"); err != nil || str != "a" {
|
||||
t.Error(str, err)
|
||||
}
|
||||
if str, ok := attr.String("notstring"); ok {
|
||||
t.Error(str, ok)
|
||||
}
|
||||
if str, err := attr.StringErr("notstring"); err == nil {
|
||||
t.Error(str, err)
|
||||
}
|
||||
|
||||
if integer, ok := attr.Int64("integer"); !ok || integer != 5 {
|
||||
t.Error(integer, ok)
|
||||
}
|
||||
if integer, err := attr.Int64Err("integer"); err != nil || integer != 5 {
|
||||
t.Error(integer, err)
|
||||
}
|
||||
if integer, ok := attr.Int64("notinteger"); ok {
|
||||
t.Error(integer, ok)
|
||||
}
|
||||
if integer, err := attr.Int64Err("notinteger"); err == nil {
|
||||
t.Error(integer, err)
|
||||
}
|
||||
|
||||
if boolean, ok := attr.Bool("bool"); !ok || !boolean {
|
||||
t.Error(boolean, ok)
|
||||
}
|
||||
if boolean, err := attr.BoolErr("bool"); err != nil || !boolean {
|
||||
t.Error(boolean, err)
|
||||
}
|
||||
if boolean, ok := attr.Bool("notbool"); ok {
|
||||
t.Error(boolean, ok)
|
||||
}
|
||||
if boolean, err := attr.BoolErr("notbool"); err == nil {
|
||||
t.Error(boolean, err)
|
||||
}
|
||||
|
||||
if date, ok := attr.DateTime("date_time"); !ok || date != now {
|
||||
t.Error(date, ok)
|
||||
}
|
||||
if date, err := attr.DateTimeErr("date_time"); err != nil || date != now {
|
||||
t.Error(date, err)
|
||||
}
|
||||
if date, ok := attr.DateTime("notdate_time"); ok {
|
||||
t.Error(date, ok)
|
||||
}
|
||||
if date, err := attr.DateTimeErr("notdate_time"); err == nil {
|
||||
t.Error(date, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataType_String(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if Integer.String() != "Integer" {
|
||||
t.Error("Expected Integer:", Integer)
|
||||
}
|
||||
if String.String() != "String" {
|
||||
t.Error("Expected String:", String)
|
||||
}
|
||||
if Bool.String() != "Bool" {
|
||||
t.Error("Expected Bool:", String)
|
||||
}
|
||||
if DateTime.String() != "DateTime" {
|
||||
t.Error("Expected DateTime:", DateTime)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_Bind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
anInteger := 5
|
||||
aString := "string"
|
||||
aBool := true
|
||||
aTime := time.Now()
|
||||
anUnknown := []byte("I'm not a recognizable type")
|
||||
|
||||
data := Attributes{
|
||||
"integer": anInteger,
|
||||
"string": aString,
|
||||
"bool": aBool,
|
||||
"date_time": aTime,
|
||||
"unknown": anUnknown,
|
||||
}
|
||||
|
||||
s := struct {
|
||||
Integer int
|
||||
String string
|
||||
Bool bool
|
||||
DateTime time.Time
|
||||
Unknown []byte
|
||||
}{}
|
||||
|
||||
if err := data.Bind(&s, false); err != nil {
|
||||
t.Error("Unexpected Error:", err)
|
||||
}
|
||||
|
||||
if s.Integer != anInteger {
|
||||
t.Error("Integer was not set.")
|
||||
}
|
||||
if s.String != aString {
|
||||
t.Error("String was not set.")
|
||||
}
|
||||
if s.Bool != aBool {
|
||||
t.Error("Bool was not set.")
|
||||
}
|
||||
if s.DateTime != aTime {
|
||||
t.Error("DateTime was not set.")
|
||||
}
|
||||
if 0 != bytes.Compare(s.Unknown, anUnknown) {
|
||||
t.Error("The []byte slice was not set.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_BindIgnoreMissing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
anInteger := 5
|
||||
aString := "string"
|
||||
|
||||
data := Attributes{
|
||||
"integer": anInteger,
|
||||
"string": aString,
|
||||
}
|
||||
|
||||
s := struct {
|
||||
Integer int
|
||||
}{}
|
||||
|
||||
if err := data.Bind(&s, false); err == nil {
|
||||
t.Error("Expected error about missing attributes:", err)
|
||||
}
|
||||
|
||||
if err := data.Bind(&s, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if s.Integer != anInteger {
|
||||
t.Error("Integer was not set.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_BindNoPtr(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := Attributes{}
|
||||
s := struct{}{}
|
||||
|
||||
if err := data.Bind(s, false); err == nil {
|
||||
t.Error("Expected an error.")
|
||||
} else if !strings.Contains(err.Error(), "struct pointer") {
|
||||
t.Error("Expected an error about pointers got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_BindMissingField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
data := Attributes{"Integer": 5}
|
||||
s := struct{}{}
|
||||
|
||||
if err := data.Bind(&s, false); err == nil {
|
||||
t.Error("Expected an error.")
|
||||
} else if !strings.Contains(err.Error(), "missing") {
|
||||
t.Error("Expected an error about missing fields, got:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttributes_BindTypeFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
Attr Attributes
|
||||
Err string
|
||||
ToBind interface{}
|
||||
}{
|
||||
{
|
||||
Attr: Attributes{"integer": 5},
|
||||
Err: "should be int",
|
||||
ToBind: &struct {
|
||||
Integer string
|
||||
}{},
|
||||
},
|
||||
{
|
||||
Attr: Attributes{"string": ""},
|
||||
Err: "should be string",
|
||||
ToBind: &struct {
|
||||
String int
|
||||
}{},
|
||||
},
|
||||
{
|
||||
Attr: Attributes{"bool": true},
|
||||
Err: "should be bool",
|
||||
ToBind: &struct {
|
||||
Bool string
|
||||
}{},
|
||||
},
|
||||
{
|
||||
Attr: Attributes{"date": time.Time{}},
|
||||
Err: "should be time.Time",
|
||||
ToBind: &struct {
|
||||
Date int
|
||||
}{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
if err := test.Attr.Bind(test.ToBind, false); err == nil {
|
||||
t.Errorf("%d> Expected an error.", i)
|
||||
} else if !strings.Contains(err.Error(), test.Err) {
|
||||
t.Errorf("%d> Expected an error about %q got: %q", i, test.Err, err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAttributes_BindScannerValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s1 := struct {
|
||||
Count sql.NullInt64
|
||||
Time NullTime
|
||||
}{
|
||||
sql.NullInt64{},
|
||||
NullTime{},
|
||||
}
|
||||
|
||||
nowTime := time.Now()
|
||||
|
||||
attrs := Attributes{"count": 12, "time": nowTime}
|
||||
if err := attrs.Bind(&s1, false); err != nil {
|
||||
t.Error("Unexpected error:", err)
|
||||
}
|
||||
|
||||
if !s1.Count.Valid {
|
||||
t.Error("Expected valid NullInt64")
|
||||
}
|
||||
if s1.Count.Int64 != 12 {
|
||||
t.Error("Unexpected value:", s1.Count.Int64)
|
||||
}
|
||||
|
||||
if !s1.Time.Valid {
|
||||
t.Error("Expected valid time.Time")
|
||||
}
|
||||
if !s1.Time.Time.Equal(nowTime) {
|
||||
t.Error("Unexpected value:", s1.Time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbind(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s1 := struct {
|
||||
Integer int
|
||||
String string
|
||||
Bool bool
|
||||
Time time.Time
|
||||
|
||||
Int32 int32
|
||||
ConfigStruct *Config
|
||||
|
||||
unexported int
|
||||
}{5, "string", true, time.Now(), 5, &Config{}, 5}
|
||||
|
||||
attr := Unbind(&s1)
|
||||
if len(attr) != 6 {
|
||||
t.Error("Expected 6 fields, got:", len(attr))
|
||||
}
|
||||
|
||||
if v, ok := attr["integer"]; !ok {
|
||||
t.Error("Could not find Integer entry.")
|
||||
} else if val, ok := v.(int); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if s1.Integer != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
|
||||
if v, ok := attr["string"]; !ok {
|
||||
t.Error("Could not find String entry.")
|
||||
} else if val, ok := v.(string); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if s1.String != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
|
||||
if v, ok := attr["bool"]; !ok {
|
||||
t.Error("Could not find String entry.")
|
||||
} else if val, ok := v.(bool); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if s1.Bool != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
|
||||
if v, ok := attr["time"]; !ok {
|
||||
t.Error("Could not find Time entry.")
|
||||
} else if val, ok := v.(time.Time); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if s1.Time != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
|
||||
if v, ok := attr["int32"]; !ok {
|
||||
t.Error("Could not find Int32 entry.")
|
||||
} else if val, ok := v.(int32); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if s1.Int32 != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
|
||||
if v, ok := attr["config_struct"]; !ok {
|
||||
t.Error("Could not find ConfigStruct entry.")
|
||||
} else if val, ok := v.(*Config); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if s1.ConfigStruct != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnbind_Valuer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
nowTime := time.Now()
|
||||
|
||||
s1 := struct {
|
||||
Count sql.NullInt64
|
||||
Time NullTime
|
||||
}{
|
||||
sql.NullInt64{Int64: 12, Valid: true},
|
||||
NullTime{nowTime, true},
|
||||
}
|
||||
|
||||
attr := Unbind(&s1)
|
||||
|
||||
if v, ok := attr["count"]; !ok {
|
||||
t.Error("Could not find NullInt64 entry.")
|
||||
} else if val, ok := v.(int64); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if 12 != val {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
|
||||
if v, ok := attr["time"]; !ok {
|
||||
t.Error("Could not find NullTime entry.")
|
||||
} else if val, ok := v.(time.Time); !ok {
|
||||
t.Errorf("Underlying type is wrong: %T", v)
|
||||
} else if !nowTime.Equal(val) {
|
||||
t.Error("Underlying value is wrong:", val)
|
||||
}
|
||||
}
|
||||
import "testing"
|
||||
|
||||
func TestCasingStyleConversions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user