1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-07-15 01:24:33 +02:00

Introduce new type of client storage

- This addresses the problem of having to update multiple times during
  one request. It's hard to have a nice interface especially with JWT
  because you always end up having to decode the request, encode new
  response, write header, then a second write to it comes, and where do
  you grab the value from? Often you don't have access to the response
  as a "read" structure. So we store it as events instead, and play
  those events against the original data right before the response is
  written to set the headers.
This commit is contained in:
Aaron L
2017-02-24 16:45:47 -08:00
parent 3170cb8068
commit 24fc6196c7
15 changed files with 599 additions and 310 deletions

View File

@ -1,13 +1,8 @@
package authboss package authboss
import ( import (
"context"
"io/ioutil" "io/ioutil"
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/pkg/errors"
) )
func TestAuthBossInit(t *testing.T) { func TestAuthBossInit(t *testing.T) {
@ -22,74 +17,6 @@ func TestAuthBossInit(t *testing.T) {
} }
} }
func TestAuthBossCurrentUser(t *testing.T) {
t.Parallel()
ab := New()
ab.LogWriter = ioutil.Discard
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
ab.ViewLoader = mockRenderLoader{}
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
if err := ab.Init(); err != nil {
t.Error("Unexpected error:", err)
}
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "localhost", nil)
userStruct := ab.CurrentUserP(rec, req)
us := userStruct.(mockStoredUser)
if us.Email != "john@john.com" || us.Password != "lies" {
t.Error("Wrong user found!")
}
}
func TestAuthBossCurrentUserCallbacks(t *testing.T) {
t.Parallel()
ab := New()
ab.LogWriter = ioutil.Discard
ab.StoreLoader = mockStoreLoader{"joe": mockUser{Email: "john@john.com", Password: "lies"}}
ab.ViewLoader = mockRenderLoader{}
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{SessionKey: "joe"})
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
if err := ab.Init(); err != nil {
t.Error("Unexpected error:", err)
}
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "localhost", nil)
afterGetUser := errors.New("afterGetUser")
beforeGetUser := errors.New("beforeGetUser")
beforeGetUserSession := errors.New("beforeGetUserSession")
ab.Callbacks.After(EventGetUser, func(context.Context) error {
return afterGetUser
})
if _, err := ab.CurrentUser(rec, req); err != afterGetUser {
t.Error("Want:", afterGetUser, "Got:", err)
}
ab.Callbacks.Before(EventGetUser, func(context.Context) (Interrupt, error) {
return InterruptNone, beforeGetUser
})
if _, err := ab.CurrentUser(rec, req); err != beforeGetUser {
t.Error("Want:", beforeGetUser, "Got:", err)
}
ab.Callbacks.Before(EventGetUserSession, func(context.Context) (Interrupt, error) {
return InterruptNone, beforeGetUserSession
})
if _, err := ab.CurrentUser(rec, req); err != beforeGetUserSession {
t.Error("Want:", beforeGetUserSession, "Got:", err)
}
}
func TestAuthbossUpdatePassword(t *testing.T) { func TestAuthbossUpdatePassword(t *testing.T) {
t.Skip("TODO(aarondl): Implement") t.Skip("TODO(aarondl): Implement")
/* /*

254
client_state.go Normal file
View File

@ -0,0 +1,254 @@
package authboss
import (
"context"
"net/http"
)
const (
// SessionKey is the primarily used key by authboss.
SessionKey = "uid"
// SessionHalfAuthKey is used for sessions that have been authenticated by
// the remember module. This serves as a way to force full authentication
// by denying half-authed users acccess to sensitive areas.
SessionHalfAuthKey = "halfauth"
// SessionLastAction is the session key to retrieve the last action of a user.
SessionLastAction = "last_action"
// SessionOAuth2State is the xsrf protection key for oauth.
SessionOAuth2State = "oauth2_state"
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
SessionOAuth2Params = "oauth2_params"
// CookieRemember is used for cookies and form input names.
CookieRemember = "rm"
// FlashSuccessKey is used for storing sucess flash messages on the session
FlashSuccessKey = "flash_success"
// FlashErrorKey is used for storing sucess flash messages on the session
FlashErrorKey = "flash_error"
)
// ClientStateEventKind is an enum.
type ClientStateEventKind int
const (
ClientStateEventPut ClientStateEventKind = iota
ClientStateEventDel
)
// ClientStateEvent are the different events that can be recorded during
type ClientStateEvent struct {
Kind ClientStateEventKind
Key string
Value string
}
// ClientStateReadWriter is used to create a cookie storer from an http request.
// Keep in mind security considerations for your implementation, Secure,
// HTTP-Only, etc flags.
//
// There's two major uses for this. To create session storage, and remember me
// cookies.
type ClientStateReadWriter interface {
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
}
// ClientState represents the client's current state and can answer queries
// about it.
type ClientState interface {
Get(key string) (string, bool)
}
// clientStateResponseWriter is used to write out the client state at the last
// moment before the response code is written.
type ClientStateResponseWriter struct {
ab *Authboss
http.ResponseWriter
hasWritten bool
ctx context.Context
sessionStateEvents []ClientStateEvent
cookieStateEvents []ClientStateEvent
}
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) http.ResponseWriter {
return &ClientStateResponseWriter{
ab: a,
ResponseWriter: w,
ctx: r.Context(),
}
}
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
if a.SessionStateStorer != nil {
state, err := a.SessionStateStorer.ReadState(w, r)
if err != nil {
return nil, err
} else if state == nil {
return r, nil
}
ctx := context.WithValue(r.Context(), ctxKeySessionState, state)
r = r.WithContext(ctx)
}
if a.CookieStateStorer != nil {
state, err := a.CookieStateStorer.ReadState(w, r)
if err != nil {
return nil, err
} else if state == nil {
return r, nil
}
ctx := context.WithValue(r.Context(), ctxKeyCookieState, state)
r = r.WithContext(ctx)
}
return r, nil
}
// WriteHeader writes the header, but in order to handle errors from the
// underlying ClientStateReadWriter, it has to panic.
func (c *ClientStateResponseWriter) WriteHeader(code int) {
if !c.hasWritten {
if err := c.putClientState(); err != nil {
panic(err)
}
}
c.ResponseWriter.WriteHeader(code)
}
// Header retrieves the underlying headers
func (c ClientStateResponseWriter) Header() http.Header {
return c.ResponseWriter.Header()
}
// Write ensures that the
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
if !c.hasWritten {
if err := c.putClientState(); err != nil {
return 0, err
}
}
return c.ResponseWriter.Write(b)
}
func (c *ClientStateResponseWriter) putClientState() error {
if c.hasWritten {
panic("should not call putClientState twice")
}
c.hasWritten = true
sessionStateIntf := c.ctx.Value(ctxKeySessionState)
cookieStateIntf := c.ctx.Value(ctxKeyCookieState)
var session, cookie ClientState
if sessionStateIntf != nil {
session = sessionStateIntf.(ClientState)
}
if cookieStateIntf != nil {
cookie = cookieStateIntf.(ClientState)
}
if c.ab.SessionStateStorer != nil {
err := c.ab.SessionStateStorer.WriteState(c, session, c.sessionStateEvents)
if err != nil {
return err
}
}
if c.ab.CookieStateStorer != nil {
err := c.ab.CookieStateStorer.WriteState(c, cookie, c.cookieStateEvents)
if err != nil {
return err
}
}
return nil
}
// PutSession puts a value into the session
func PutSession(w http.ResponseWriter, key, val string) {
putState(w, ctxKeySessionState, key, val)
}
// DelSession deletes a key-value from the session.
func DelSession(w http.ResponseWriter, key string) {
delState(w, ctxKeySessionState, key)
}
// GetSession fetches a value from the session
func GetSession(r *http.Request, key string) (string, bool) {
return getState(r, ctxKeySessionState, key)
}
// PutCookie puts a value into the session
func PutCookie(w http.ResponseWriter, key, val string) {
putState(w, ctxKeyCookieState, key, val)
}
// DelCookie deletes a key-value from the session.
func DelCookie(w http.ResponseWriter, key string) {
delState(w, ctxKeyCookieState, key)
}
// GetCookie fetches a value from the session
func GetCookie(r *http.Request, key string) (string, bool) {
return getState(r, ctxKeyCookieState, key)
}
func putState(w http.ResponseWriter, ctxKey contextKey, key, val string) {
setState(w, ctxKey, ClientStateEventPut, key, val)
}
func delState(w http.ResponseWriter, ctxKey contextKey, key string) {
setState(w, ctxKey, ClientStateEventDel, key, "")
}
func setState(w http.ResponseWriter, ctxKey contextKey, op ClientStateEventKind, key, val string) {
csrw := w.(*ClientStateResponseWriter)
ev := ClientStateEvent{
Kind: op,
Key: key,
}
if op == ClientStateEventPut {
ev.Value = val
}
switch ctxKey {
case ctxKeySessionState:
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
case ctxKeyCookieState:
csrw.cookieStateEvents = append(csrw.cookieStateEvents, ev)
}
}
func getState(r *http.Request, ctxKey contextKey, key string) (string, bool) {
val := r.Context().Value(ctxKey)
if val == nil {
return "", false
}
state := val.(ClientState)
return state.Get(key)
}
// FlashSuccess returns FlashSuccessKey from the session and removes it.
func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
str, ok := GetSession(r, FlashSuccessKey)
if !ok {
return ""
}
DelSession(w, FlashSuccessKey)
return str
}
// FlashError returns FlashError from the session and removes it.
func FlashError(w http.ResponseWriter, r *http.Request) string {
str, ok := GetSession(r, FlashErrorKey)
if !ok {
return ""
}
DelSession(w, FlashErrorKey)
return str
}

197
client_state_test.go Normal file
View File

@ -0,0 +1,197 @@
package authboss
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestStateGet(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStateStorer = newMockClientStateRW("one", "two")
ab.CookieStateStorer = newMockClientStateRW("three", "four")
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
var err error
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
if got, _ := GetSession(r, "one"); got != "two" {
t.Error("session value was wrong:", got)
}
if got, _ := GetCookie(r, "three"); got != "four" {
t.Error("cookie value was wrong:", got)
}
}
func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStateStorer = newMockClientStateRW("one", "two")
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
csrw := w.(*ClientStateResponseWriter)
w.WriteHeader(200)
// Check this doesn't panic
w.WriteHeader(200)
defer func() {
if recover() == nil {
t.Error("expected a panic")
}
}()
csrw.putClientState()
}
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStateStorer = newMockClientStateRW("one", "two")
ab.CookieStateStorer = newMockClientStateRW("three", "four")
r := httptest.NewRequest("GET", "/", nil)
var w http.ResponseWriter = httptest.NewRecorder()
var err error
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
w = ab.NewResponse(w, r)
w.WriteHeader(200)
// This is an odd test, since the mock will always overwrite the previous
// write with the cookie values. Keeping it anyway for code coverage
got := strings.TrimSpace(w.Header().Get("test_session"))
if got != `{"three":"four"}` {
t.Error("got:", got)
}
}
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStateStorer = newMockClientStateRW()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
PutSession(w, "one", "two")
w.WriteHeader(200)
got := strings.TrimSpace(w.Header().Get("test_session"))
if got != `{"one":"two"}` {
t.Error("got:", got)
}
}
func TestStateResponseWriterLastSecondWriteWrite(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStateStorer = newMockClientStateRW()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
PutSession(w, "one", "two")
io.WriteString(w, "Hello world!")
got := strings.TrimSpace(w.Header().Get("test_session"))
if got != `{"one":"two"}` {
t.Error("got:", got)
}
}
func TestStateResponseWriterEvents(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
csrw := w.(*ClientStateResponseWriter)
PutSession(w, "one", "two")
DelSession(w, "one")
DelCookie(w, "one")
PutCookie(w, "two", "one")
want := ClientStateEvent{Kind: ClientStateEventPut, Key: "one", Value: "two"}
if got := csrw.sessionStateEvents[0]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
if got := csrw.sessionStateEvents[1]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventDel, Key: "one"}
if got := csrw.cookieStateEvents[0]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventPut, Key: "two", Value: "one"}
if got := csrw.cookieStateEvents[1]; got != want {
t.Error("event was wrong", got)
}
}
func TestFlashClearer(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStateStorer = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b")
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
csrw := w.(*ClientStateResponseWriter)
if msg := FlashSuccess(w, r); msg != "" {
t.Error("Unexpected flash success:", msg)
}
if msg := FlashError(w, r); msg != "" {
t.Error("Unexpected flash error:", msg)
}
var err error
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
if msg := FlashSuccess(w, r); msg != "a" {
t.Error("Unexpected flash success:", msg)
}
if msg := FlashError(w, r); msg != "b" {
t.Error("Unexpected flash error:", msg)
}
want := ClientStateEvent{Kind: ClientStateEventDel, Key: FlashSuccessKey}
if got := csrw.sessionStateEvents[0]; got != want {
t.Error("event was wrong", got)
}
want = ClientStateEvent{Kind: ClientStateEventDel, Key: FlashErrorKey}
if got := csrw.sessionStateEvents[1]; got != want {
t.Error("event was wrong", got)
}
}

View File

@ -1,86 +0,0 @@
package authboss
import "net/http"
const (
// SessionKey is the primarily used key by authboss.
SessionKey = "uid"
// SessionHalfAuthKey is used for sessions that have been authenticated by
// the remember module. This serves as a way to force full authentication
// by denying half-authed users acccess to sensitive areas.
SessionHalfAuthKey = "halfauth"
// SessionLastAction is the session key to retrieve the last action of a user.
SessionLastAction = "last_action"
// SessionOAuth2State is the xsrf protection key for oauth.
SessionOAuth2State = "oauth2_state"
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
SessionOAuth2Params = "oauth2_params"
// CookieRemember is used for cookies and form input names.
CookieRemember = "rm"
// FlashSuccessKey is used for storing sucess flash messages on the session
FlashSuccessKey = "flash_success"
// FlashErrorKey is used for storing sucess flash messages on the session
FlashErrorKey = "flash_error"
)
// ClientStoreMaker is used to create a cookie storer from an http request.
// Keep in mind security considerations for your implementation, Secure,
// HTTP-Only, etc flags.
//
// There's two major uses for this. To create session storage, and remember me
// cookies.
type ClientStoreMaker interface {
Make(http.ResponseWriter, *http.Request) ClientStorer
}
// ClientStorer should be able to store values on the clients machine. Cookie and
// Session storers are built with this interface.
type ClientStorer interface {
Put(key, value string)
Get(key string) (string, bool)
Del(key string)
}
// ClientStorerErr is a wrapper to return error values from failed Gets.
type ClientStorerErr interface {
ClientStorer
GetErr(key string) (string, error)
}
type clientStoreWrapper struct {
ClientStorer
}
// GetErr returns a value or an error.
func (c clientStoreWrapper) GetErr(key string) (string, error) {
str, ok := c.Get(key)
if !ok {
return str, ClientDataErr{key}
}
return str, nil
}
// FlashSuccess returns FlashSuccessKey from the session and removes it.
func (a *Authboss) FlashSuccess(w http.ResponseWriter, r *http.Request) string {
storer := a.SessionStoreMaker.Make(w, r)
msg, ok := storer.Get(FlashSuccessKey)
if ok {
storer.Del(FlashSuccessKey)
}
return msg
}
// FlashError returns FlashError from the session and removes it.
func (a *Authboss) FlashError(w http.ResponseWriter, r *http.Request) string {
storer := a.SessionStoreMaker.Make(w, r)
msg, ok := storer.Get(FlashErrorKey)
if ok {
storer.Del(FlashErrorKey)
}
return msg
}

View File

@ -1,52 +0,0 @@
package authboss
import "testing"
type testClientStorerErr string
func (t testClientStorerErr) Put(key, value string) {}
func (t testClientStorerErr) Get(key string) (string, bool) {
return string(t), key == string(t)
}
func (t testClientStorerErr) Del(key string) {}
func TestClientStorerErr(t *testing.T) {
t.Parallel()
var cs testClientStorerErr
csw := clientStoreWrapper{&cs}
if _, err := csw.GetErr("hello"); err == nil {
t.Error("Expected an error")
}
cs = "hello"
if str, err := csw.GetErr("hello"); err != nil {
t.Error(err)
} else if str != "hello" {
t.Error("Wrong value:", str)
}
}
func TestFlashClearer(t *testing.T) {
t.Parallel()
session := mockClientStore{FlashSuccessKey: "success", FlashErrorKey: "error"}
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(session)
if msg := ab.FlashSuccess(nil, nil); msg != "success" {
t.Error("Unexpected flash success:", msg)
}
if msg, ok := session.Get(FlashSuccessKey); ok {
t.Error("Unexpected success flash:", msg)
}
if msg := ab.FlashError(nil, nil); msg != "error" {
t.Error("Unexpected flash error:", msg)
}
if msg, ok := session.Get(FlashErrorKey); ok {
t.Error("Unexpected error flash:", msg)
}
}

View File

@ -97,14 +97,16 @@ type Config struct {
// Storer is the interface through which Authboss accesses the web apps database. // Storer is the interface through which Authboss accesses the web apps database.
StoreLoader StoreLoader StoreLoader StoreLoader
// CookieStoreMaker must be defined to provide an interface capapable of storing cookies // CookieStateStorer must be defined to provide an interface capapable of
// for the given response, and reading them from the request. // storing cookies for the given response, and reading them from the request.
CookieStoreMaker ClientStoreMaker CookieStateStorer ClientStateReadWriter
// SessionStoreMaker must be defined to provide an interface capable of storing session-only // SessionStateStorer must be defined to provide an interface capable of
// values for the given response, and reading them from the request. // storing session-only values for the given response, and reading them
SessionStoreMaker ClientStoreMaker // from the request.
// LogWriter is written to when errors occur, as well as on startup to show which modules are loaded SessionStateStorer ClientStateReadWriter
// and which routes they registered. By default writes to io.Discard. // LogWriter is written to when errors occur, as well as on startup to show
// which modules are loaded and which routes they registered. By default
// writes to io.Discard.
LogWriter io.Writer LogWriter io.Writer
// Mailer is the mailer being used to send e-mails out. Authboss defines two loggers for use // Mailer is the mailer being used to send e-mails out. Authboss defines two loggers for use
// LogMailer and SMTPMailer, the default is a LogMailer to io.Discard. // LogMailer and SMTPMailer, the default is a LogMailer to io.Discard.

View File

@ -10,6 +10,9 @@ type contextKey string
const ( const (
ctxKeyPID contextKey = "pid" ctxKeyPID contextKey = "pid"
ctxKeyUser contextKey = "user" ctxKeyUser contextKey = "user"
ctxKeySessionState contextKey = "session"
ctxKeyCookieState contextKey = "cookie"
) )
func (c contextKey) String() string { func (c contextKey) String() string {
@ -27,8 +30,7 @@ func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string
return "", err return "", err
} }
session := a.SessionStoreMaker.Make(w, r) pid, _ := GetSession(r, SessionKey)
pid, _ := session.Get(SessionKey)
return pid, nil return pid, nil
} }

View File

@ -1,10 +1,6 @@
package authboss package authboss
import ( /* TODO(aarondl): Re-enable
"context"
"net/http/httptest"
"testing"
)
func TestCurrentUserID(t *testing.T) { func TestCurrentUserID(t *testing.T) {
t.Parallel() t.Parallel()
@ -233,3 +229,4 @@ func TestLoadCurrentUserP(t *testing.T) {
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserP(nil, &req) _ = ab.LoadCurrentUserP(nil, &req)
} }
*/

View File

@ -9,10 +9,12 @@ var nowTime = time.Now
// TimeToExpiry returns zero if the user session is expired else the time until expiry. // TimeToExpiry returns zero if the user session is expired else the time until expiry.
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration { func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
return a.timeToExpiry(a.SessionStoreMaker.Make(w, r)) //TODO(aarondl): Rewrite this so it makes sense with new ClientStorer idioms
//return a.timeToExpiry(state.(ClientState))
return 0
} }
func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration { func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
dateStr, ok := session.Get(SessionLastAction) dateStr, ok := session.Get(SessionLastAction)
if !ok { if !ok {
return a.ExpireAfter return a.ExpireAfter
@ -33,12 +35,13 @@ func (a *Authboss) timeToExpiry(session ClientStorer) time.Duration {
// RefreshExpiry updates the last action for the user, so he doesn't become expired. // RefreshExpiry updates the last action for the user, so he doesn't become expired.
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) { func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
session := a.SessionStoreMaker.Make(w, r) //TODO(aarondl): Fix
a.refreshExpiry(session) //a.refreshExpiry(session)
} }
func (a *Authboss) refreshExpiry(session ClientStorer) { func (a *Authboss) refreshExpiry(session ClientState) {
session.Put(SessionLastAction, nowTime().UTC().Format(time.RFC3339)) //TODO(aarondl): Fix
PutSession(nil, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
} }
type expireMiddleware struct { type expireMiddleware struct {
@ -57,16 +60,17 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
// ServeHTTP removes the session if it's passed the expire time. // ServeHTTP removes the session if it's passed the expire time.
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session := m.ab.SessionStoreMaker.Make(w, r) //TODO(aarondl): Fix
if _, ok := session.Get(SessionKey); ok { /*
ttl := m.ab.timeToExpiry(session) if _, ok := GetSession(r, SessionKey); ok {
if ttl == 0 { ttl := m.ab.timeToExpiry(session)
session.Del(SessionKey) if ttl == 0 {
session.Del(SessionLastAction) session.Del(SessionKey)
} else { session.Del(SessionLastAction)
m.ab.refreshExpiry(session) } else {
m.ab.refreshExpiry(session)
}
} }
}
m.next.ServeHTTP(w, r) m.next.ServeHTTP(w, r)*/
} }

View File

@ -1,11 +1,6 @@
package authboss package authboss
import ( /* TODO(aarondl): Re-enable
"net/http"
"net/http/httptest"
"testing"
"time"
)
// These tests use the global variable nowTime so cannot be parallelized // These tests use the global variable nowTime so cannot be parallelized
@ -80,3 +75,4 @@ func TestDudeIsNotExpired(t *testing.T) {
t.Error("Expected session key:", key) t.Error("Expected session key:", key)
} }
} }
*/

View File

@ -1,6 +1,7 @@
package authboss package authboss
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"net/http" "net/http"
@ -78,35 +79,59 @@ func (m mockStoredUser) GetPassword(ctx context.Context) (password string, err e
return m.Password, nil return m.Password, nil
} }
type mockClientStoreMaker struct { type mockClientStateReadWriter struct {
store mockClientStore state mockClientState
} }
type mockClientStore map[string]string
func newMockClientStoreMaker(store mockClientStore) mockClientStoreMaker { type mockClientState map[string]string
return mockClientStoreMaker{
store: store, func newMockClientStateRW(keyValue ...string) mockClientStateReadWriter {
state := mockClientState{}
for i := 0; i < len(keyValue); i += 2 {
key, value := keyValue[i], keyValue[i+1]
state[key] = value
} }
}
func (m mockClientStoreMaker) Make(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStateReadWriter{state}
return m.store
} }
func (m mockClientStore) Get(key string) (string, bool) { func (m mockClientStateReadWriter) ReadState(w http.ResponseWriter, r *http.Request) (ClientState, error) {
v, ok := m[key] return m.state, nil
return v, ok
} }
func (m mockClientStore) GetErr(key string) (string, error) {
v, ok := m[key] func (m mockClientStateReadWriter) WriteState(w http.ResponseWriter, cs ClientState, evs []ClientStateEvent) error {
if !ok { var state mockClientState
return v, ClientDataErr{key}
if cs != nil {
state = cs.(mockClientState)
} else {
state = mockClientState{}
} }
return v, nil
}
func (m mockClientStore) Put(key, val string) { m[key] = val }
func (m mockClientStore) Del(key string) { delete(m, key) }
func mockRequest(postKeyValues ...string) *http.Request { for _, ev := range evs {
switch ev.Kind {
case ClientStateEventPut:
state[ev.Key] = ev.Value
case ClientStateEventDel:
delete(state, ev.Key)
}
}
b, err := json.Marshal(state)
if err != nil {
return err
}
w.Header().Set("test_session", string(b))
return nil
}
func (m mockClientState) Get(key string) (string, bool) {
val, ok := m[key]
return val, ok
}
func newMockRequest(postKeyValues ...string) *http.Request {
urlValues := make(url.Values) urlValues := make(url.Values)
for i := 0; i < len(postKeyValues); i += 2 { for i := 0; i < len(postKeyValues); i += 2 {
urlValues.Set(postKeyValues[i], postKeyValues[i+1]) urlValues.Set(postKeyValues[i], postKeyValues[i+1])
@ -121,6 +146,27 @@ func mockRequest(postKeyValues ...string) *http.Request {
return req return req
} }
func newMockAPIRequest(postKeyValues ...string) *http.Request {
kv := map[string]string{}
for i := 0; i < len(postKeyValues); i += 2 {
key, value := postKeyValues[i], postKeyValues[i+1]
kv[key] = value
}
b, err := json.Marshal(kv)
if err != nil {
panic(err)
}
req, err := http.NewRequest("POST", "http://localhost", bytes.NewReader(b))
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/json")
return req
}
type mockValidator struct { type mockValidator struct {
FieldName string FieldName string
Errs ErrorList Errs ErrorList

View File

@ -7,30 +7,8 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// RedirectOptions packages up all the pieces a module needs to write out a // Respond to an HTTP request. Renders templates, flash messages, does XSRF
// response. // and writes the headers out.
type RedirectOptions struct {
// Success & Failure are used to set Flash messages / JSON messages
// if set. They should be mutually exclusive.
Success string
Failure string
// Code is used when it's an API request instead of 200.
Code int
// When a request should redirect a user somewhere on completion, these
// should be set. RedirectURL tells it where to go. And optionally set
// FollowRedirParam to override the RedirectURL if the form parameter defined
// by FormValueRedirect is passed in the request.
//
// Redirecting works differently whether it's an API request or not.
// If it's an API request, then it will leave the URL in a "redirect"
// parameter.
RedirectPath string
FollowRedirParam bool
}
// Respond to an HTTP request.
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error { func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
data.MergeKV( data.MergeKV(
"xsrfName", template.HTML(a.XSRFName), "xsrfName", template.HTML(a.XSRFName),
@ -41,14 +19,13 @@ func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, tem
data.Merge(a.LayoutDataMaker(w, r)) data.Merge(a.LayoutDataMaker(w, r))
} }
session := a.SessionStoreMaker.Make(w, r) flashSuccess := FlashSuccess(w, r)
if flash, ok := session.Get(FlashSuccessKey); ok { flashError := FlashError(w, r)
session.Del(FlashSuccessKey) if len(flashSuccess) != 0 {
data.MergeKV(FlashSuccessKey, flash) data.MergeKV(FlashSuccessKey, flashSuccess)
} }
if flash, ok := session.Get(FlashErrorKey); ok { if len(flashError) != 0 {
session.Del(FlashErrorKey) data.MergeKV(FlashErrorKey, flashError)
data.MergeKV(FlashErrorKey, flash)
} }
rendered, mime, err := a.renderer.Render(r.Context(), templateName, data) rendered, mime, err := a.renderer.Render(r.Context(), templateName, data)
@ -93,6 +70,29 @@ func (a *Authboss) Email(w http.ResponseWriter, r *http.Request, email Email, ro
return a.Mailer.Send(ctx, email) return a.Mailer.Send(ctx, email)
} }
// RedirectOptions packages up all the pieces a module needs to write out a
// response.
type RedirectOptions struct {
// Success & Failure are used to set Flash messages / JSON messages
// if set. They should be mutually exclusive.
Success string
Failure string
// Code is used when it's an API request instead of 200.
Code int
// When a request should redirect a user somewhere on completion, these
// should be set. RedirectURL tells it where to go. And optionally set
// FollowRedirParam to override the RedirectURL if the form parameter defined
// by FormValueRedirect is passed in the request.
//
// Redirecting works differently whether it's an API request or not.
// If it's an API request, then it will leave the URL in a "redirect"
// parameter.
RedirectPath string
FollowRedirParam bool
}
// Redirect the client elsewhere. If it's an API request it will simply render // Redirect the client elsewhere. If it's an API request it will simply render
// a JSON response with information that should help a client to decide what // a JSON response with information that should help a client to decide what
// to do. // to do.
@ -155,12 +155,10 @@ func (a *Authboss) redirectNonAPI(w http.ResponseWriter, r *http.Request, ro Red
} }
if len(ro.Success) != 0 { if len(ro.Success) != 0 {
session := a.SessionStoreMaker.Make(w, r) PutSession(w, FlashSuccessKey, ro.Success)
session.Put(FlashSuccessKey, ro.Success)
} }
if len(ro.Failure) != 0 { if len(ro.Failure) != 0 {
session := a.SessionStoreMaker.Make(w, r) PutSession(w, FlashErrorKey, ro.Failure)
session.Put(FlashErrorKey, ro.Failure)
} }
http.Redirect(w, r, path, http.StatusFound) http.Redirect(w, r, path, http.StatusFound)

View File

@ -31,8 +31,8 @@ func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
ab.ViewLoader = mockRenderLoader{} ab.ViewLoader = mockRenderLoader{}
ab.Init(testRouterModName) ab.Init(testRouterModName)
ab.MountPath = "/prefix" ab.MountPath = "/prefix"
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{}) //ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{}) //ab.CookieStoreMaker = newMockClientStoreMaker(mockClientStore{})
logger.Reset() // Clear out the module load messages logger.Reset() // Clear out the module load messages

View File

@ -61,6 +61,10 @@ type Storer interface {
Load(ctx context.Context) error Load(ctx context.Context) error
} }
// TODO(aarondl): Document & move to Register module
// ArbitraryStorer allows arbitrary data from the web form through. You should
// definitely only pull the keys you want from the map, since this is unfiltered
// input from a web request and is an attack vector.
type ArbitraryStorer interface { type ArbitraryStorer interface {
Storer Storer

View File

@ -65,7 +65,7 @@ func TestErrorList_Map(t *testing.T) {
func TestValidate(t *testing.T) { func TestValidate(t *testing.T) {
t.Parallel() t.Parallel()
req := mockRequest(StoreUsername, "john", StoreEmail, "john@john.com") req := newMockRequest(StoreUsername, "john", StoreEmail, "john@john.com")
errList := Validate(req, []Validator{ errList := Validate(req, []Validator{
mockValidator{ mockValidator{
@ -96,19 +96,19 @@ func TestValidate(t *testing.T) {
func TestValidate_Confirm(t *testing.T) { func TestValidate_Confirm(t *testing.T) {
t.Parallel() t.Parallel()
req := mockRequest(StoreUsername, "john", "confirmUsername", "johnny") req := newMockRequest(StoreUsername, "john", "confirmUsername", "johnny")
errs := Validate(req, nil, StoreUsername, "confirmUsername").Map() errs := Validate(req, nil, StoreUsername, "confirmUsername").Map()
if errs["confirmUsername"][0] != "Does not match username" { if errs["confirmUsername"][0] != "Does not match username" {
t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0]) t.Error("Expected a different error for confirmUsername:", errs["confirmUsername"][0])
} }
req = mockRequest(StoreUsername, "john", "confirmUsername", "john") req = newMockRequest(StoreUsername, "john", "confirmUsername", "john")
errs = Validate(req, nil, StoreUsername, "confirmUsername").Map() errs = Validate(req, nil, StoreUsername, "confirmUsername").Map()
if len(errs) != 0 { if len(errs) != 0 {
t.Error("Expected no errors:", errs) t.Error("Expected no errors:", errs)
} }
req = mockRequest(StoreUsername, "john", "confirmUsername", "john") req = newMockRequest(StoreUsername, "john", "confirmUsername", "john")
errs = Validate(req, nil, StoreUsername).Map() errs = Validate(req, nil, StoreUsername).Map()
if len(errs) != 0 { if len(errs) != 0 {
t.Error("Expected no errors:", errs) t.Error("Expected no errors:", errs)