mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-08 04:03:53 +02:00
Clean up context and client state
- Remove extraneous http.ResponseWriter from all read-only queries against the request context (for the ClientState) - Instead of using a context.Context on the ClientStateResponseWriter just store variables for the things we'd like to store, it should be less expensive and it's much easier to work with and more clear. - Save the loaded client state into both the ResponseWriter itself and the Request context, the ResponseWriter will store them simply to send them into the WriteState() method later on, the Request will store them to be able to query data.
This commit is contained in:
parent
ce2d3dac09
commit
37ace55579
@ -54,7 +54,7 @@ type ClientStateEvent struct {
|
||||
type ClientStateReadWriter interface {
|
||||
// ReadState should return a map like structure allowing it to look up
|
||||
// any values in the current session, or any cookie in the request
|
||||
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
|
||||
ReadState(*http.Request) (ClientState, error)
|
||||
// WriteState can sometimes be called with a nil ClientState in the event
|
||||
// that no ClientState was recovered from the request context.
|
||||
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
|
||||
@ -81,62 +81,66 @@ type ClientState interface {
|
||||
type ClientStateResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
cookieState ClientStateReadWriter
|
||||
sessionState ClientStateReadWriter
|
||||
cookieStateRW ClientStateReadWriter
|
||||
sessionStateRW ClientStateReadWriter
|
||||
|
||||
cookieState ClientState
|
||||
sessionState ClientState
|
||||
|
||||
hasWritten bool
|
||||
ctx context.Context
|
||||
sessionStateEvents []ClientStateEvent
|
||||
cookieStateEvents []ClientStateEvent
|
||||
sessionStateEvents []ClientStateEvent
|
||||
}
|
||||
|
||||
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
|
||||
// as well as loading the current client state into the context for use.
|
||||
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
request, err := a.LoadClientState(w, r)
|
||||
writer := a.NewResponse(w)
|
||||
request, err := a.LoadClientState(writer, r)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to load client state: %+v", err))
|
||||
}
|
||||
|
||||
writer := a.NewResponse(w, request)
|
||||
|
||||
h.ServeHTTP(writer, request)
|
||||
})
|
||||
}
|
||||
|
||||
// NewResponse wraps the ResponseWriter with a ClientStateResponseWriter
|
||||
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) *ClientStateResponseWriter {
|
||||
func (a *Authboss) NewResponse(w http.ResponseWriter) *ClientStateResponseWriter {
|
||||
return &ClientStateResponseWriter{
|
||||
ResponseWriter: w,
|
||||
cookieState: a.Config.Storage.CookieState,
|
||||
sessionState: a.Config.Storage.SessionState,
|
||||
ctx: r.Context(),
|
||||
cookieStateRW: a.Config.Storage.CookieState,
|
||||
sessionStateRW: a.Config.Storage.SessionState,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadClientState loads the state from sessions and cookies into the request context
|
||||
// LoadClientState loads the state from sessions and cookies into the ResponseWriter
|
||||
// for later use.
|
||||
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
|
||||
if a.Storage.SessionState != nil {
|
||||
state, err := a.Storage.SessionState.ReadState(w, r)
|
||||
state, err := a.Storage.SessionState.ReadState(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if state == nil {
|
||||
return r, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), CTXKeySessionState, state)
|
||||
r = r.WithContext(ctx)
|
||||
c := MustClientStateResponseWriter(w)
|
||||
c.sessionState = state
|
||||
r = r.WithContext(context.WithValue(r.Context(), CTXKeySessionState, state))
|
||||
}
|
||||
if a.Storage.CookieState != nil {
|
||||
state, err := a.Storage.CookieState.ReadState(w, r)
|
||||
state, err := a.Storage.CookieState.ReadState(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if state == nil {
|
||||
return r, nil
|
||||
return nil, nil
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), CTXKeyCookieState, state)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
c := MustClientStateResponseWriter(w)
|
||||
c.cookieState = state
|
||||
r = r.WithContext(context.WithValue(r.Context(), CTXKeyCookieState, state))
|
||||
}
|
||||
|
||||
return r, nil
|
||||
@ -201,28 +205,14 @@ func (c *ClientStateResponseWriter) putClientState() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.sessionState != nil && len(c.sessionStateEvents) > 0 {
|
||||
sessionStateIntf := c.ctx.Value(CTXKeySessionState)
|
||||
|
||||
var session ClientState
|
||||
if sessionStateIntf != nil {
|
||||
session = sessionStateIntf.(ClientState)
|
||||
}
|
||||
|
||||
err := c.sessionState.WriteState(c, session, c.sessionStateEvents)
|
||||
if c.sessionStateRW != nil && len(c.sessionStateEvents) > 0 {
|
||||
err := c.sessionStateRW.WriteState(c, c.sessionState, c.sessionStateEvents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.cookieState != nil && len(c.cookieStateEvents) > 0 {
|
||||
cookieStateIntf := c.ctx.Value(CTXKeyCookieState)
|
||||
|
||||
var cookie ClientState
|
||||
if cookieStateIntf != nil {
|
||||
cookie = cookieStateIntf.(ClientState)
|
||||
}
|
||||
|
||||
err := c.cookieState.WriteState(c, cookie, c.cookieStateEvents)
|
||||
if c.cookieStateRW != nil && len(c.cookieStateEvents) > 0 {
|
||||
err := c.cookieStateRW.WriteState(c, c.cookieState, c.cookieStateEvents)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -269,7 +259,7 @@ func delState(w http.ResponseWriter, CTXKey contextKey, key string) {
|
||||
setState(w, CTXKey, ClientStateEventDel, key, "")
|
||||
}
|
||||
|
||||
func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind, key, val string) {
|
||||
func setState(w http.ResponseWriter, ctxKey contextKey, op ClientStateEventKind, key, val string) {
|
||||
csrw := MustClientStateResponseWriter(w)
|
||||
ev := ClientStateEvent{
|
||||
Kind: op,
|
||||
@ -280,7 +270,7 @@ func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind,
|
||||
ev.Value = val
|
||||
}
|
||||
|
||||
switch CTXKey {
|
||||
switch ctxKey {
|
||||
case CTXKeySessionState:
|
||||
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
|
||||
case CTXKeyCookieState:
|
||||
@ -288,8 +278,8 @@ func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind,
|
||||
}
|
||||
}
|
||||
|
||||
func getState(r *http.Request, CTXKey contextKey, key string) (string, bool) {
|
||||
val := r.Context().Value(CTXKey)
|
||||
func getState(r *http.Request, ctxKey contextKey, key string) (string, bool) {
|
||||
val := r.Context().Value(ctxKey)
|
||||
if val == nil {
|
||||
return "", false
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ func TestStateGet(t *testing.T) {
|
||||
ab.Storage.CookieState = newMockClientStateRW("three", "four")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
var err error
|
||||
r, err = ab.LoadClientState(w, r)
|
||||
@ -37,8 +37,7 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
|
||||
ab := New()
|
||||
ab.Storage.SessionState = newMockClientStateRW("one", "two")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
w.WriteHeader(200)
|
||||
// Check this doesn't panic
|
||||
@ -59,8 +58,7 @@ func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
|
||||
ab := New()
|
||||
ab.Storage.SessionState = newMockClientStateRW()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
|
||||
@ -77,8 +75,7 @@ func TestStateResponseWriterLastSecondWriteWrite(t *testing.T) {
|
||||
ab := New()
|
||||
ab.Storage.SessionState = newMockClientStateRW()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
|
||||
@ -94,8 +91,7 @@ func TestStateResponseWriterEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
PutSession(w, "one", "two")
|
||||
DelSession(w, "one")
|
||||
@ -130,14 +126,14 @@ func TestFlashClearer(t *testing.T) {
|
||||
ab.Storage.SessionState = newMockClientStateRW(FlashSuccessKey, "a", FlashErrorKey, "b")
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
|
||||
if msg := FlashSuccess(w, r); msg != "" {
|
||||
t.Error("Unexpected flash success:", msg)
|
||||
t.Error("unexpected flash success:", msg)
|
||||
}
|
||||
|
||||
if msg := FlashError(w, r); msg != "" {
|
||||
t.Error("Unexpected flash error:", msg)
|
||||
t.Error("unexpected flash error:", msg)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
34
context.go
34
context.go
@ -32,7 +32,7 @@ func (c contextKey) String() string {
|
||||
}
|
||||
|
||||
// CurrentUserID retrieves the current user from the session.
|
||||
func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
func (a *Authboss) CurrentUserID(r *http.Request) (string, error) {
|
||||
if pid := r.Context().Value(CTXKeyPID); pid != nil {
|
||||
return pid.(string), nil
|
||||
}
|
||||
@ -43,8 +43,8 @@ func (a *Authboss) CurrentUserID(w http.ResponseWriter, r *http.Request) (string
|
||||
|
||||
// CurrentUserIDP retrieves the current user but panics if it's not available for
|
||||
// any reason.
|
||||
func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string {
|
||||
i, err := a.CurrentUserID(w, r)
|
||||
func (a *Authboss) CurrentUserIDP(r *http.Request) string {
|
||||
i, err := a.CurrentUserID(r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(i) == 0 {
|
||||
@ -57,12 +57,12 @@ func (a *Authboss) CurrentUserIDP(w http.ResponseWriter, r *http.Request) string
|
||||
// CurrentUser retrieves the current user from the session and the database.
|
||||
// Before the user is loaded from the database the context key is checked.
|
||||
// If the session doesn't have the user ID ErrUserNotFound will be returned.
|
||||
func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (User, error) {
|
||||
func (a *Authboss) CurrentUser(r *http.Request) (User, error) {
|
||||
if user := r.Context().Value(CTXKeyUser); user != nil {
|
||||
return user.(User), nil
|
||||
}
|
||||
|
||||
pid, err := a.CurrentUserID(w, r)
|
||||
pid, err := a.CurrentUserID(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(pid) == 0 {
|
||||
@ -74,8 +74,8 @@ func (a *Authboss) CurrentUser(w http.ResponseWriter, r *http.Request) (User, er
|
||||
|
||||
// CurrentUserP retrieves the current user but panics if it's not available for
|
||||
// any reason.
|
||||
func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) User {
|
||||
i, err := a.CurrentUser(w, r)
|
||||
func (a *Authboss) CurrentUserP(r *http.Request) User {
|
||||
i, err := a.CurrentUser(r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if i == nil {
|
||||
@ -91,12 +91,8 @@ func (a *Authboss) currentUser(ctx context.Context, pid string) (User, error) {
|
||||
// LoadCurrentUserID takes a pointer to a pointer to the request in order to
|
||||
// change the current method's request pointer itself to the new request that
|
||||
// contains the new context that has the pid in it.
|
||||
func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (string, error) {
|
||||
if pid := (*r).Context().Value(CTXKeyPID); pid != nil {
|
||||
return pid.(string), nil
|
||||
}
|
||||
|
||||
pid, err := a.CurrentUserID(w, *r)
|
||||
func (a *Authboss) LoadCurrentUserID(r **http.Request) (string, error) {
|
||||
pid, err := a.CurrentUserID(*r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -112,8 +108,8 @@ func (a *Authboss) LoadCurrentUserID(w http.ResponseWriter, r **http.Request) (s
|
||||
}
|
||||
|
||||
// LoadCurrentUserIDP loads the current user id and panics if it's not found
|
||||
func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) string {
|
||||
pid, err := a.LoadCurrentUserID(w, r)
|
||||
func (a *Authboss) LoadCurrentUserIDP(r **http.Request) string {
|
||||
pid, err := a.LoadCurrentUserID(r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if len(pid) == 0 {
|
||||
@ -127,12 +123,12 @@ func (a *Authboss) LoadCurrentUserIDP(w http.ResponseWriter, r **http.Request) s
|
||||
// change the current method's request pointer itself to the new request that
|
||||
// contains the new context that has the user in it. Calls LoadCurrentUserID
|
||||
// so the primary id is also put in the context.
|
||||
func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (User, error) {
|
||||
func (a *Authboss) LoadCurrentUser(r **http.Request) (User, error) {
|
||||
if user := (*r).Context().Value(CTXKeyUser); user != nil {
|
||||
return user.(User), nil
|
||||
}
|
||||
|
||||
pid, err := a.LoadCurrentUserID(w, r)
|
||||
pid, err := a.LoadCurrentUserID(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -154,8 +150,8 @@ func (a *Authboss) LoadCurrentUser(w http.ResponseWriter, r **http.Request) (Use
|
||||
|
||||
// LoadCurrentUserP does the same as LoadCurrentUser but panics if
|
||||
// the current user is not found.
|
||||
func (a *Authboss) LoadCurrentUserP(w http.ResponseWriter, r **http.Request) User {
|
||||
user, err := a.LoadCurrentUser(w, r)
|
||||
func (a *Authboss) LoadCurrentUserP(r **http.Request) User {
|
||||
user, err := a.LoadCurrentUser(r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if user == nil {
|
||||
|
@ -12,6 +12,7 @@ func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *htt
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
@ -21,7 +22,9 @@ func testSetupContext() (*Authboss, *http.Request) {
|
||||
ab.Storage.Server = mockServerStorer{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil))
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder())
|
||||
r = loadClientStateP(ab, w, r)
|
||||
|
||||
return ab, r
|
||||
}
|
||||
@ -50,7 +53,7 @@ func TestCurrentUserID(t *testing.T) {
|
||||
|
||||
ab, r := testSetupContext()
|
||||
|
||||
id, err := ab.CurrentUserID(nil, r)
|
||||
id, err := ab.CurrentUserID(r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -65,7 +68,7 @@ func TestCurrentUserIDContext(t *testing.T) {
|
||||
|
||||
ab, r := testSetupContext()
|
||||
|
||||
id, err := ab.CurrentUserID(nil, r)
|
||||
id, err := ab.CurrentUserID(r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -88,7 +91,7 @@ func TestCurrentUserIDP(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
_ = ab.CurrentUserIDP(nil, httptest.NewRequest("GET", "/", nil))
|
||||
_ = ab.CurrentUserIDP(httptest.NewRequest("GET", "/", nil))
|
||||
}
|
||||
|
||||
func TestCurrentUser(t *testing.T) {
|
||||
@ -96,7 +99,7 @@ func TestCurrentUser(t *testing.T) {
|
||||
|
||||
ab, r := testSetupContext()
|
||||
|
||||
user, err := ab.CurrentUser(nil, r)
|
||||
user, err := ab.CurrentUser(r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -111,7 +114,7 @@ func TestCurrentUserContext(t *testing.T) {
|
||||
|
||||
ab, _, r := testSetupContextCached()
|
||||
|
||||
user, err := ab.CurrentUser(nil, r)
|
||||
user, err := ab.CurrentUser(r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -132,7 +135,7 @@ func TestCurrentUserP(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
_ = ab.CurrentUserP(nil, httptest.NewRequest("GET", "/", nil))
|
||||
_ = ab.CurrentUserP(httptest.NewRequest("GET", "/", nil))
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserID(t *testing.T) {
|
||||
@ -140,7 +143,7 @@ func TestLoadCurrentUserID(t *testing.T) {
|
||||
|
||||
ab, r := testSetupContext()
|
||||
|
||||
id, err := ab.LoadCurrentUserID(nil, &r)
|
||||
id, err := ab.LoadCurrentUserID(&r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -159,7 +162,7 @@ func TestLoadCurrentUserIDContext(t *testing.T) {
|
||||
|
||||
ab, _, r := testSetupContextCached()
|
||||
|
||||
pid, err := ab.LoadCurrentUserID(nil, &r)
|
||||
pid, err := ab.LoadCurrentUserID(&r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -181,7 +184,7 @@ func TestLoadCurrentUserIDP(t *testing.T) {
|
||||
}()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserIDP(nil, &r)
|
||||
_ = ab.LoadCurrentUserIDP(&r)
|
||||
}
|
||||
|
||||
func TestLoadCurrentUser(t *testing.T) {
|
||||
@ -189,7 +192,7 @@ func TestLoadCurrentUser(t *testing.T) {
|
||||
|
||||
ab, r := testSetupContext()
|
||||
|
||||
user, err := ab.LoadCurrentUser(nil, &r)
|
||||
user, err := ab.LoadCurrentUser(&r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -210,7 +213,7 @@ func TestLoadCurrentUserContext(t *testing.T) {
|
||||
|
||||
ab, wantUser, r := testSetupContextCached()
|
||||
|
||||
user, err := ab.LoadCurrentUser(nil, &r)
|
||||
user, err := ab.LoadCurrentUser(&r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -233,7 +236,7 @@ func TestLoadCurrentUserP(t *testing.T) {
|
||||
}()
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserP(nil, &r)
|
||||
_ = ab.LoadCurrentUserP(&r)
|
||||
}
|
||||
|
||||
func TestCTXKeyString(t *testing.T) {
|
||||
|
@ -65,7 +65,7 @@ func newMockClientStateRW(keyValue ...string) mockClientStateReadWriter {
|
||||
return mockClientStateReadWriter{state}
|
||||
}
|
||||
|
||||
func (m mockClientStateReadWriter) ReadState(w http.ResponseWriter, r *http.Request) (ClientState, error) {
|
||||
func (m mockClientStateReadWriter) ReadState(r *http.Request) (ClientState, error) {
|
||||
return m.state, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user