1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-24 05:17:10 +02:00

Clean up context and client state

- Remove extraneous http.ResponseWriter from all read-only queries
  against the request context (for the ClientState)
- Instead of using a context.Context on the ClientStateResponseWriter
  just store variables for the things we'd like to store, it should be
  less expensive and it's much easier to work with and more clear.
- Save the loaded client state into both the ResponseWriter itself and
  the Request context, the ResponseWriter will store them simply to send
  them into the WriteState() method later on, the Request will store
  them to be able to query data.
This commit is contained in:
Aaron L 2018-03-07 16:21:37 -08:00
parent ce2d3dac09
commit 37ace55579
5 changed files with 73 additions and 88 deletions

View File

@ -54,7 +54,7 @@ type ClientStateEvent struct {
type ClientStateReadWriter interface { type ClientStateReadWriter interface {
// ReadState should return a map like structure allowing it to look up // ReadState should return a map like structure allowing it to look up
// any values in the current session, or any cookie in the request // any values in the current session, or any cookie in the request
ReadState(http.ResponseWriter, *http.Request) (ClientState, error) ReadState(*http.Request) (ClientState, error)
// WriteState can sometimes be called with a nil ClientState in the event // WriteState can sometimes be called with a nil ClientState in the event
// that no ClientState was recovered from the request context. // that no ClientState was recovered from the request context.
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
@ -81,62 +81,66 @@ type ClientState interface {
type ClientStateResponseWriter struct { type ClientStateResponseWriter struct {
http.ResponseWriter http.ResponseWriter
cookieState ClientStateReadWriter cookieStateRW ClientStateReadWriter
sessionState ClientStateReadWriter sessionStateRW ClientStateReadWriter
cookieState ClientState
sessionState ClientState
hasWritten bool hasWritten bool
ctx context.Context
sessionStateEvents []ClientStateEvent
cookieStateEvents []ClientStateEvent cookieStateEvents []ClientStateEvent
sessionStateEvents []ClientStateEvent
} }
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter // LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
// as well as loading the current client state into the context for use. // as well as loading the current client state into the context for use.
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler { func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request, err := a.LoadClientState(w, r) writer := a.NewResponse(w)
request, err := a.LoadClientState(writer, r)
if err != nil { if err != nil {
panic(fmt.Sprintf("failed to load client state: %+v", err)) panic(fmt.Sprintf("failed to load client state: %+v", err))
} }
writer := a.NewResponse(w, request)
h.ServeHTTP(writer, request) h.ServeHTTP(writer, request)
}) })
} }
// NewResponse wraps the ResponseWriter with a ClientStateResponseWriter // NewResponse wraps the ResponseWriter with a ClientStateResponseWriter
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) *ClientStateResponseWriter { func (a *Authboss) NewResponse(w http.ResponseWriter) *ClientStateResponseWriter {
return &ClientStateResponseWriter{ return &ClientStateResponseWriter{
ResponseWriter: w, ResponseWriter: w,
cookieState: a.Config.Storage.CookieState, cookieStateRW: a.Config.Storage.CookieState,
sessionState: a.Config.Storage.SessionState, sessionStateRW: a.Config.Storage.SessionState,
ctx: r.Context(),
} }
} }
// LoadClientState loads the state from sessions and cookies into the request context // LoadClientState loads the state from sessions and cookies into the ResponseWriter
// for later use.
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) { func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
if a.Storage.SessionState != nil { if a.Storage.SessionState != nil {
state, err := a.Storage.SessionState.ReadState(w, r) state, err := a.Storage.SessionState.ReadState(r)
if err != nil { if err != nil {
return nil, err return nil, err
} else if state == nil { } else if state == nil {
return r, nil return nil, nil
} }
ctx := context.WithValue(r.Context(), CTXKeySessionState, state) c := MustClientStateResponseWriter(w)
r = r.WithContext(ctx) c.sessionState = state
r = r.WithContext(context.WithValue(r.Context(), CTXKeySessionState, state))
} }
if a.Storage.CookieState != nil { if a.Storage.CookieState != nil {
state, err := a.Storage.CookieState.ReadState(w, r) state, err := a.Storage.CookieState.ReadState(r)
if err != nil { if err != nil {
return nil, err return nil, err
} else if state == nil { } else if state == nil {
return r, nil return nil, nil
} }
ctx := context.WithValue(r.Context(), CTXKeyCookieState, state)
r = r.WithContext(ctx) c := MustClientStateResponseWriter(w)
c.cookieState = state
r = r.WithContext(context.WithValue(r.Context(), CTXKeyCookieState, state))
} }
return r, nil return r, nil
@ -201,28 +205,14 @@ func (c *ClientStateResponseWriter) putClientState() error {
return nil return nil
} }
if c.sessionState != nil && len(c.sessionStateEvents) > 0 { if c.sessionStateRW != nil && len(c.sessionStateEvents) > 0 {
sessionStateIntf := c.ctx.Value(CTXKeySessionState) err := c.sessionStateRW.WriteState(c, c.sessionState, c.sessionStateEvents)
var session ClientState
if sessionStateIntf != nil {
session = sessionStateIntf.(ClientState)
}
err := c.sessionState.WriteState(c, session, c.sessionStateEvents)
if err != nil { if err != nil {
return err return err
} }
} }
if c.cookieState != nil && len(c.cookieStateEvents) > 0 { if c.cookieStateRW != nil && len(c.cookieStateEvents) > 0 {
cookieStateIntf := c.ctx.Value(CTXKeyCookieState) err := c.cookieStateRW.WriteState(c, c.cookieState, c.cookieStateEvents)
var cookie ClientState
if cookieStateIntf != nil {
cookie = cookieStateIntf.(ClientState)
}
err := c.cookieState.WriteState(c, cookie, c.cookieStateEvents)
if err != nil { if err != nil {
return err return err
} }
@ -269,7 +259,7 @@ func delState(w http.ResponseWriter, CTXKey contextKey, key string) {
setState(w, CTXKey, ClientStateEventDel, key, "") setState(w, CTXKey, ClientStateEventDel, key, "")
} }
func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind, key, val string) { func setState(w http.ResponseWriter, ctxKey contextKey, op ClientStateEventKind, key, val string) {
csrw := MustClientStateResponseWriter(w) csrw := MustClientStateResponseWriter(w)
ev := ClientStateEvent{ ev := ClientStateEvent{
Kind: op, Kind: op,
@ -280,7 +270,7 @@ func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind,
ev.Value = val ev.Value = val
} }
switch CTXKey { switch ctxKey {
case CTXKeySessionState: case CTXKeySessionState:
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev) csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
case CTXKeyCookieState: case CTXKeyCookieState:
@ -288,8 +278,8 @@ func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind,
} }
} }
func getState(r *http.Request, CTXKey contextKey, key string) (string, bool) { func getState(r *http.Request, ctxKey contextKey, key string) (string, bool) {
val := r.Context().Value(CTXKey) val := r.Context().Value(ctxKey)
if val == nil { if val == nil {
return "", false return "", false
} }

View File

@ -15,7 +15,7 @@ func TestStateGet(t *testing.T) {
ab.Storage.CookieState = newMockClientStateRW("three", "four") ab.Storage.CookieState = newMockClientStateRW("three", "four")
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r) w := ab.NewResponse(httptest.NewRecorder())
var err error var err error
r, err = ab.LoadClientState(w, r) r, err = ab.LoadClientState(w, r)
@ -37,8 +37,7 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
ab := New() ab := New()
ab.Storage.SessionState = newMockClientStateRW("one", "two") ab.Storage.SessionState = newMockClientStateRW("one", "two")
r := httptest.NewRequest("GET", "/", nil) w := ab.NewResponse(httptest.NewRecorder())
w := ab.NewResponse(httptest.NewRecorder(), r)
w.WriteHeader(200) w.WriteHeader(200)
// Check this doesn't panic // Check this doesn't panic
@ -59,8 +58,7 @@ func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
ab := New() ab := New()
ab.Storage.SessionState = newMockClientStateRW() ab.Storage.SessionState = newMockClientStateRW()
r := httptest.NewRequest("GET", "/", nil) w := ab.NewResponse(httptest.NewRecorder())
w := ab.NewResponse(httptest.NewRecorder(), r)
PutSession(w, "one", "two") PutSession(w, "one", "two")
@ -77,8 +75,7 @@ func TestStateResponseWriterLastSecondWriteWrite(t *testing.T) {
ab := New() ab := New()
ab.Storage.SessionState = newMockClientStateRW() ab.Storage.SessionState = newMockClientStateRW()
r := httptest.NewRequest("GET", "/", nil) w := ab.NewResponse(httptest.NewRecorder())
w := ab.NewResponse(httptest.NewRecorder(), r)
PutSession(w, "one", "two") PutSession(w, "one", "two")
@ -94,8 +91,7 @@ func TestStateResponseWriterEvents(t *testing.T) {
t.Parallel() t.Parallel()
ab := New() ab := New()
r := httptest.NewRequest("GET", "/", nil) w := ab.NewResponse(httptest.NewRecorder())
w := ab.NewResponse(httptest.NewRecorder(), r)
PutSession(w, "one", "two") PutSession(w, "one", "two")
DelSession(w, "one") DelSession(w, "one")
@ -130,14 +126,14 @@ func TestFlashClearer(t *testing.T) {
ab.Storage.SessionState = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b") ab.Storage.SessionState = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b")
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r) w := ab.NewResponse(httptest.NewRecorder())
if msg := FlashSuccess(w, r); msg != "" { if msg := FlashSuccess(w, r); msg != "" {
t.Error("Unexpected flash success:", msg) t.Error("unexpected flash success:", msg)
} }
if msg := FlashError(w, r); msg != "" { if msg := FlashError(w, r); msg != "" {
t.Error("Unexpected flash error:", msg) t.Error("unexpected flash error:", msg)
} }
var err error var err error

View File

@ -32,7 +32,7 @@ func (c contextKey) String() string {
} }
// CurrentUserID retrieves the current user from the session. // CurrentUserID retrieves the current user from the session.
func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string, error) { func (a *Authboss) CurrentUserID(r *http.Request) (string, error) {
if pid := r.Context().Value(CTXKeyPID); pid != nil { if pid := r.Context().Value(CTXKeyPID); pid != nil {
return pid.(string), nil return pid.(string), nil
} }
@ -43,8 +43,8 @@ func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string
// CurrentUserIDP retrieves the current user but panics if it's not available for // CurrentUserIDP retrieves the current user but panics if it's not available for
// any reason. // any reason.
func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string { func (a *Authboss) CurrentUserIDP(r *http.Request) string {
i, err := a.CurrentUserID(w, r) i, err := a.CurrentUserID(r)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(i) == 0 { } else if len(i) == 0 {
@ -57,12 +57,12 @@ func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string
// CurrentUser retrieves the current user from the session and the database. // CurrentUser retrieves the current user from the session and the database.
// Before the user is loaded from the database the context key is checked. // Before the user is loaded from the database the context key is checked.
// If the session doesn't have the user ID ErrUserNotFound will be returned. // If the session doesn't have the user ID ErrUserNotFound will be returned.
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (User, error) { func (a *Authboss) CurrentUser(r *http.Request) (User, error) {
if user := r.Context().Value(CTXKeyUser); user != nil { if user := r.Context().Value(CTXKeyUser); user != nil {
return user.(User), nil return user.(User), nil
} }
pid, err := a.CurrentUserID(w, r) pid, err := a.CurrentUserID(r)
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(pid) == 0 { } else if len(pid) == 0 {
@ -74,8 +74,8 @@ func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (User, er
// CurrentUserP retrieves the current user but panics if it's not available for // CurrentUserP retrieves the current user but panics if it's not available for
// any reason. // any reason.
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) User { func (a *Authboss) CurrentUserP(r *http.Request) User {
i, err := a.CurrentUser(w, r) i, err := a.CurrentUser(r)
if err != nil { if err != nil {
panic(err) panic(err)
} else if i == nil { } else if i == nil {
@ -91,12 +91,8 @@ func (a *Authboss) currentUser(ctx context.Context, pid string) (User, error) {
// LoadCurrentUserID takes a pointer to a pointer to the request in order to // LoadCurrentUserID takes a pointer to a pointer to the request in order to
// change the current method's request pointer itself to the new request that // change the current method's request pointer itself to the new request that
// contains the new context that has the pid in it. // contains the new context that has the pid in it.
func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (string, error) { func (a *Authboss) LoadCurrentUserID(r **http.Request) (string, error) {
if pid := (*r).Context().Value(CTXKeyPID); pid != nil { pid, err := a.CurrentUserID(*r)
return pid.(string), nil
}
pid, err := a.CurrentUserID(w, *r)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -112,8 +108,8 @@ func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (s
} }
// LoadCurrentUserIDP loads the current user id and panics if it's not found // LoadCurrentUserIDP loads the current user id and panics if it's not found
func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) string { func (a *Authboss) LoadCurrentUserIDP(r **http.Request) string {
pid, err := a.LoadCurrentUserID(w, r) pid, err := a.LoadCurrentUserID(r)
if err != nil { if err != nil {
panic(err) panic(err)
} else if len(pid) == 0 { } else if len(pid) == 0 {
@ -127,12 +123,12 @@ func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) s
// change the current method's request pointer itself to the new request that // change the current method's request pointer itself to the new request that
// contains the new context that has the user in it. Calls LoadCurrentUserID // contains the new context that has the user in it. Calls LoadCurrentUserID
// so the primary id is also put in the context. // so the primary id is also put in the context.
func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (User, error) { func (a *Authboss) LoadCurrentUser(r **http.Request) (User, error) {
if user := (*r).Context().Value(CTXKeyUser); user != nil { if user := (*r).Context().Value(CTXKeyUser); user != nil {
return user.(User), nil return user.(User), nil
} }
pid, err := a.LoadCurrentUserID(w, r) pid, err := a.LoadCurrentUserID(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -154,8 +150,8 @@ func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (Use
// LoadCurrentUserP does the same as LoadCurrentUser but panics if // LoadCurrentUserP does the same as LoadCurrentUser but panics if
// the current user is not found. // the current user is not found.
func (a *Authboss) LoadCurrentUserP(w http.ResponseWriter, r **http.Request) User { func (a *Authboss) LoadCurrentUserP(r **http.Request) User {
user, err := a.LoadCurrentUser(w, r) user, err := a.LoadCurrentUser(r)
if err != nil { if err != nil {
panic(err) panic(err)
} else if user == nil { } else if user == nil {

View File

@ -12,6 +12,7 @@ func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *htt
if err != nil { if err != nil {
panic(err) panic(err)
} }
return r return r
} }
@ -21,7 +22,9 @@ func testSetupContext() (*Authboss, *http.Request) {
ab.Storage.Server = mockServerStorer{ ab.Storage.Server = mockServerStorer{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"}, "george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
} }
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil)) r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder())
r = loadClientStateP(ab, w, r)
return ab, r return ab, r
} }
@ -50,7 +53,7 @@ func TestCurrentUserID(t *testing.T) {
ab, r := testSetupContext() ab, r := testSetupContext()
id, err := ab.CurrentUserID(nil, r) id, err := ab.CurrentUserID(r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -65,7 +68,7 @@ func TestCurrentUserIDContext(t *testing.T) {
ab, r := testSetupContext() ab, r := testSetupContext()
id, err := ab.CurrentUserID(nil, r) id, err := ab.CurrentUserID(r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -88,7 +91,7 @@ func TestCurrentUserIDP(t *testing.T) {
} }
}() }()
_ = ab.CurrentUserIDP(nil, httptest.NewRequest("GET", "/", nil)) _ = ab.CurrentUserIDP(httptest.NewRequest("GET", "/", nil))
} }
func TestCurrentUser(t *testing.T) { func TestCurrentUser(t *testing.T) {
@ -96,7 +99,7 @@ func TestCurrentUser(t *testing.T) {
ab, r := testSetupContext() ab, r := testSetupContext()
user, err := ab.CurrentUser(nil, r) user, err := ab.CurrentUser(r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -111,7 +114,7 @@ func TestCurrentUserContext(t *testing.T) {
ab, _, r := testSetupContextCached() ab, _, r := testSetupContextCached()
user, err := ab.CurrentUser(nil, r) user, err := ab.CurrentUser(r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -132,7 +135,7 @@ func TestCurrentUserP(t *testing.T) {
} }
}() }()
_ = ab.CurrentUserP(nil, httptest.NewRequest("GET", "/", nil)) _ = ab.CurrentUserP(httptest.NewRequest("GET", "/", nil))
} }
func TestLoadCurrentUserID(t *testing.T) { func TestLoadCurrentUserID(t *testing.T) {
@ -140,7 +143,7 @@ func TestLoadCurrentUserID(t *testing.T) {
ab, r := testSetupContext() ab, r := testSetupContext()
id, err := ab.LoadCurrentUserID(nil, &r) id, err := ab.LoadCurrentUserID(&r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -159,7 +162,7 @@ func TestLoadCurrentUserIDContext(t *testing.T) {
ab, _, r := testSetupContextCached() ab, _, r := testSetupContextCached()
pid, err := ab.LoadCurrentUserID(nil, &r) pid, err := ab.LoadCurrentUserID(&r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -181,7 +184,7 @@ func TestLoadCurrentUserIDP(t *testing.T) {
}() }()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserIDP(nil, &r) _ = ab.LoadCurrentUserIDP(&r)
} }
func TestLoadCurrentUser(t *testing.T) { func TestLoadCurrentUser(t *testing.T) {
@ -189,7 +192,7 @@ func TestLoadCurrentUser(t *testing.T) {
ab, r := testSetupContext() ab, r := testSetupContext()
user, err := ab.LoadCurrentUser(nil, &r) user, err := ab.LoadCurrentUser(&r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -210,7 +213,7 @@ func TestLoadCurrentUserContext(t *testing.T) {
ab, wantUser, r := testSetupContextCached() ab, wantUser, r := testSetupContextCached()
user, err := ab.LoadCurrentUser(nil, &r) user, err := ab.LoadCurrentUser(&r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -233,7 +236,7 @@ func TestLoadCurrentUserP(t *testing.T) {
}() }()
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserP(nil, &r) _ = ab.LoadCurrentUserP(&r)
} }
func TestCTXKeyString(t *testing.T) { func TestCTXKeyString(t *testing.T) {

View File

@ -65,7 +65,7 @@ func newMockClientStateRW(keyValue ...string) mockClientStateReadWriter {
return mockClientStateReadWriter{state} return mockClientStateReadWriter{state}
} }
func (m mockClientStateReadWriter) ReadState(w http.ResponseWriter, r *http.Request) (ClientState, error) { func (m mockClientStateReadWriter) ReadState(r *http.Request) (ClientState, error) {
return m.state, nil return m.state, nil
} }