1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-08 04:03:53 +02:00

Clean up context and client state

- Remove extraneous http.ResponseWriter from all read-only queries
  against the request context (for the ClientState)
- Instead of using a context.Context on the ClientStateResponseWriter
  just store variables for the things we'd like to store, it should be
  less expensive and it's much easier to work with and more clear.
- Save the loaded client state into both the ResponseWriter itself and
  the Request context, the ResponseWriter will store them simply to send
  them into the WriteState() method later on, the Request will store
  them to be able to query data.
This commit is contained in:
Aaron L 2018-03-07 16:21:37 -08:00
parent ce2d3dac09
commit 37ace55579
5 changed files with 73 additions and 88 deletions

View File

@ -54,7 +54,7 @@ type ClientStateEvent struct {
type ClientStateReadWriter interface {
// ReadState should return a map like structure allowing it to look up
// any values in the current session, or any cookie in the request
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
ReadState(*http.Request) (ClientState, error)
// WriteState can sometimes be called with a nil ClientState in the event
// that no ClientState was recovered from the request context.
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
@ -81,62 +81,66 @@ type ClientState interface {
type ClientStateResponseWriter struct {
http.ResponseWriter
cookieState ClientStateReadWriter
sessionState ClientStateReadWriter
cookieStateRW ClientStateReadWriter
sessionStateRW ClientStateReadWriter
cookieState ClientState
sessionState ClientState
hasWritten bool
ctx context.Context
sessionStateEvents []ClientStateEvent
cookieStateEvents []ClientStateEvent
sessionStateEvents []ClientStateEvent
}
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
// as well as loading the current client state into the context for use.
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request, err := a.LoadClientState(w, r)
writer := a.NewResponse(w)
request, err := a.LoadClientState(writer, r)
if err != nil {
panic(fmt.Sprintf("failed to load client state: %+v", err))
}
writer := a.NewResponse(w, request)
h.ServeHTTP(writer, request)
})
}
// NewResponse wraps the ResponseWriter with a ClientStateResponseWriter
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) *ClientStateResponseWriter {
func (a *Authboss) NewResponse(w http.ResponseWriter) *ClientStateResponseWriter {
return &ClientStateResponseWriter{
ResponseWriter: w,
cookieState: a.Config.Storage.CookieState,
sessionState: a.Config.Storage.SessionState,
ctx: r.Context(),
cookieStateRW: a.Config.Storage.CookieState,
sessionStateRW: a.Config.Storage.SessionState,
}
}
// LoadClientState loads the state from sessions and cookies into the request context
// LoadClientState loads the state from sessions and cookies into the ResponseWriter
// for later use.
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
if a.Storage.SessionState != nil {
state, err := a.Storage.SessionState.ReadState(w, r)
state, err := a.Storage.SessionState.ReadState(r)
if err != nil {
return nil, err
} else if state == nil {
return r, nil
return nil, nil
}
ctx := context.WithValue(r.Context(), CTXKeySessionState, state)
r = r.WithContext(ctx)
c := MustClientStateResponseWriter(w)
c.sessionState = state
r = r.WithContext(context.WithValue(r.Context(), CTXKeySessionState, state))
}
if a.Storage.CookieState != nil {
state, err := a.Storage.CookieState.ReadState(w, r)
state, err := a.Storage.CookieState.ReadState(r)
if err != nil {
return nil, err
} else if state == nil {
return r, nil
return nil, nil
}
ctx := context.WithValue(r.Context(), CTXKeyCookieState, state)
r = r.WithContext(ctx)
c := MustClientStateResponseWriter(w)
c.cookieState = state
r = r.WithContext(context.WithValue(r.Context(), CTXKeyCookieState, state))
}
return r, nil
@ -201,28 +205,14 @@ func (c *ClientStateResponseWriter) putClientState() error {
return nil
}
if c.sessionState != nil && len(c.sessionStateEvents) > 0 {
sessionStateIntf := c.ctx.Value(CTXKeySessionState)
var session ClientState
if sessionStateIntf != nil {
session = sessionStateIntf.(ClientState)
}
err := c.sessionState.WriteState(c, session, c.sessionStateEvents)
if c.sessionStateRW != nil && len(c.sessionStateEvents) > 0 {
err := c.sessionStateRW.WriteState(c, c.sessionState, c.sessionStateEvents)
if err != nil {
return err
}
}
if c.cookieState != nil && len(c.cookieStateEvents) > 0 {
cookieStateIntf := c.ctx.Value(CTXKeyCookieState)
var cookie ClientState
if cookieStateIntf != nil {
cookie = cookieStateIntf.(ClientState)
}
err := c.cookieState.WriteState(c, cookie, c.cookieStateEvents)
if c.cookieStateRW != nil && len(c.cookieStateEvents) > 0 {
err := c.cookieStateRW.WriteState(c, c.cookieState, c.cookieStateEvents)
if err != nil {
return err
}
@ -269,7 +259,7 @@ func delState(w http.ResponseWriter, CTXKey contextKey, key string) {
setState(w, CTXKey, ClientStateEventDel, key, "")
}
func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind, key, val string) {
func setState(w http.ResponseWriter, ctxKey contextKey, op ClientStateEventKind, key, val string) {
csrw := MustClientStateResponseWriter(w)
ev := ClientStateEvent{
Kind: op,
@ -280,7 +270,7 @@ func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind,
ev.Value = val
}
switch CTXKey {
switch ctxKey {
case CTXKeySessionState:
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
case CTXKeyCookieState:
@ -288,8 +278,8 @@ func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind,
}
}
func getState(r *http.Request, CTXKey contextKey, key string) (string, bool) {
val := r.Context().Value(CTXKey)
func getState(r *http.Request, ctxKey contextKey, key string) (string, bool) {
val := r.Context().Value(ctxKey)
if val == nil {
return "", false
}

View File

@ -15,7 +15,7 @@ func TestStateGet(t *testing.T) {
ab.Storage.CookieState = newMockClientStateRW("three", "four")
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
w := ab.NewResponse(httptest.NewRecorder())
var err error
r, err = ab.LoadClientState(w, r)
@ -37,8 +37,7 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
ab := New()
ab.Storage.SessionState = newMockClientStateRW("one", "two")
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
w := ab.NewResponse(httptest.NewRecorder())
w.WriteHeader(200)
// Check this doesn't panic
@ -59,8 +58,7 @@ func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
ab := New()
ab.Storage.SessionState = newMockClientStateRW()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
w := ab.NewResponse(httptest.NewRecorder())
PutSession(w, "one", "two")
@ -77,8 +75,7 @@ func TestStateResponseWriterLastSecondWriteWrite(t *testing.T) {
ab := New()
ab.Storage.SessionState = newMockClientStateRW()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
w := ab.NewResponse(httptest.NewRecorder())
PutSession(w, "one", "two")
@ -94,8 +91,7 @@ func TestStateResponseWriterEvents(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
w := ab.NewResponse(httptest.NewRecorder())
PutSession(w, "one", "two")
DelSession(w, "one")
@ -130,14 +126,14 @@ func TestFlashClearer(t *testing.T) {
ab.Storage.SessionState = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b")
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
w := ab.NewResponse(httptest.NewRecorder())
if msg := FlashSuccess(w, r); msg != "" {
t.Error("Unexpected flash success:", msg)
t.Error("unexpected flash success:", msg)
}
if msg := FlashError(w, r); msg != "" {
t.Error("Unexpected flash error:", msg)
t.Error("unexpected flash error:", msg)
}
var err error

View File

@ -32,7 +32,7 @@ func (c contextKey) String() string {
}
// CurrentUserID retrieves the current user from the session.
func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string, error) {
func (a *Authboss) CurrentUserID(r *http.Request) (string, error) {
if pid := r.Context().Value(CTXKeyPID); pid != nil {
return pid.(string), nil
}
@ -43,8 +43,8 @@ func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string
// CurrentUserIDP retrieves the current user but panics if it's not available for
// any reason.
func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string {
i, err := a.CurrentUserID(w, r)
func (a *Authboss) CurrentUserIDP(r *http.Request) string {
i, err := a.CurrentUserID(r)
if err != nil {
panic(err)
} else if len(i) == 0 {
@ -57,12 +57,12 @@ func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string
// CurrentUser retrieves the current user from the session and the database.
// Before the user is loaded from the database the context key is checked.
// If the session doesn't have the user ID ErrUserNotFound will be returned.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (User, error) {
func (a *Authboss) CurrentUser(r *http.Request) (User, error) {
if user := r.Context().Value(CTXKeyUser); user != nil {
return user.(User), nil
}
pid, err := a.CurrentUserID(w, r)
pid, err := a.CurrentUserID(r)
if err != nil {
return nil, err
} else if len(pid) == 0 {
@ -74,8 +74,8 @@ func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (User, er
// CurrentUserP retrieves the current user but panics if it's not available for
// any reason.
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) User {
i, err := a.CurrentUser(w, r)
func (a *Authboss) CurrentUserP(r *http.Request) User {
i, err := a.CurrentUser(r)
if err != nil {
panic(err)
} else if i == nil {
@ -91,12 +91,8 @@ func (a *Authboss) currentUser(ctx context.Context, pid string) (User, error) {
// LoadCurrentUserID takes a pointer to a pointer to the request in order to
// change the current method's request pointer itself to the new request that
// contains the new context that has the pid in it.
func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (string, error) {
if pid := (*r).Context().Value(CTXKeyPID); pid != nil {
return pid.(string), nil
}
pid, err := a.CurrentUserID(w, *r)
func (a *Authboss) LoadCurrentUserID(r **http.Request) (string, error) {
pid, err := a.CurrentUserID(*r)
if err != nil {
return "", err
}
@ -112,8 +108,8 @@ func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (s
}
// LoadCurrentUserIDP loads the current user id and panics if it's not found
func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) string {
pid, err := a.LoadCurrentUserID(w, r)
func (a *Authboss) LoadCurrentUserIDP(r **http.Request) string {
pid, err := a.LoadCurrentUserID(r)
if err != nil {
panic(err)
} else if len(pid) == 0 {
@ -127,12 +123,12 @@ func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) s
// change the current method's request pointer itself to the new request that
// contains the new context that has the user in it. Calls LoadCurrentUserID
// so the primary id is also put in the context.
func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (User, error) {
func (a *Authboss) LoadCurrentUser(r **http.Request) (User, error) {
if user := (*r).Context().Value(CTXKeyUser); user != nil {
return user.(User), nil
}
pid, err := a.LoadCurrentUserID(w, r)
pid, err := a.LoadCurrentUserID(r)
if err != nil {
return nil, err
}
@ -154,8 +150,8 @@ func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (Use
// LoadCurrentUserP does the same as LoadCurrentUser but panics if
// the current user is not found.
func (a *Authboss) LoadCurrentUserP(w http.ResponseWriter, r **http.Request) User {
user, err := a.LoadCurrentUser(w, r)
func (a *Authboss) LoadCurrentUserP(r **http.Request) User {
user, err := a.LoadCurrentUser(r)
if err != nil {
panic(err)
} else if user == nil {

View File

@ -12,6 +12,7 @@ func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *htt
if err != nil {
panic(err)
}
return r
}
@ -21,7 +22,9 @@ func testSetupContext() (*Authboss, *http.Request) {
ab.Storage.Server = mockServerStorer{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil))
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder())
r = loadClientStateP(ab, w, r)
return ab, r
}
@ -50,7 +53,7 @@ func TestCurrentUserID(t *testing.T) {
ab, r := testSetupContext()
id, err := ab.CurrentUserID(nil, r)
id, err := ab.CurrentUserID(r)
if err != nil {
t.Error(err)
}
@ -65,7 +68,7 @@ func TestCurrentUserIDContext(t *testing.T) {
ab, r := testSetupContext()
id, err := ab.CurrentUserID(nil, r)
id, err := ab.CurrentUserID(r)
if err != nil {
t.Error(err)
}
@ -88,7 +91,7 @@ func TestCurrentUserIDP(t *testing.T) {
}
}()
_ = ab.CurrentUserIDP(nil, httptest.NewRequest("GET", "/", nil))
_ = ab.CurrentUserIDP(httptest.NewRequest("GET", "/", nil))
}
func TestCurrentUser(t *testing.T) {
@ -96,7 +99,7 @@ func TestCurrentUser(t *testing.T) {
ab, r := testSetupContext()
user, err := ab.CurrentUser(nil, r)
user, err := ab.CurrentUser(r)
if err != nil {
t.Error(err)
}
@ -111,7 +114,7 @@ func TestCurrentUserContext(t *testing.T) {
ab, _, r := testSetupContextCached()
user, err := ab.CurrentUser(nil, r)
user, err := ab.CurrentUser(r)
if err != nil {
t.Error(err)
}
@ -132,7 +135,7 @@ func TestCurrentUserP(t *testing.T) {
}
}()
_ = ab.CurrentUserP(nil, httptest.NewRequest("GET", "/", nil))
_ = ab.CurrentUserP(httptest.NewRequest("GET", "/", nil))
}
func TestLoadCurrentUserID(t *testing.T) {
@ -140,7 +143,7 @@ func TestLoadCurrentUserID(t *testing.T) {
ab, r := testSetupContext()
id, err := ab.LoadCurrentUserID(nil, &r)
id, err := ab.LoadCurrentUserID(&r)
if err != nil {
t.Error(err)
}
@ -159,7 +162,7 @@ func TestLoadCurrentUserIDContext(t *testing.T) {
ab, _, r := testSetupContextCached()
pid, err := ab.LoadCurrentUserID(nil, &r)
pid, err := ab.LoadCurrentUserID(&r)
if err != nil {
t.Error(err)
}
@ -181,7 +184,7 @@ func TestLoadCurrentUserIDP(t *testing.T) {
}()
r := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserIDP(nil, &r)
_ = ab.LoadCurrentUserIDP(&r)
}
func TestLoadCurrentUser(t *testing.T) {
@ -189,7 +192,7 @@ func TestLoadCurrentUser(t *testing.T) {
ab, r := testSetupContext()
user, err := ab.LoadCurrentUser(nil, &r)
user, err := ab.LoadCurrentUser(&r)
if err != nil {
t.Error(err)
}
@ -210,7 +213,7 @@ func TestLoadCurrentUserContext(t *testing.T) {
ab, wantUser, r := testSetupContextCached()
user, err := ab.LoadCurrentUser(nil, &r)
user, err := ab.LoadCurrentUser(&r)
if err != nil {
t.Error(err)
}
@ -233,7 +236,7 @@ func TestLoadCurrentUserP(t *testing.T) {
}()
r := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserP(nil, &r)
_ = ab.LoadCurrentUserP(&r)
}
func TestCTXKeyString(t *testing.T) {

View File

@ -65,7 +65,7 @@ func newMockClientStateRW(keyValue ...string) mockClientStateReadWriter {
return mockClientStateReadWriter{state}
}
func (m mockClientStateReadWriter) ReadState(w http.ResponseWriter, r *http.Request) (ClientState, error) {
func (m mockClientStateReadWriter) ReadState(r *http.Request) (ClientState, error) {
return m.state, nil
}