1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-02-05 13:24:54 +02:00

More gigantic edits.

- Change response to be more central to Authboss. Make sure it has
  useful methods and works with the new rendering idioms.
- Change the load user methods to all work with context keys, and even
  be able to set context keys on the current request to avoid setting
  contexts everywhere in the code base.
This commit is contained in:
Aaron L 2017-02-23 16:13:25 -08:00
parent f65d9f6bb6
commit fa6ba517db
25 changed files with 889 additions and 1216 deletions

View File

@ -20,6 +20,9 @@ type Authboss struct {
loadedModules map[string]Modularizer
mux *http.ServeMux
templateNames []string
renderer Renderer
}
// New makes a new instance of authboss with a default
@ -47,68 +50,15 @@ func (a *Authboss) Init(modulesToLoad ...string) error {
}
}
renderer, err := a.ViewLoader.Init(a.templateNames)
if err != nil {
return errors.Wrap(err, "failed to init view loader")
}
a.renderer = renderer
return nil
}
// CurrentUser retrieves the current user from the session and the database.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
return nil, errors.New("TODO")
}
func (a *Authboss) currentUser(w http.ResponseWriter, r *http.Request) (interface{}, error) {
/*
_, err := a.Callbacks.FireBefore(EventGetUserSession, ctx)
if err != nil {
return nil, err
}
key, ok := ctx.SessionStorer.Get(SessionKey)
if !ok {
return nil, nil
}
_, err = a.Callbacks.FireBefore(EventGetUser, ctx)
if err != nil {
return nil, err
}
var user interface{}
if index := strings.IndexByte(key, ';'); index > 0 {
user, err = a.OAuth2Storer.GetOAuth(key[:index], key[index+1:])
} else {
user, err = a.Storer.Get(key)
}
if err != nil {
return nil, err
}
ctx.User = Unbind(user)
err = a.Callbacks.FireAfter(EventGetUser, ctx)
if err != nil {
return nil, err
}
return user, err
*/
return nil, errors.New("not implemented")
}
// CurrentUserP retrieves the current user but panics if it's not available for
// any reason.
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) interface{} {
/*
i, err := a.CurrentUser(w, r)
if err != nil {
panic(err.Error())
}
return i
*/
panic("TODO")
}
/*
UpdatePassword should be called to recalculate hashes and do any cleanup
that should occur on password resets. Updater should return an error if the
@ -130,43 +80,28 @@ The error returned is returned either from the updater if that produced an error
or from the cleanup routines.
*/
func (a *Authboss) UpdatePassword(w http.ResponseWriter, r *http.Request,
ptPassword string, user interface{}, updater func() error) error {
ptPassword string, user Storer, updater func() error) error {
/*
updatePwd := len(ptPassword) > 0
/*updatePwd := len(ptPassword) > 0
if updatePwd {
pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), a.BCryptCost)
if err != nil {
return err
}
val := reflect.ValueOf(user).Elem()
field := val.FieldByName("Password")
if !field.CanSet() {
return errors.New("authboss: updatePassword called without a modifyable user struct")
}
fieldPtr := field.Addr()
if scanner, ok := fieldPtr.Interface().(sql.Scanner); ok {
if err := scanner.Scan(string(pass)); err != nil {
return err
}
} else {
field.SetString(string(pass))
}
}
if err := updater(); err != nil {
if updatePwd {
pass, err := bcrypt.GenerateFromPassword([]byte(ptPassword), a.BCryptCost)
if err != nil {
return err
}
if !updatePwd {
return nil
}
user.PutPassword(r.Context(),
}
return a.Callbacks.FireAfter(EventPasswordReset, a.InitContext(w, r))
*/
if err := updater(); err != nil {
return err
}
return errors.New("TODO")
if !updatePwd {
return nil
}
return a.Callbacks.FireAfter(EventPasswordReset, r.Context())*/
// TODO(aarondl): Fix
return errors.New("not implemented")
}

View File

@ -2,7 +2,6 @@ package authboss
import (
"context"
"database/sql"
"io/ioutil"
"net/http"
"net/http/httptest"
@ -16,6 +15,7 @@ func TestAuthBossInit(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
ab.ViewLoader = mockRenderLoader{}
err := ab.Init()
if err != nil {
t.Error("Unexpected error:", err)
@ -28,12 +28,9 @@ func TestAuthBossCurrentUser(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return mockClientStore{SessionKey: "joe"}
}
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return mockClientStore{}
}
ab.ViewLoader = mockRenderLoader{}
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
if err := ab.Init(); err != nil {
t.Error("Unexpected error:", err)
@ -43,7 +40,7 @@ func TestAuthBossCurrentUser(t *testing.T) {
req, _ := http.NewRequest("GET", "localhost", nil)
userStruct := ab.CurrentUserP(rec, req)
us := userStruct.(*mockUser)
us := userStruct.(mockStoredUser)
if us.Email != "john@john.com" || us.Password != "lies" {
t.Error("Wrong user found!")
@ -56,12 +53,9 @@ func TestAuthBossCurrentUserCallbacks(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return mockClientStore{SessionKey: "joe"}
}
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return mockClientStore{}
}
ab.ViewLoader = mockRenderLoader{}
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
if err := ab.Init(); err != nil {
t.Error("Unexpected error:", err)
@ -97,86 +91,88 @@ func TestAuthBossCurrentUserCallbacks(t *testing.T) {
}
func TestAuthbossUpdatePassword(t *testing.T) {
t.Parallel()
t.Skip("TODO(aarondl): Implement")
/*
t.Parallel()
ab := New()
session := mockClientStore{}
cookies := mockClientStore{}
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return session
}
ab.CookieStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return cookies
}
ab := New()
session := mockClientStore{}
cookies := mockClientStore{}
ab.SessionStoreMaker = newMockClientStoreMaker(session)
ab.CookieStoreMaker = newMockClientStoreMaker(cookies)
called := false
ab.Callbacks.After(EventPasswordReset, func(ctx context.Context) error {
called = true
return nil
})
called := false
ab.Callbacks.After(EventPasswordReset, func(ctx context.Context) error {
called = true
return nil
})
user1 := struct {
Password string
}{}
user2 := struct {
Password sql.NullString
}{}
user1 := struct {
Password string
}{}
user2 := struct {
Password sql.NullString
}{}
r, _ := http.NewRequest("GET", "http://localhost", nil)
r, _ := http.NewRequest("GET", "http://localhost", nil)
called = false
err := ab.UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil })
if err != nil {
t.Error(err)
}
called = false
err := ab.UpdatePassword(nil, r, "newpassword", &user1, func() error { return nil })
if err != nil {
t.Error(err)
}
if len(user1.Password) == 0 {
t.Error("Password not updated")
}
if !called {
t.Error("Callbacks should have been called.")
}
if len(user1.Password) == 0 {
t.Error("Password not updated")
}
if !called {
t.Error("Callbacks should have been called.")
}
called = false
err = ab.UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil })
if err != nil {
t.Error(err)
}
called = false
err = ab.UpdatePassword(nil, r, "newpassword", &user2, func() error { return nil })
if err != nil {
t.Error(err)
}
if !user2.Password.Valid || len(user2.Password.String) == 0 {
t.Error("Password not updated")
}
if !called {
t.Error("Callbacks should have been called.")
}
if !user2.Password.Valid || len(user2.Password.String) == 0 {
t.Error("Password not updated")
}
if !called {
t.Error("Callbacks should have been called.")
}
called = false
oldPassword := user1.Password
err = ab.UpdatePassword(nil, r, "", &user1, func() error { return nil })
if err != nil {
t.Error(err)
}
called = false
oldPassword := user1.Password
err = ab.UpdatePassword(nil, r, "", &user1, func() error { return nil })
if err != nil {
t.Error(err)
}
if user1.Password != oldPassword {
t.Error("Password not updated")
}
if called {
t.Error("Callbacks should not have been called")
}
if user1.Password != oldPassword {
t.Error("Password not updated")
}
if called {
t.Error("Callbacks should not have been called")
}
*/
}
func TestAuthbossUpdatePasswordFail(t *testing.T) {
t.Parallel()
t.Skip("TODO(aarondl): Implement")
/*
t.Parallel()
ab := New()
ab := New()
user1 := struct {
Password string
}{}
user1 := struct {
Password string
}{}
anErr := errors.New("anError")
err := ab.UpdatePassword(nil, nil, "update", &user1, func() error { return anErr })
if err != anErr {
t.Error("Expected an specific error:", err)
}
anErr := errors.New("anError")
err := ab.UpdatePassword(nil, nil, "update", &user1, func() error { return anErr })
if err != anErr {
t.Error("Expected an specific error:", err)
}
*/
}

View File

@ -1,12 +1,6 @@
package authboss
import (
"context"
"fmt"
"io/ioutil"
"reflect"
"runtime"
)
import "context"
//go:generate stringer -output stringers.go -type "Event,Interrupt"
@ -65,8 +59,8 @@ type Callbacks struct {
// Called only by authboss internals and for testing.
func NewCallbacks() *Callbacks {
return &Callbacks{
make(map[Event][]Before),
make(map[Event][]After),
before: make(map[Event][]Before),
after: make(map[Event][]After),
}
}
@ -95,8 +89,6 @@ func (c *Callbacks) FireBefore(e Event, ctx context.Context) (interrupt Interrup
for _, fn := range callbacks {
interrupt, err = fn(ctx)
if err != nil {
// TODO(aarondl): logwriter fail
fmt.Fprintf(ioutil.Discard, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
return InterruptNone, err
}
if interrupt != InterruptNone {
@ -113,8 +105,6 @@ func (c *Callbacks) FireAfter(e Event, ctx context.Context) (err error) {
callbacks := c.after[e]
for _, fn := range callbacks {
if err = fn(ctx); err != nil {
// TODO(aarondl): logwriter fail
fmt.Fprintf(ioutil.Discard, "Callback error (%s): %v\n", runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name(), err)
return err
}
}

View File

@ -3,7 +3,6 @@ package authboss
import (
"bytes"
"context"
"strings"
"testing"
"github.com/pkg/errors"
@ -116,10 +115,6 @@ func TestCallbacksBeforeErrors(t *testing.T) {
if before2 {
t.Error("Before2 should not have been called.")
}
if estr := log.String(); !strings.Contains(estr, errValue.Error()) {
t.Error("Error string wrong:", estr)
}
}
func TestCallbacksAfterErrors(t *testing.T) {
@ -153,10 +148,6 @@ func TestCallbacksAfterErrors(t *testing.T) {
if after2 {
t.Error("After2 should not have been called.")
}
if estr := log.String(); !strings.Contains(estr, errValue.Error()) {
t.Error("Error string wrong:", estr)
}
}
func TestEventString(t *testing.T) {

View File

@ -25,8 +25,9 @@ const (
FlashErrorKey = "flash_error"
)
// ClientStoreMaker is used to create a cookie storer from an http request. Keep in mind
// security considerations for your implementation, Secure, HTTP-Only, etc flags.
// ClientStoreMaker is used to create a cookie storer from an http request.
// Keep in mind security considerations for your implementation, Secure,
// HTTP-Only, etc flags.
//
// There's two major uses for this. To create session storage, and remember me
// cookies.
@ -64,7 +65,7 @@ func (c clientStoreWrapper) GetErr(key string) (string, error) {
// FlashSuccess returns FlashSuccessKey from the session and removes it.
func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
storer := a.SessionStoreMaker(w, r)
storer := a.SessionStoreMaker.Make(w, r)
msg, ok := storer.Get(FlashSuccessKey)
if ok {
storer.Del(FlashSuccessKey)
@ -75,7 +76,7 @@ func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
// FlashError returns FlashError from the session and removes it.
func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string {
storer := a.SessionStoreMaker(w, r)
storer := a.SessionStoreMaker.Make(w, r)
msg, ok := storer.Get(FlashErrorKey)
if ok {
storer.Del(FlashErrorKey)

View File

@ -1,9 +1,6 @@
package authboss
import (
"net/http"
"testing"
)
import "testing"
type testClientStorerErr string
@ -36,9 +33,7 @@ func TestFlashClearer(t *testing.T) {
session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"}
ab := New()
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer {
return session
}
ab.SessionStoreMaker = newMockClientStoreMaker(session)
if msg := ab.FlashSuccess(nil, nil); msg != "success" {
t.Error("Unexpected flash success:", msg)

View File

@ -2,6 +2,7 @@
package confirm
import (
"context"
"crypto/md5"
"crypto/rand"
"encoding/base64"
@ -31,14 +32,25 @@ var (
errUserMissing = errors.New("after registration user must be loaded")
)
// ConfirmStorer must be implemented in order to satisfy the confirm module's
// storage requirements.
type ConfirmStorer interface {
authboss.Storer
// ConfirmStoreLoader allows lookup of users by different parameters.
type ConfirmStoreLoader interface {
authboss.StoreLoader
// ConfirmUser looks up a user by a confirm token. See confirm module for
// attribute names. If the token is not found in the data store,
// simply return nil, ErrUserNotFound.
ConfirmUser(confirmToken string) (interface{}, error)
LoadByConfirmToken(confirmToken string) (ConfirmStorer, error)
}
// ConfirmStorer defines attributes for the confirm module.
type ConfirmStorer interface {
authboss.Storer
PutConfirmed(ctx context.Context, confirmed bool) error
PutConfirmToken(ctx context.Context, token string) error
GetConfirmed(ctx context.Context) (confirmed bool, err error)
GetConfirmToken(ctx context.Context) (token string, err error)
}
func init() {
@ -48,30 +60,13 @@ func init() {
// Confirm module
type Confirm struct {
*authboss.Authboss
emailHTMLTemplates response.Templates
emailTextTemplates response.Templates
}
// Initialize the module
func (c *Confirm) Initialize(ab *authboss.Authboss) (err error) {
c.Authboss = ab
var ok bool
storer, ok := c.Storer.(ConfirmStorer)
if c.StoreMaker == nil && (storer == nil || !ok) {
return errors.New("need a confirmStorer")
}
c.emailHTMLTemplates, err = response.LoadTemplates(ab, c.LayoutHTMLEmail, c.ViewsPath, tplConfirmHTML)
if err != nil {
return err
}
c.emailTextTemplates, err = response.LoadTemplates(ab, c.LayoutTextEmail, c.ViewsPath, tplConfirmText)
if err != nil {
return err
}
c.Callbacks.After(authboss.EventGetUser, func(ctx *authboss.Context) error {
c.Callbacks.After(authboss.EventGetUser, func(ctx context.Context) error {
_, err := c.beforeGet(ctx)
return err
})
@ -88,17 +83,12 @@ func (c *Confirm) Routes() authboss.RouteTable {
}
}
// Storage requirements
func (c *Confirm) Storage() authboss.StorageOptions {
return authboss.StorageOptions{
c.PrimaryID: authboss.String,
authboss.StoreEmail: authboss.String,
StoreConfirmToken: authboss.String,
StoreConfirmed: authboss.Bool,
}
// Templates returns the list of templates required by this module
func (c *Confirm) Templates() []string {
return []string{tplConfirmHTML, tplConfirmText}
}
func (c *Confirm) beforeGet(ctx *authboss.Context) (authboss.Interrupt, error) {
func (c *Confirm) beforeGet(ctx context.Context) (authboss.Interrupt, error) {
if confirmed, err := ctx.User.BoolErr(StoreConfirmed); err != nil {
return authboss.InterruptNone, err
} else if !confirmed {
@ -109,7 +99,7 @@ func (c *Confirm) beforeGet(ctx *authboss.Context) (authboss.Interrupt, error) {
}
// AfterRegister ensures the account is not activated.
func (c *Confirm) afterRegister(ctx *authboss.Context) error {
func (c *Confirm) afterRegister(ctx context.Context) error {
if ctx.User == nil {
return errUserMissing
}
@ -136,7 +126,7 @@ func (c *Confirm) afterRegister(ctx *authboss.Context) error {
return nil
}
var goConfirmEmail = func(c *Confirm, ctx *authboss.Context, to, token string) {
var goConfirmEmail = func(c *Confirm, ctx context.Context, to, token string) {
if ctx.MailMaker != nil {
c.confirmEmail(ctx, to, token)
} else {
@ -145,7 +135,7 @@ var goConfirmEmail = func(c *Confirm, ctx *authboss.Context, to, token string) {
}
// confirmEmail sends a confirmation e-mail.
func (c *Confirm) confirmEmail(ctx *authboss.Context, to, token string) {
func (c *Confirm) confirmEmail(ctx context.Context, to, token string) {
p := path.Join(c.MountPath, "confirm")
url := fmt.Sprintf("%s%s?%s=%s", c.RootURL, p, url.QueryEscape(FormValueConfirm), url.QueryEscape(token))
@ -161,7 +151,7 @@ func (c *Confirm) confirmEmail(ctx *authboss.Context, to, token string) {
}
}
func (c *Confirm) confirmHandler(ctx *authboss.Context, w http.ResponseWriter, r *http.Request) error {
func (c *Confirm) confirmHandler(w http.ResponseWriter, r *http.Request) error {
token := r.FormValue(FormValueConfirm)
if len(token) == 0 {
return authboss.ClientDataErr{Name: FormValueConfirm}

150
context.go Normal file
View File

@ -0,0 +1,150 @@
package authboss
import (
"context"
"net/http"
)
type contextKey string
const (
ctxKeyPID contextKey = "pid"
ctxKeyUser contextKey = "user"
)
func (c contextKey) String() string {
return "authboss ctx key " + string(c)
}
// CurrentUserID retrieves the current user from the session.
func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string, error) {
_, err := a.Callbacks.FireBefore(EventGetUserSession, r.Context())
if err != nil {
return "", err
}
session := a.SessionStoreMaker.Make(w, r)
key, _ := session.Get(SessionKey)
return key, nil
}
// CurrentUserIDP retrieves the current user but panics if it's not available for
// any reason.
func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string {
i, err := a.CurrentUserID(w, r)
if err != nil {
panic(err)
} else if len(i) == 0 {
panic(ErrUserNotFound)
}
return i
}
// CurrentUser retrieves the current user from the session and the database.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (Storer, error) {
pid, err := a.CurrentUserID(w, r)
if err != nil {
return nil, err
} else if len(pid) == 0 {
return nil, nil
}
return a.currentUser(r.Context(), pid)
}
// CurrentUserP retrieves the current user but panics if it's not available for
// any reason.
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) Storer {
i, err := a.CurrentUser(w, r)
if err != nil {
panic(err)
}
return i
}
func (a *Authboss) currentUser(ctx context.Context, pid string) (Storer, error) {
_, err := a.Callbacks.FireBefore(EventGetUser, ctx)
if err != nil {
return nil, err
}
user, err := a.StoreLoader.Load(ctx, pid)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, ctxKeyUser, user)
err = a.Callbacks.FireAfter(EventGetUser, ctx)
if err != nil {
return nil, err
}
return user, nil
}
// LoadCurrentUser takes a pointer to a pointer to the request in order to
// change the current method's request pointer itself to the new request that
// contains the new context that has the pid in it.
func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (string, error) {
pid, err := a.CurrentUserID(w, *r)
if err != nil {
return "", err
}
if len(pid) == 0 {
return "", nil
}
ctx := context.WithValue((**r).Context(), ctxKeyPID, pid)
*r = (**r).WithContext(ctx)
return pid, nil
}
func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) string {
pid, err := a.LoadCurrentUserID(w, r)
if err != nil {
panic(err)
} else if len(pid) == 0 {
panic(ErrUserNotFound)
}
return pid
}
// LoadCurrentUser takes a pointer to a pointer to the request in order to
// change the current method's request pointer itself to the new request that
// contains the new context that has the user in it. Calls LoadCurrentUserID
// so the primary id is also put in the context.
func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (Storer, error) {
pid, err := a.LoadCurrentUserID(w, r)
if err != nil {
return nil, err
}
if len(pid) == 0 {
return nil, nil
}
ctx := (**r).Context()
user, err := a.currentUser(ctx, pid)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, ctxKeyUser, user)
*r = (**r).WithContext(ctx)
return user, nil
}
func (a *Authboss) LoadCurrentUserP(w http.ResponseWriter, r **http.Request) Storer {
user, err := a.LoadCurrentUser(w, r)
if err != nil {
panic(err)
} else if user == nil {
panic(ErrUserNotFound)
}
return user
}

170
context_test.go Normal file
View File

@ -0,0 +1,170 @@
package authboss
import (
"context"
"net/http/httptest"
"testing"
)
func TestCurrentUserID(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
id, err := ab.CurrentUserID(nil, httptest.NewRequest("GET", "/", nil))
if err != nil {
t.Error(err)
}
if id != "george-pid" {
t.Error("got:", id)
}
}
func TestCurrentUserIDP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
defer func() {
if recover().(error) != ErrUserNotFound {
t.Failed()
}
}()
_ = ab.CurrentUserIDP(nil, httptest.NewRequest("GET", "/", nil))
}
func TestCurrentUser(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
user, err := ab.CurrentUser(nil, httptest.NewRequest("GET", "/", nil))
if err != nil {
t.Error(err)
}
if got, err := user.GetEmail(context.TODO()); err != nil {
t.Error(err)
} else if got != "george-pid" {
t.Error("got:", got)
}
}
func TestCurrentUserP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{}
defer func() {
if recover().(error) != ErrUserNotFound {
t.Failed()
}
}()
_ = ab.CurrentUserP(nil, httptest.NewRequest("GET", "/", nil))
}
func TestLoadCurrentUserID(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
req := httptest.NewRequest("GET", "/", nil)
id, err := ab.LoadCurrentUserID(nil, &req)
if err != nil {
t.Error(err)
}
if id != "george-pid" {
t.Error("got:", id)
}
if req.Context().Value(ctxKeyPID).(string) != "george-pid" {
t.Error("context was not updated in local request")
}
}
func TestLoadCurrentUserIDP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
defer func() {
if recover().(error) != ErrUserNotFound {
t.Failed()
}
}()
req := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserIDP(nil, &req)
}
func TestLoadCurrentUser(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
req := httptest.NewRequest("GET", "/", nil)
user, err := ab.LoadCurrentUser(nil, &req)
if err != nil {
t.Error(err)
}
if got, err := user.GetEmail(context.TODO()); err != nil {
t.Error(err)
} else if got != "george-pid" {
t.Error("got:", got)
}
want := user.(mockStoredUser).mockUser
got := req.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
if got != want {
t.Error("users mismatched:\nwant: %#v\ngot: %#v", want, got)
}
}
func TestLoadCurrentUserP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{}
defer func() {
if recover().(error) != ErrUserNotFound {
t.Failed()
}
}()
req := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserP(nil, &req)
}

View File

@ -2,32 +2,6 @@ package authboss
import "fmt"
// AttributeErr represents a failure to retrieve a critical
// piece of data from the storer.
type AttributeErr struct {
Name string
WantKind DataType
GotKind string
}
// NewAttributeErr creates a new attribute err type. Useful for when you want
// to have a type mismatch error.
func NewAttributeErr(name string, kind DataType, val interface{}) AttributeErr {
return AttributeErr{
Name: name,
WantKind: kind,
GotKind: fmt.Sprintf("%T", val),
}
}
func (a AttributeErr) Error() string {
if len(a.GotKind) == 0 {
return fmt.Sprintf("Failed to retrieve database attribute: %s", a.Name)
}
return fmt.Sprintf("Failed to retrieve database attribute, type was wrong: %s (want: %v, got: %s)", a.Name, a.WantKind, a.GotKind)
}
// ClientDataErr represents a failure to retrieve a critical
// piece of client information such as a cookie or session value.
type ClientDataErr struct {
@ -38,19 +12,6 @@ func (c ClientDataErr) Error() string {
return fmt.Sprintf("Failed to retrieve client attribute: %s", c.Name)
}
// ErrAndRedirect represents a general error whose response should
// be to redirect.
type ErrAndRedirect struct {
Err error
Location string
FlashSuccess string
FlashError string
}
func (e ErrAndRedirect) Error() string {
return fmt.Sprintf("Error: %v, Redirecting to: %s", e.Err, e.Location)
}
// RenderErr represents an error that occured during rendering
// of a template.
type RenderErr struct {
@ -60,5 +21,5 @@ type RenderErr struct {
}
func (r RenderErr) Error() string {
return fmt.Sprintf("Error rendering template %q: %v, data: %#v", r.TemplateName, r.Err, r.Data)
return fmt.Sprintf("error rendering response %q: %v, data: %#v", r.TemplateName, r.Err, r.Data)
}

View File

@ -6,21 +6,6 @@ import (
"github.com/pkg/errors"
)
func TestAttributeErr(t *testing.T) {
t.Parallel()
estr := "Failed to retrieve database attribute, type was wrong: lol (want: String, got: int)"
if str := NewAttributeErr("lol", String, 5).Error(); str != estr {
t.Error("Error was wrong:", str)
}
estr = "Failed to retrieve database attribute: lol"
err := AttributeErr{Name: "lol"}
if str := err.Error(); str != estr {
t.Error("Error was wrong:", str)
}
}
func TestClientDataErr(t *testing.T) {
t.Parallel()
@ -31,20 +16,10 @@ func TestClientDataErr(t *testing.T) {
}
}
func TestErrAndRedirect(t *testing.T) {
t.Parallel()
estr := "Error: cause, Redirecting to: /"
err := ErrAndRedirect{errors.New("cause"), "/", "success", "failure"}
if str := err.Error(); str != estr {
t.Error("Error was wrong:", str)
}
}
func TestRenderErr(t *testing.T) {
t.Parallel()
estr := `Error rendering template "lol": cause, data: authboss.HTMLData{"a":5}`
estr := `error rendering response "lol": cause, data: authboss.HTMLData{"a":5}`
err := RenderErr{"lol", NewHTMLData("a", 5), errors.New("cause")}
if str := err.Error(); str != estr {
t.Error("Error was wrong:", str)

View File

@ -9,7 +9,7 @@ var nowTime = time.Now
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
return a.timeToExpiry(a.SessionStoreMaker(w, r))
return a.timeToExpiry(a.SessionStoreMaker.Make(w, r))
}
func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
@ -33,7 +33,7 @@ func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
session := a.SessionStoreMaker(w, r)
session := a.SessionStoreMaker.Make(w, r)
a.refreshExpiry(session)
}
@ -57,7 +57,7 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
// ServeHTTP removes the session if it's passed the expire time.
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session := m.ab.SessionStoreMaker(w, r)
session := m.ab.SessionStoreMaker.Make(w, r)
if _, ok := session.Get(SessionKey); ok {
ttl := m.ab.timeToExpiry(session)
if ttl == 0 {

View File

@ -11,15 +11,18 @@ import (
func TestDudeIsExpired(t *testing.T) {
ab := New()
session := mockClientStore{SessionKey: "username"}
ab.refreshExpiry(session)
// No t.Parallel()
nowTime = func() time.Time {
return time.Now().UTC().Add(ab.ExpireAfter * 2)
}
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return session
}
defer func() {
nowTime = time.Now
}()
ab.SessionStoreMaker = newMockClientStoreMaker(session)
r, _ := http.NewRequest("GET", "tra/la/la", nil)
w := httptest.NewRecorder()
@ -36,25 +39,28 @@ func TestDudeIsExpired(t *testing.T) {
}
if key, ok := session.Get(SessionKey); ok {
t.Error("Unexpcted session key:", key)
t.Error("Unexpected session key:", key)
}
if key, ok := session.Get(SessionLastAction); ok {
t.Error("Unexpcted last action key:", key)
t.Error("Unexpected last action key:", key)
}
}
func TestDudeIsNotExpired(t *testing.T) {
ab := New()
session := mockClientStore{SessionKey: "username"}
ab.refreshExpiry(session)
// No t.Parallel()
nowTime = func() time.Time {
return time.Now().UTC().Add(ab.ExpireAfter / 2)
}
ab.SessionStoreMaker = func(_ http.ResponseWriter, _ *http.Request) ClientStorer {
return session
}
defer func() {
nowTime = time.Now
}()
ab.SessionStoreMaker = newMockClientStoreMaker(session)
r, _ := http.NewRequest("GET", "tra/la/la", nil)
w := httptest.NewRecorder()

View File

@ -2,6 +2,7 @@
package mocks
import (
"context"
"io"
"net/http"
"net/url"
@ -25,113 +26,99 @@ type MockUser struct {
Locked bool
AttemptNumber int
AttemptTime time.Time
OauthToken string
OauthRefresh string
OauthExpiry time.Time
OAuthToken string
OAuthRefresh string
OAuthExpiry time.Time
}
func (m MockUser) GetUsername(context.Context) (string, error) { return m.Username, nil }
func (m MockUser) GetEmail(context.Context) (string, error) { return m.Email, nil }
func (m MockUser) GetPassword(context.Context) (string, error) { return m.Password, nil }
func (m MockUser) GetRecoverToken(context.Context) (string, error) { return m.RecoverToken, nil }
func (m MockUser) GetRecoverTokenExpiry(context.Context) (time.Time, error) {
return m.RecoverTokenExpiry, nil
}
func (m MockUser) GetConfirmToken(context.Context) (string, error) { return m.ConfirmToken, nil }
func (m MockUser) GetConfirmed(context.Context) (bool, error) { return m.Confirmed, nil }
func (m MockUser) GetLocked(context.Context) (bool, error) { return m.Locked, nil }
func (m MockUser) GetAttemptNumber(context.Context) (int, error) { return m.AttemptNumber, nil }
func (m MockUser) GetAttemptTime(context.Context) (time.Time, error) {
return m.AttemptTime, nil
}
func (m MockUser) GetOAuthToken(context.Context) (string, error) { return m.OAuthToken, nil }
func (m MockUser) GetOAuthRefresh(context.Context) (string, error) { return m.OAuthRefresh, nil }
func (m MockUser) GetOAuthExpiry(context.Context) (time.Time, error) {
return m.OAuthExpiry, nil
}
func (m *MockUser) SetUsername(ctx context.Context, username string) error {
m.Username = username
return nil
}
func (m *MockUser) SetEmail(ctx context.Context, email string) error {
m.Email = email
return nil
}
func (m *MockUser) SetPassword(ctx context.Context, password string) error {
m.Password = password
return nil
}
func (m *MockUser) SetRecoverToken(ctx context.Context, recoverToken string) error {
m.RecoverToken = recoverToken
return nil
}
func (m *MockUser) SetRecoverTokenExpiry(ctx context.Context, recoverTokenExpiry time.Time) error {
m.RecoverTokenExpiry = recoverTokenExpiry
return nil
}
func (m *MockUser) SetConfirmToken(ctx context.Context, confirmToken string) error {
m.ConfirmToken = confirmToken
return nil
}
func (m *MockUser) SetConfirmed(ctx context.Context, confirmed bool) error {
m.Confirmed = confirmed
return nil
}
func (m *MockUser) SetLocked(ctx context.Context, locked bool) error {
m.Locked = locked
return nil
}
func (m *MockUser) SetAttemptNumber(ctx context.Context, attemptNumber int) error {
m.AttemptNumber = attemptNumber
return nil
}
func (m *MockUser) SetAttemptTime(ctx context.Context, attemptTime time.Time) error {
m.AttemptTime = attemptTime
return nil
}
func (m *MockUser) SetOAuthToken(ctx context.Context, oAuthToken string) error {
m.OAuthToken = oAuthToken
return nil
}
func (m *MockUser) SetOAuthRefresh(ctx context.Context, oAuthRefresh string) error {
m.OAuthRefresh = oAuthRefresh
return nil
}
func (m *MockUser) SetOAuthExpiry(ctx context.Context, oAuthExpiry time.Time) error {
m.OAuthExpiry = oAuthExpiry
return nil
}
// MockStorer should be valid for any module storer defined in authboss.
type MockStorer struct {
Users map[string]authboss.Attributes
Tokens map[string][]string
CreateErr string
PutErr string
GetErr string
AddTokenErr string
DelTokensErr string
UseTokenErr string
RecoverUserErr string
ConfirmUserErr string
type MockStoreLoader struct {
Users map[string]*MockUser
RMTokens map[string][]string
}
// NewMockStorer constructor
func NewMockStorer() *MockStorer {
return &MockStorer{
Users: make(map[string]authboss.Attributes),
Tokens: make(map[string][]string),
func NewMockStoreLoader() *MockStoreLoader {
return &MockStoreLoader{
Users: make(map[string]*MockUser),
RMTokens: make(map[string][]string),
}
}
// Create a new user
func (m *MockStorer) Create(key string, attr authboss.Attributes) error {
if len(m.CreateErr) > 0 {
return errors.New(m.CreateErr)
}
m.Users[key] = attr
return nil
}
// Put updates to a user
func (m *MockStorer) Put(key string, attr authboss.Attributes) error {
if len(m.PutErr) > 0 {
return errors.New(m.PutErr)
}
if _, ok := m.Users[key]; !ok {
m.Users[key] = attr
return nil
}
for k, v := range attr {
m.Users[key][k] = v
}
return nil
}
// Get a user
func (m *MockStorer) Get(key string) (result interface{}, err error) {
if len(m.GetErr) > 0 {
return nil, errors.New(m.GetErr)
}
userAttrs, ok := m.Users[key]
if !ok {
return nil, authboss.ErrUserNotFound
}
u := &MockUser{}
if err := userAttrs.Bind(u, true); err != nil {
panic(err)
}
return u, nil
}
// PutOAuth user
func (m *MockStorer) PutOAuth(uid, provider string, attr authboss.Attributes) error {
if len(m.PutErr) > 0 {
return errors.New(m.PutErr)
}
if _, ok := m.Users[uid+provider]; !ok {
m.Users[uid+provider] = attr
return nil
}
for k, v := range attr {
m.Users[uid+provider][k] = v
}
return nil
}
// GetOAuth user
func (m *MockStorer) GetOAuth(uid, provider string) (result interface{}, err error) {
if len(m.GetErr) > 0 {
return nil, errors.New(m.GetErr)
}
userAttrs, ok := m.Users[uid+provider]
if !ok {
return nil, authboss.ErrUserNotFound
}
u := &MockUser{}
if err := userAttrs.Bind(u, true); err != nil {
panic(err)
}
return u, nil
}
/*
// AddToken for remember me
func (m *MockStorer) AddToken(key, token string) error {
if len(m.AddTokenErr) > 0 {
@ -211,23 +198,26 @@ func (m *MockStorer) ConfirmUser(confirmToken string) (result interface{}, err e
return nil, authboss.ErrUserNotFound
}
*/
// MockFailStorer is used for testing module initialize functions that recover more than the base storer
type MockFailStorer struct{}
type MockFailStorer struct {
MockUser
}
// Create fails
func (_ MockFailStorer) Create(_ string, _ authboss.Attributes) error {
func (_ MockFailStorer) Create(context.Context) error {
return errors.New("fail storer: create")
}
// Put fails
func (_ MockFailStorer) Put(_ string, _ authboss.Attributes) error {
func (_ MockFailStorer) Save(context.Context) error {
return errors.New("fail storer: put")
}
// Get fails
func (_ MockFailStorer) Get(_ string) (interface{}, error) {
return nil, errors.New("fail storer: get")
func (_ MockFailStorer) Load(context.Context) error {
return errors.New("fail storer: get")
}
// MockClientStorer is used for testing the client stores on context
@ -322,7 +312,7 @@ func NewMockMailer() *MockMailer {
}
// Send an e-mail
func (m *MockMailer) Send(email authboss.Email) error {
func (m *MockMailer) Send(ctx context.Context, email authboss.Email) error {
if len(m.SendErr) > 0 {
return errors.New(m.SendErr)
}
@ -341,7 +331,7 @@ type MockAfterCallback struct {
func NewMockAfterCallback() *MockAfterCallback {
m := MockAfterCallback{}
m.Fn = func(_ *authboss.Context) error {
m.Fn = func(context.Context) error {
m.HasBeenCalled = true
return nil
}

View File

@ -3,30 +3,13 @@ package response
//go:generate go-bindata -pkg=response -prefix=templates templates
import (
"bytes"
"html/template"
"io"
"io/ioutil"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"github.com/pkg/errors"
"github.com/go-authboss/authboss"
)
var (
// ErrTemplateNotFound should be returned from Get when the view is not found
ErrTemplateNotFound = errors.New("template not found")
)
// TODO(aarondl): Extract this into the default "template" provider
/*
// Templates is a map depicting the forms a template needs wrapped within the specified layout
type Templates map[string]*template.Template
// LoadTemplates parses all specified files located in fpath. Each template is wrapped
// in a unique clone of layout. All templates are expecting {{authboss}} handlebars
// for parsing. It will check the override directory specified in the config, replacing any
@ -158,3 +141,4 @@ func Redirect(ctx *authboss.Context, w http.ResponseWriter, r *http.Request, pat
}
http.Redirect(w, r, path, http.StatusFound)
}
*/

View File

@ -76,7 +76,7 @@ func TestBoundary(t *testing.T) {
t.Parallel()
mailer := smtpMailer{"server", nil, rand.New(rand.NewSource(3))}
if got := mailer.boundary(); got != "ntadoe" {
if got := mailer.boundary(); got != "fe3fhpsm69lx8jvnrnju0wr" {
t.Error("boundary was wrong", got)
}
}

View File

@ -2,6 +2,7 @@ package authboss
import (
"context"
"encoding/json"
"net/http"
"net/url"
"strings"
@ -77,8 +78,20 @@ func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err e
return m.Password, nil
}
type mockClientStoreMaker struct {
store mockClientStore
}
type mockClientStore map[string]string
func newMockClientStoreMaker(store mockClientStore) mockClientStoreMaker {
return mockClientStoreMaker{
store: store,
}
}
func (m mockClientStoreMaker) Make(w http.ResponseWriter, r *http.Request) ClientStorer {
return m.store
}
func (m mockClientStore) Get(key string) (string, bool) {
v, ok := m[key]
return v, ok
@ -125,3 +138,20 @@ func (m mockValidator) Errors(in string) ErrorList {
func (m mockValidator) Rules() []string {
return m.Ruleset
}
type mockRenderLoader struct{}
func (m mockRenderLoader) Init(names []string) (Renderer, error) {
return mockRenderer{}, nil
}
type mockRenderer struct{}
func (m mockRenderer) Render(ctx context.Context, name string, data HTMLData) ([]byte, string, error) {
b, err := json.Marshal(data)
if err != nil {
return nil, "", err
}
return b, "application/json", nil
}

View File

@ -8,6 +8,7 @@ var registeredModules = make(map[string]Modularizer)
type Modularizer interface {
Initialize(*Authboss) error
Routes() RouteTable
Templates() []string
}
// RegisterModule with the core providing all the necessary information to
@ -55,6 +56,7 @@ func (a *Authboss) loadModule(name string) error {
}
mod, ok := value.Interface().(Modularizer)
a.loadedModules[name] = mod
a.templateNames = append(a.templateNames, mod.Templates()...)
return mod.Initialize(a)
}

View File

@ -29,6 +29,7 @@ func testHandler(w http.ResponseWriter, r *http.Request) error {
func (t *testModule) Initialize(a *Authboss) error { return nil }
func (t *testModule) Routes() RouteTable { return t.r }
func (t *testModule) Templates() []string { return []string{"template1.tpl"} }
func TestRegister(t *testing.T) {
// RegisterModule called by init()
@ -59,6 +60,7 @@ func TestLoadedModules(t *testing.T) {
func TestIsLoaded(t *testing.T) {
ab := New()
ab.LogWriter = ioutil.Discard
ab.ViewLoader = mockRenderLoader{}
if err := ab.Init(testModName); err != nil {
t.Error(err)
}

View File

@ -6,10 +6,10 @@ import "context"
// It's possible that Init() is a no-op if the responses are JSON or anything
// else.
type RenderLoader interface {
Init(names string) (Renderer, error)
Init(names []string) (Renderer, error)
}
// Renderer is a type that can render a given template with some data.
type Renderer interface {
Render(ctx context.Context, data HTMLData) ([]byte, error)
Render(ctx context.Context, name string, data HTMLData) (output []byte, contentType string, err error)
}

168
response.go Normal file
View File

@ -0,0 +1,168 @@
package authboss
import (
"html/template"
"net/http"
"github.com/pkg/errors"
)
// RedirectOptions packages up all the pieces a module needs to write out a
// response.
type RedirectOptions struct {
// Success & Failure are used to set Flash messages / JSON messages
// if set. They should be mutually exclusive.
Success string
Failure string
// Code is used when it's an API request instead of 200.
Code int
// When a request should redirect a user somewhere on completion, these
// should be set. RedirectURL tells it where to go. And optionally set
// FollowRedirParam to override the RedirectURL if the form parameter defined
// by FormValueRedirect is passed in the request.
//
// Redirecting works differently whether it's an API request or not.
// If it's an API request, then it will leave the URL in a "redirect"
// parameter.
RedirectPath string
FollowRedirParam bool
}
// Respond to an HTTP request.
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
data.MergeKV(
"xsrfName", template.HTML(a.XSRFName),
"xsrfToken", template.HTML(a.XSRFMaker(w, r)),
)
if a.LayoutDataMaker != nil {
data.Merge(a.LayoutDataMaker(w, r))
}
session := a.SessionStoreMaker.Make(w, r)
if flash, ok := session.Get(FlashSuccessKey); ok {
session.Del(FlashSuccessKey)
data.MergeKV(FlashSuccessKey, flash)
}
if flash, ok := session.Get(FlashErrorKey); ok {
session.Del(FlashErrorKey)
data.MergeKV(FlashErrorKey, flash)
}
rendered, mime, err := a.renderer.Render(r.Context(), templateName, data)
if err != nil {
return err
}
w.Header().Set("Content-Type", mime)
w.WriteHeader(code)
_, err = w.Write(rendered)
return err
}
// EmailResponseOptions controls how e-mails are rendered and sent
type EmailResponseOptions struct {
Data HTMLData
HTMLTemplate string
TextTemplate string
}
// Email renders the e-mail templates and sends it using the mailer.
func (a *Authboss) Email(w http.ResponseWriter, r *http.Request, email Email, ro EmailResponseOptions) error {
ctx := r.Context()
if len(ro.HTMLTemplate) != 0 {
htmlBody, _, err := a.renderer.Render(ctx, ro.HTMLTemplate, ro.Data)
if err != nil {
return errors.Wrap(err, "failed to render e-mail html body")
}
email.HTMLBody = string(htmlBody)
}
if len(ro.TextTemplate) != 0 {
textBody, _, err := a.renderer.Render(ctx, ro.TextTemplate, ro.Data)
if err != nil {
return errors.Wrap(err, "failed to render e-mail text body")
}
email.TextBody = string(textBody)
}
return a.Mailer.Send(ctx, email)
}
// Redirect the client elsewhere. If it's an API request it will simply render
// a JSON response with information that should help a client to decide what
// to do.
func (a *Authboss) Redirect(w http.ResponseWriter, r *http.Request, ro RedirectOptions) error {
var redirectFunction = a.redirectNonAPI
if isAPIRequest(r) {
redirectFunction = a.redirectAPI
}
return redirectFunction(w, r, ro)
}
func (a *Authboss) redirectAPI(w http.ResponseWriter, r *http.Request, ro RedirectOptions) error {
path := ro.RedirectPath
redir := r.FormValue(FormValueRedirect)
if len(redir) != 0 && ro.FollowRedirParam {
path = redir
}
var status, message string
if len(ro.Success) != 0 {
status = "success"
message = ro.Success
}
if len(ro.Failure) != 0 {
status = "failure"
message = ro.Failure
}
data := HTMLData{
"path": path,
}
if len(status) != 0 {
data["status"] = status
data["message"] = message
}
body, mime, err := a.renderer.Render(r.Context(), "redirect", data)
if err != nil {
return err
}
if len(body) != 0 {
w.Header().Set("Content-Type", mime)
}
if ro.Code != 0 {
w.WriteHeader(ro.Code)
}
_, err = w.Write(body)
return err
}
func (a *Authboss) redirectNonAPI(w http.ResponseWriter, r *http.Request, ro RedirectOptions) error {
path := ro.RedirectPath
redir := r.FormValue(FormValueRedirect)
if len(redir) != 0 && ro.FollowRedirParam {
path = redir
}
if len(ro.Success) != 0 {
session := a.SessionStoreMaker.Make(w, r)
session.Put(FlashSuccessKey, ro.Success)
}
if len(ro.Failure) != 0 {
session := a.SessionStoreMaker.Make(w, r)
session.Put(FlashErrorKey, ro.Failure)
}
http.Redirect(w, r, path, http.StatusFound)
return nil
}

197
router.go
View File

@ -28,7 +28,7 @@ func (a *Authboss) NewRouter() http.Handler {
for name, mod := range a.loadedModules {
for route, handler := range mod.Routes() {
fmt.Fprintf(a.LogWriter, "%-10s Route: %s\n", "["+name+"]", path.Join(a.MountPath, route))
a.mux.Handle(path.Join(a.MountPath, route), contextRoute{a, handler})
a.mux.Handle(path.Join(a.MountPath, route), abHandler{a, handler})
}
}
@ -44,106 +44,105 @@ func (a *Authboss) NewRouter() http.Handler {
return a.mux
}
type contextRoute struct {
type abHandler struct {
*Authboss
fn HandlerFunc
}
func (c contextRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
/*
// Instantiate the context
ctx := c.Authboss.InitContext(w, r)
// Check to make sure we actually need to visit this route
if redirectIfLoggedIn(ctx, w, r) {
return
}
// Call the handler
err := c.fn(ctx, w, r)
if err == nil {
return
}
// Log the error
fmt.Fprintf(c.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err)
// Do specific error handling for special kinds of errors.
switch e := err.(type) {
case ErrAndRedirect:
if len(e.FlashSuccess) > 0 {
ctx.SessionStorer.Put(FlashSuccessKey, e.FlashSuccess)
}
if len(e.FlashError) > 0 {
ctx.SessionStorer.Put(FlashErrorKey, e.FlashError)
}
http.Redirect(w, r, e.Location, http.StatusFound)
case ClientDataErr:
if c.BadRequestHandler != nil {
c.BadRequestHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "400 Bad request")
}
default:
if c.ErrorHandler != nil {
c.ErrorHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "500 An error has occurred")
}
}
}
// redirectIfLoggedIn checks a user's existence by using currentUser. This is done instead of
// a simple Session cookie check so that the remember module has a chance to log the user in
// before they are determined to "not be logged in".
//
// The exceptional routes are sort of hardcoded in a terrible way in here, later on this could move to some
// configuration or something more interesting.
func redirectIfLoggedIn(ctx *Context, w http.ResponseWriter, r *http.Request) (handled bool) {
// If it's a log out url, always let it pass through.
if strings.HasSuffix(r.URL.Path, "/logout") {
return false
}
// If it's an auth url, allow them through if they're half-authed.
if strings.HasSuffix(r.URL.Path, "/auth") || strings.Contains(r.URL.Path, "/oauth2/") {
if halfAuthed, ok := ctx.SessionStorer.Get(SessionHalfAuthKey); ok && halfAuthed == "true" {
return false
}
}
// Otherwise, check if they're logged in, this uses hooks to allow remember
// to set the session cookie
cu, err := ctx.currentUser(ctx, w, r)
// if the user was not found, that means the user was deleted from the underlying
// storer and we should just remove this session cookie and allow them through.
// if it's a generic error, 500
// if the user is found, redirect them away from this page, because they don't need
// to see it.
if err == ErrUserNotFound {
uname, _ := ctx.SessionStorer.Get(SessionKey)
fmt.Fprintf(ctx.LogWriter, "user (%s) has session cookie but user not found, removing cookie", uname)
ctx.SessionStorer.Del(SessionKey)
return false
} else if err != nil {
fmt.Fprintf(ctx.LogWriter, "error occurred reading current user at %s: %v", r.URL.Path, err)
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "500 An error has occurred")
return true
}
if cu != nil {
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
http.Redirect(w, r, redir, http.StatusFound)
} else {
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)
}
return true
}
return false
*/
// TODO(aarondl): Move this somewhere reasonable
func isAPIRequest(r *http.Request) bool {
return r.Header.Get("Content-Type") == "application/json"
}
func (a abHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Put uid in the context
_, err := a.LoadCurrentUserID(w, &r)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "500 An error has occurred")
fmt.Fprintf(a.LogWriter, "failed to load current user id: %v", err)
return
}
// Call the handler
err = a.fn(w, r)
if err == nil {
return
}
// Log the error
fmt.Fprintf(a.LogWriter, "Error Occurred at %s: %v", r.URL.Path, err)
// Do specific error handling for special kinds of errors.
if _, ok := err.(ClientDataErr); ok {
if a.BadRequestHandler != nil {
a.BadRequestHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusBadRequest)
io.WriteString(w, "400 Bad request")
}
return
}
if a.ErrorHandler != nil {
a.ErrorHandler.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "500 An error has occurred")
}
}
/*
// TODO(aarondl): Throw away this function
// redirectIfLoggedIn checks a user's existence by using currentUser. This is done instead of
// a simple Session cookie check so that the remember module has a chance to log the user in
// before they are determined to "not be logged in".
//
// The exceptional routes are sort of hardcoded in a terrible way in here, later on this could move to some
// configuration or something more interesting.
func redirectIfLoggedIn(w http.ResponseWriter, r *http.Request) (handled bool) {
// If it's a log out url, always let it pass through.
if strings.HasSuffix(r.URL.Path, "/logout") {
return false
}
// If it's an auth url, allow them through if they're half-authed.
if strings.HasSuffix(r.URL.Path, "/auth") || strings.Contains(r.URL.Path, "/oauth2/") {
if halfAuthed, ok := ctx.SessionStorer.Get(SessionHalfAuthKey); ok && halfAuthed == "true" {
return false
}
}
// Otherwise, check if they're logged in, this uses hooks to allow remember
// to set the session cookie
cu, err := ctx.currentUser(ctx, w, r)
// if the user was not found, that means the user was deleted from the underlying
// storer and we should just remove this session cookie and allow them through.
// if it's a generic error, 500
// if the user is found, redirect them away from this page, because they don't need
// to see it.
if err == ErrUserNotFound {
uname, _ := ctx.SessionStorer.Get(SessionKey)
fmt.Fprintf(ctx.LogWriter, "user (%s) has session cookie but user not found, removing cookie", uname)
ctx.SessionStorer.Del(SessionKey)
return false
} else if err != nil {
fmt.Fprintf(ctx.LogWriter, "error occurred reading current user at %s: %v", r.URL.Path, err)
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, "500 An error has occurred")
return true
}
if cu != nil {
if redir := r.FormValue(FormValueRedirect); len(redir) > 0 {
http.Redirect(w, r, redir, http.StatusFound)
} else {
http.Redirect(w, r, ctx.AuthLoginOKPath, http.StatusFound)
}
return true
}
return false
}
*/

View File

@ -2,8 +2,6 @@ package authboss
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
@ -24,15 +22,17 @@ type testRouterModule struct {
func (t testRouterModule) Initialize(ab *Authboss) error { return nil }
func (t testRouterModule) Routes() RouteTable { return t.routes }
func (t testRouterModule) Templates() []string { return []string{"template1.tpl"} }
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
ab := New()
logger := &bytes.Buffer{}
ab.LogWriter = logger
ab.ViewLoader = mockRenderLoader{}
ab.Init(testRouterModName)
ab.MountPath = "/prefix"
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
logger.Reset() // Clear out the module load messages
@ -169,166 +169,3 @@ func TestRouter_Error(t *testing.T) {
t.Error(str)
}
}
func TestRouter_Redirect(t *testing.T) {
err := ErrAndRedirect{
Err: errors.New("error"),
Location: "/",
FlashSuccess: "yay",
FlashError: "nay",
}
w, r := testRouterCallbackSetup("/error",
func(http.ResponseWriter, *http.Request) error {
return err
},
)
ab, router, logger := testRouterSetup()
session := mockClientStore{}
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session }
logger.Reset()
router.ServeHTTP(w, r)
if w.Code != http.StatusFound {
t.Error("Wrong code:", w.Code)
}
if loc := w.Header().Get("Location"); loc != err.Location {
t.Error("Wrong location:", loc)
}
if succ, ok := session.Get(FlashSuccessKey); !ok || succ != err.FlashSuccess {
t.Error(succ, ok)
}
if fail, ok := session.Get(FlashErrorKey); !ok || fail != err.FlashError {
t.Error(fail, ok)
}
}
func TestRouter_redirectIfLoggedIn(t *testing.T) {
t.Parallel()
tests := []struct {
Path string
LoggedIn bool
HalfAuthed bool
ShouldRedirect bool
}{
// These routes will be accessed depending on logged in and half auth's value
{"/auth", false, false, false},
{"/auth", true, false, true},
{"/auth", true, true, false},
{"/oauth2/facebook", false, false, false},
{"/oauth2/facebook", true, false, true},
{"/oauth2/facebook", true, true, false},
{"/oauth2/callback/facebook", false, false, false},
{"/oauth2/callback/facebook", true, false, true},
{"/oauth2/callback/facebook", true, true, false},
// These are logout routes and never redirect
{"/logout", true, false, false},
{"/logout", true, true, false},
{"/oauth2/logout", true, false, false},
{"/oauth2/logout", true, true, false},
// These routes should always redirect despite half auth
{"/register", true, true, true},
{"/recover", true, true, true},
{"/register", false, false, false},
{"/recover", false, false, false},
}
storer := mockStoreLoader{"john@john.com": mockUser{
Email: "john@john.com",
Password: "password",
}}
ab := New()
ab.StoreLoader = storer
for i, test := range tests {
session := mockClientStore{}
cookies := mockClientStore{}
ctx := context.TODO()
ctx.SessionStorer = session
ctx.CookieStorer = cookies
if test.LoggedIn {
session[SessionKey] = "john@john.com"
}
if test.HalfAuthed {
session[SessionHalfAuthKey] = "true"
}
r, _ := http.NewRequest("GET", test.Path, nil)
w := httptest.NewRecorder()
handled := redirectIfLoggedIn(ctx, w, r)
if test.ShouldRedirect && (!handled || w.Code != http.StatusFound) {
t.Errorf("%d) It should have redirected the request: %q %t %d", i, test.Path, handled, w.Code)
} else if !test.ShouldRedirect && (handled || w.Code != http.StatusOK) {
t.Errorf("%d) It should have NOT redirected the request: %q %t %d", i, test.Path, handled, w.Code)
}
}
}
type deathStorer struct{}
func (d deathStorer) Create(key string, attributes Attributes) error { return nil }
func (d deathStorer) Put(key string, attributes Attributes) error { return nil }
func (d deathStorer) Get(key string) (interface{}, error) { return nil, errors.New("explosion") }
func TestRouter_redirectIfLoggedInError(t *testing.T) {
t.Parallel()
ab := New()
ab.LogWriter = ioutil.Discard
ab.Storer = deathStorer{}
session := mockClientStore{SessionKey: "john"}
cookies := mockClientStore{}
ctx := context.TODO()
ctx.SessionStorer = session
ctx.CookieStorer = cookies
r, _ := http.NewRequest("GET", "/auth", nil)
w := httptest.NewRecorder()
handled := redirectIfLoggedIn(ctx, w, r)
if !handled {
t.Error("It should have been handled.")
}
if w.Code != http.StatusInternalServerError {
t.Error("It should have internal server error'd:", w.Code)
}
}
type notFoundStorer struct{}
func (n notFoundStorer) Create(key string, attributes Attributes) error { return nil }
func (n notFoundStorer) Put(key string, attributes Attributes) error { return nil }
func (n notFoundStorer) Get(key string) (interface{}, error) { return nil, ErrUserNotFound }
func TestRouter_redirectIfLoggedInUserNotFound(t *testing.T) {
t.Parallel()
ab := New()
ab.LogWriter = ioutil.Discard
ab.Storer = notFoundStorer{}
session := mockClientStore{SessionKey: "john"}
cookies := mockClientStore{}
ctx := context.TODO()
ctx.SessionStorer = session
ctx.CookieStorer = cookies
r, _ := http.NewRequest("GET", "/auth", nil)
w := httptest.NewRecorder()
handled := redirectIfLoggedIn(ctx, w, r)
if handled {
t.Error("It should not have been handled.")
}
if _, ok := session.Get(SessionKey); ok {
t.Error("It should have removed the bad session cookie")
}
}

View File

@ -3,7 +3,6 @@ package authboss
import (
"bytes"
"context"
"reflect"
"time"
"github.com/pkg/errors"
@ -73,10 +72,12 @@ type ArbitraryStorer interface {
GetArbitrary(ctx context.Context) (arbitrary map[string]string, err error)
}
// OAuth2Storer allows reading and writing values
// OAuth2Storer allows reading and writing values relating to OAuth2
type OAuth2Storer interface {
Storer
IsOAuth2User(ctx context.Context) (bool, error)
PutUID(ctx context.Context, uid string) error
PutProvider(ctx context.Context, provider string) error
PutToken(ctx context.Context, token string) error
@ -90,36 +91,6 @@ type OAuth2Storer interface {
GetExpiry(ctx context.Context) (expiry time.Duration, err error)
}
// DataType represents the various types that clients must be able to store.
type DataType int
// DataType constants
const (
Integer DataType = iota
String
Bool
DateTime
)
var (
dateTimeType = reflect.TypeOf(time.Time{})
)
// String returns a string for the DataType representation.
func (d DataType) String() string {
switch d {
case Integer:
return "Integer"
case String:
return "String"
case Bool:
return "Bool"
case DateTime:
return "DateTime"
}
return ""
}
func camelToUnder(in string) string {
out := bytes.Buffer{}
for i := 0; i < len(in); i++ {

View File

@ -1,476 +1,6 @@
package authboss
import (
"bytes"
"database/sql"
"database/sql/driver"
"net/http"
"net/url"
"strings"
"testing"
"time"
)
type NullTime struct {
Time time.Time
Valid bool
}
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
func TestAttributes_FromRequest(t *testing.T) {
t.Parallel()
now := time.Now().UTC()
vals := make(url.Values)
vals.Set("a", "a")
vals.Set("b_int", "5")
vals.Set("wildcard", "")
vals.Set("c_date", now.Format(time.RFC3339))
req, err := http.NewRequest("POST", "/", strings.NewReader(vals.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if err != nil {
t.Error(err)
}
attr, err := AttributesFromRequest(req)
if err != nil {
t.Error(err)
}
if got := attr["a"].(string); got != "a" {
t.Error("a's value is wrong:", got)
}
if got := attr["b"].(int); got != 5 {
t.Error("b's value is wrong:", got)
}
if got := attr["c"].(time.Time); got.Unix() != now.Unix() {
t.Error("c's value is wrong:", now, got)
}
if _, ok := attr["wildcard"]; ok {
t.Error("We don't need totally empty fields.")
}
}
func TestAttributes_Names(t *testing.T) {
t.Parallel()
attr := Attributes{
"integer": 5,
"string": "string",
"bool": true,
"date_time": time.Now(),
}
names := attr.Names()
found := map[string]bool{"integer": false, "string": false, "bool": false, "date_time": false}
for _, n := range names {
found[n] = true
}
for k, v := range found {
if !v {
t.Error("Could not find:", k)
}
}
}
func TestAttributeMeta_Names(t *testing.T) {
t.Parallel()
meta := AttributeMeta{
"integer": Integer,
"string": String,
"bool": Bool,
"date_time": DateTime,
}
names := meta.Names()
found := map[string]bool{"integer": false, "string": false, "bool": false, "date_time": false}
for _, n := range names {
found[n] = true
}
for k, v := range found {
if !v {
t.Error("Could not find:", k)
}
}
}
func TestAttributeMeta_Helpers(t *testing.T) {
t.Parallel()
now := time.Now()
attr := Attributes{
"integer": int64(5),
"string": "a",
"bool": true,
"date_time": now,
}
if str, ok := attr.String("string"); !ok || str != "a" {
t.Error(str, ok)
}
if str, err := attr.StringErr("string"); err != nil || str != "a" {
t.Error(str, err)
}
if str, ok := attr.String("notstring"); ok {
t.Error(str, ok)
}
if str, err := attr.StringErr("notstring"); err == nil {
t.Error(str, err)
}
if integer, ok := attr.Int64("integer"); !ok || integer != 5 {
t.Error(integer, ok)
}
if integer, err := attr.Int64Err("integer"); err != nil || integer != 5 {
t.Error(integer, err)
}
if integer, ok := attr.Int64("notinteger"); ok {
t.Error(integer, ok)
}
if integer, err := attr.Int64Err("notinteger"); err == nil {
t.Error(integer, err)
}
if boolean, ok := attr.Bool("bool"); !ok || !boolean {
t.Error(boolean, ok)
}
if boolean, err := attr.BoolErr("bool"); err != nil || !boolean {
t.Error(boolean, err)
}
if boolean, ok := attr.Bool("notbool"); ok {
t.Error(boolean, ok)
}
if boolean, err := attr.BoolErr("notbool"); err == nil {
t.Error(boolean, err)
}
if date, ok := attr.DateTime("date_time"); !ok || date != now {
t.Error(date, ok)
}
if date, err := attr.DateTimeErr("date_time"); err != nil || date != now {
t.Error(date, err)
}
if date, ok := attr.DateTime("notdate_time"); ok {
t.Error(date, ok)
}
if date, err := attr.DateTimeErr("notdate_time"); err == nil {
t.Error(date, err)
}
}
func TestDataType_String(t *testing.T) {
t.Parallel()
if Integer.String() != "Integer" {
t.Error("Expected Integer:", Integer)
}
if String.String() != "String" {
t.Error("Expected String:", String)
}
if Bool.String() != "Bool" {
t.Error("Expected Bool:", String)
}
if DateTime.String() != "DateTime" {
t.Error("Expected DateTime:", DateTime)
}
}
func TestAttributes_Bind(t *testing.T) {
t.Parallel()
anInteger := 5
aString := "string"
aBool := true
aTime := time.Now()
anUnknown := []byte("I'm not a recognizable type")
data := Attributes{
"integer": anInteger,
"string": aString,
"bool": aBool,
"date_time": aTime,
"unknown": anUnknown,
}
s := struct {
Integer int
String string
Bool bool
DateTime time.Time
Unknown []byte
}{}
if err := data.Bind(&s, false); err != nil {
t.Error("Unexpected Error:", err)
}
if s.Integer != anInteger {
t.Error("Integer was not set.")
}
if s.String != aString {
t.Error("String was not set.")
}
if s.Bool != aBool {
t.Error("Bool was not set.")
}
if s.DateTime != aTime {
t.Error("DateTime was not set.")
}
if 0 != bytes.Compare(s.Unknown, anUnknown) {
t.Error("The []byte slice was not set.")
}
}
func TestAttributes_BindIgnoreMissing(t *testing.T) {
t.Parallel()
anInteger := 5
aString := "string"
data := Attributes{
"integer": anInteger,
"string": aString,
}
s := struct {
Integer int
}{}
if err := data.Bind(&s, false); err == nil {
t.Error("Expected error about missing attributes:", err)
}
if err := data.Bind(&s, true); err != nil {
t.Error(err)
}
if s.Integer != anInteger {
t.Error("Integer was not set.")
}
}
func TestAttributes_BindNoPtr(t *testing.T) {
t.Parallel()
data := Attributes{}
s := struct{}{}
if err := data.Bind(s, false); err == nil {
t.Error("Expected an error.")
} else if !strings.Contains(err.Error(), "struct pointer") {
t.Error("Expected an error about pointers got:", err)
}
}
func TestAttributes_BindMissingField(t *testing.T) {
t.Parallel()
data := Attributes{"Integer": 5}
s := struct{}{}
if err := data.Bind(&s, false); err == nil {
t.Error("Expected an error.")
} else if !strings.Contains(err.Error(), "missing") {
t.Error("Expected an error about missing fields, got:", err)
}
}
func TestAttributes_BindTypeFail(t *testing.T) {
t.Parallel()
tests := []struct {
Attr Attributes
Err string
ToBind interface{}
}{
{
Attr: Attributes{"integer": 5},
Err: "should be int",
ToBind: &struct {
Integer string
}{},
},
{
Attr: Attributes{"string": ""},
Err: "should be string",
ToBind: &struct {
String int
}{},
},
{
Attr: Attributes{"bool": true},
Err: "should be bool",
ToBind: &struct {
Bool string
}{},
},
{
Attr: Attributes{"date": time.Time{}},
Err: "should be time.Time",
ToBind: &struct {
Date int
}{},
},
}
for i, test := range tests {
if err := test.Attr.Bind(test.ToBind, false); err == nil {
t.Errorf("%d> Expected an error.", i)
} else if !strings.Contains(err.Error(), test.Err) {
t.Errorf("%d> Expected an error about %q got: %q", i, test.Err, err)
}
}
}
func TestAttributes_BindScannerValues(t *testing.T) {
t.Parallel()
s1 := struct {
Count sql.NullInt64
Time NullTime
}{
sql.NullInt64{},
NullTime{},
}
nowTime := time.Now()
attrs := Attributes{"count": 12, "time": nowTime}
if err := attrs.Bind(&s1, false); err != nil {
t.Error("Unexpected error:", err)
}
if !s1.Count.Valid {
t.Error("Expected valid NullInt64")
}
if s1.Count.Int64 != 12 {
t.Error("Unexpected value:", s1.Count.Int64)
}
if !s1.Time.Valid {
t.Error("Expected valid time.Time")
}
if !s1.Time.Time.Equal(nowTime) {
t.Error("Unexpected value:", s1.Time.Time)
}
}
func TestUnbind(t *testing.T) {
t.Parallel()
s1 := struct {
Integer int
String string
Bool bool
Time time.Time
Int32 int32
ConfigStruct *Config
unexported int
}{5, "string", true, time.Now(), 5, &Config{}, 5}
attr := Unbind(&s1)
if len(attr) != 6 {
t.Error("Expected 6 fields, got:", len(attr))
}
if v, ok := attr["integer"]; !ok {
t.Error("Could not find Integer entry.")
} else if val, ok := v.(int); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if s1.Integer != val {
t.Error("Underlying value is wrong:", val)
}
if v, ok := attr["string"]; !ok {
t.Error("Could not find String entry.")
} else if val, ok := v.(string); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if s1.String != val {
t.Error("Underlying value is wrong:", val)
}
if v, ok := attr["bool"]; !ok {
t.Error("Could not find String entry.")
} else if val, ok := v.(bool); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if s1.Bool != val {
t.Error("Underlying value is wrong:", val)
}
if v, ok := attr["time"]; !ok {
t.Error("Could not find Time entry.")
} else if val, ok := v.(time.Time); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if s1.Time != val {
t.Error("Underlying value is wrong:", val)
}
if v, ok := attr["int32"]; !ok {
t.Error("Could not find Int32 entry.")
} else if val, ok := v.(int32); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if s1.Int32 != val {
t.Error("Underlying value is wrong:", val)
}
if v, ok := attr["config_struct"]; !ok {
t.Error("Could not find ConfigStruct entry.")
} else if val, ok := v.(*Config); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if s1.ConfigStruct != val {
t.Error("Underlying value is wrong:", val)
}
}
func TestUnbind_Valuer(t *testing.T) {
t.Parallel()
nowTime := time.Now()
s1 := struct {
Count sql.NullInt64
Time NullTime
}{
sql.NullInt64{Int64: 12, Valid: true},
NullTime{nowTime, true},
}
attr := Unbind(&s1)
if v, ok := attr["count"]; !ok {
t.Error("Could not find NullInt64 entry.")
} else if val, ok := v.(int64); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if 12 != val {
t.Error("Underlying value is wrong:", val)
}
if v, ok := attr["time"]; !ok {
t.Error("Could not find NullTime entry.")
} else if val, ok := v.(time.Time); !ok {
t.Errorf("Underlying type is wrong: %T", v)
} else if !nowTime.Equal(val) {
t.Error("Underlying value is wrong:", val)
}
}
import "testing"
func TestCasingStyleConversions(t *testing.T) {
t.Parallel()