1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-24 05:17:10 +02:00

Fix tests back up again.

- Remove a test that was obsoleted by optimizations. Not 100% sure this
  is correct, but it seems like if nothing has changed since the
  previous session/cookie read then we shouldn't need to write any new
  headers for them. This is especially true in the typical "I use
  cookies for everything" use case, but may not be true of other use
  cases... Remains to be seen. Since they're optimizations they should
  be able to removed "safely" later.
This commit is contained in:
Aaron L 2018-02-14 15:16:44 -08:00
parent f585b35cfb
commit 2137c827d3
4 changed files with 23 additions and 71 deletions

View File

@ -90,8 +90,9 @@ type ClientStateResponseWriter struct {
cookieStateEvents []ClientStateEvent cookieStateEvents []ClientStateEvent
} }
// ClientStateMiddleware wraps all requests with the ClientStateResponseWriter // LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
func (a *Authboss) ClientStateMiddleware(h http.Handler) http.Handler { // as well as loading the current client state into the context for use.
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request, err := a.LoadClientState(w, r) request, err := a.LoadClientState(w, r)
if err != nil { if err != nil {
@ -174,7 +175,8 @@ func (c ClientStateResponseWriter) Header() http.Header {
return c.ResponseWriter.Header() return c.ResponseWriter.Header()
} }
// Write ensures that the // Write ensures that the client state is written before any writes
// to the body occur (before header flush to http client)
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) { func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
if !c.hasWritten { if !c.hasWritten {
if err := c.putClientState(); err != nil { if err := c.putClientState(); err != nil {
@ -184,7 +186,7 @@ func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
return c.ResponseWriter.Write(b) return c.ResponseWriter.Write(b)
} }
// UnderlyingResponseWriter for this isnstance // UnderlyingResponseWriter for this instance
func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter { func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return c.ResponseWriter return c.ResponseWriter
} }

View File

@ -2,7 +2,6 @@ package authboss
import ( import (
"io" "io"
"net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
@ -54,33 +53,6 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
w.putClientState() w.putClientState()
} }
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
t.Parallel()
ab := New()
ab.Storage.SessionState = newMockClientStateRW("one", "two")
ab.Storage.CookieState = newMockClientStateRW("three", "four")
r := httptest.NewRequest("GET", "/", nil)
var w http.ResponseWriter = httptest.NewRecorder()
var err error
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
w = ab.NewResponse(w, r)
w.WriteHeader(200)
// This is an odd test, since the mock will always overwrite the previous
// write with the cookie values. Keeping it anyway for code coverage
got := strings.TrimSpace(w.Header().Get("test_session"))
if got != `{"three":"four"}` {
t.Error("got:", got)
}
}
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) { func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -30,8 +30,8 @@ func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
ab := New() ab := New()
wantUser := mockUser{Email: "george-pid", Password: "unreadable"} wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid") ctx := context.WithValue(req.Context(), CTXKeyPID, "george-pid")
ctx = context.WithValue(ctx, ctxKeyUser, wantUser) ctx = context.WithValue(ctx, CTXKeyUser, wantUser)
req = req.WithContext(ctx) req = req.WithContext(ctx)
return ab, wantUser, req return ab, wantUser, req
@ -101,9 +101,7 @@ func TestCurrentUser(t *testing.T) {
t.Error(err) t.Error(err)
} }
if got, err := user.GetPID(context.TODO()); err != nil { if got := user.GetPID(context.TODO()); got != "george-pid" {
t.Error(err)
} else if got != "george-pid" {
t.Error("got:", got) t.Error("got:", got)
} }
} }
@ -118,9 +116,7 @@ func TestCurrentUserContext(t *testing.T) {
t.Error(err) t.Error(err)
} }
if got, err := user.GetPID(context.TODO()); err != nil { if got := user.GetPID(context.TODO()); got != "george-pid" {
t.Error(err)
} else if got != "george-pid" {
t.Error("got:", got) t.Error("got:", got)
} }
} }
@ -153,7 +149,7 @@ func TestLoadCurrentUserID(t *testing.T) {
t.Error("got:", id) t.Error("got:", id)
} }
if r.Context().Value(ctxKeyPID).(string) != "george-pid" { if r.Context().Value(CTXKeyPID).(string) != "george-pid" {
t.Error("context was not updated in local request") t.Error("context was not updated in local request")
} }
} }
@ -198,14 +194,12 @@ func TestLoadCurrentUser(t *testing.T) {
t.Error(err) t.Error(err)
} }
if got, err := user.GetPID(context.TODO()); err != nil { if got := user.GetPID(context.TODO()); got != "george-pid" {
t.Error(err)
} else if got != "george-pid" {
t.Error("got:", got) t.Error("got:", got)
} }
want := user.(mockUser) want := user.(mockUser)
got := r.Context().Value(ctxKeyUser).(mockUser) got := r.Context().Value(CTXKeyUser).(mockUser)
if got != want { if got != want {
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got) t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got)
} }
@ -242,10 +236,10 @@ func TestLoadCurrentUserP(t *testing.T) {
_ = ab.LoadCurrentUserP(nil, &r) _ = ab.LoadCurrentUserP(nil, &r)
} }
func TestCtxKeyString(t *testing.T) { func TestCTXKeyString(t *testing.T) {
t.Parallel() t.Parallel()
if got := ctxKeyPID.String(); got != "authboss ctx key pid" { if got := CTXKeyPID.String(); got != "authboss ctx key pid" {
t.Error(got) t.Error(got)
} }
} }

View File

@ -8,8 +8,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/pkg/errors"
) )
type mockUser struct { type mockUser struct {
@ -29,40 +27,26 @@ func (m mockServerStorer) Load(ctx context.Context, key string) (User, error) {
} }
func (m mockServerStorer) Save(ctx context.Context, user User) error { func (m mockServerStorer) Save(ctx context.Context, user User) error {
e, err := user.GetPID(ctx) pid := user.GetPID(ctx)
if err != nil { m[pid] = user.(mockUser)
panic(err)
}
m[e] = user.(mockUser)
return nil return nil
} }
func (m mockUser) PutPID(ctx context.Context, email string) error { func (m mockUser) PutPID(ctx context.Context, email string) {
m.Email = email m.Email = email
return nil
} }
func (m mockUser) PutUsername(ctx context.Context, username string) error { func (m mockUser) PutPassword(ctx context.Context, password string) {
return errors.New("not impl")
}
func (m mockUser) PutPassword(ctx context.Context, password string) error {
m.Password = password m.Password = password
return nil
} }
func (m mockUser) GetPID(ctx context.Context) (email string, err error) { func (m mockUser) GetPID(ctx context.Context) (email string) {
return m.Email, nil return m.Email
} }
func (m mockUser) GetUsername(ctx context.Context) (username string, err error) { func (m mockUser) GetPassword(ctx context.Context) (password string) {
return "", errors.New("not impl") return m.Password
}
func (m mockUser) GetPassword(ctx context.Context) (password string, err error) {
return m.Password, nil
} }
type mockClientStateReadWriter struct { type mockClientStateReadWriter struct {