1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-10 04:17:59 +02:00

Fix tests back up again.

- Remove a test that was obsoleted by optimizations. Not 100% sure this
  is correct, but it seems like if nothing has changed since the
  previous session/cookie read then we shouldn't need to write any new
  headers for them. This is especially true in the typical "I use
  cookies for everything" use case, but may not be true of other use
  cases... Remains to be seen. Since they're optimizations they should
  be able to removed "safely" later.
This commit is contained in:
Aaron L 2018-02-14 15:16:44 -08:00
parent f585b35cfb
commit 2137c827d3
4 changed files with 23 additions and 71 deletions

View File

@ -90,8 +90,9 @@ type ClientStateResponseWriter struct {
cookieStateEvents []ClientStateEvent
}
// ClientStateMiddleware wraps all requests with the ClientStateResponseWriter
func (a *Authboss) ClientStateMiddleware(h http.Handler) http.Handler {
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
// as well as loading the current client state into the context for use.
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request, err := a.LoadClientState(w, r)
if err != nil {
@ -174,7 +175,8 @@ func (c ClientStateResponseWriter) Header() http.Header {
return c.ResponseWriter.Header()
}
// Write ensures that the
// Write ensures that the client state is written before any writes
// to the body occur (before header flush to http client)
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
if !c.hasWritten {
if err := c.putClientState(); err != nil {
@ -184,7 +186,7 @@ func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
return c.ResponseWriter.Write(b)
}
// UnderlyingResponseWriter for this isnstance
// UnderlyingResponseWriter for this instance
func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return c.ResponseWriter
}

View File

@ -2,7 +2,6 @@ package authboss
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
@ -54,33 +53,6 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
w.putClientState()
}
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
t.Parallel()
ab := New()
ab.Storage.SessionState = newMockClientStateRW("one", "two")
ab.Storage.CookieState = newMockClientStateRW("three", "four")
r := httptest.NewRequest("GET", "/", nil)
var w http.ResponseWriter = httptest.NewRecorder()
var err error
r, err = ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
w = ab.NewResponse(w, r)
w.WriteHeader(200)
// This is an odd test, since the mock will always overwrite the previous
// write with the cookie values. Keeping it anyway for code coverage
got := strings.TrimSpace(w.Header().Get("test_session"))
if got != `{"three":"four"}` {
t.Error("got:", got)
}
}
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
t.Parallel()

View File

@ -30,8 +30,8 @@ func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
ab := New()
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid")
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
ctx := context.WithValue(req.Context(), CTXKeyPID, "george-pid")
ctx = context.WithValue(ctx, CTXKeyUser, wantUser)
req = req.WithContext(ctx)
return ab, wantUser, req
@ -101,9 +101,7 @@ func TestCurrentUser(t *testing.T) {
t.Error(err)
}
if got, err := user.GetPID(context.TODO()); err != nil {
t.Error(err)
} else if got != "george-pid" {
if got := user.GetPID(context.TODO()); got != "george-pid" {
t.Error("got:", got)
}
}
@ -118,9 +116,7 @@ func TestCurrentUserContext(t *testing.T) {
t.Error(err)
}
if got, err := user.GetPID(context.TODO()); err != nil {
t.Error(err)
} else if got != "george-pid" {
if got := user.GetPID(context.TODO()); got != "george-pid" {
t.Error("got:", got)
}
}
@ -153,7 +149,7 @@ func TestLoadCurrentUserID(t *testing.T) {
t.Error("got:", id)
}
if r.Context().Value(ctxKeyPID).(string) != "george-pid" {
if r.Context().Value(CTXKeyPID).(string) != "george-pid" {
t.Error("context was not updated in local request")
}
}
@ -198,14 +194,12 @@ func TestLoadCurrentUser(t *testing.T) {
t.Error(err)
}
if got, err := user.GetPID(context.TODO()); err != nil {
t.Error(err)
} else if got != "george-pid" {
if got := user.GetPID(context.TODO()); got != "george-pid" {
t.Error("got:", got)
}
want := user.(mockUser)
got := r.Context().Value(ctxKeyUser).(mockUser)
got := r.Context().Value(CTXKeyUser).(mockUser)
if got != want {
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got)
}
@ -242,10 +236,10 @@ func TestLoadCurrentUserP(t *testing.T) {
_ = ab.LoadCurrentUserP(nil, &r)
}
func TestCtxKeyString(t *testing.T) {
func TestCTXKeyString(t *testing.T) {
t.Parallel()
if got := ctxKeyPID.String(); got != "authboss ctx key pid" {
if got := CTXKeyPID.String(); got != "authboss ctx key pid" {
t.Error(got)
}
}

View File

@ -8,8 +8,6 @@ import (
"net/http"
"net/url"
"strings"
"github.com/pkg/errors"
)
type mockUser struct {
@ -29,40 +27,26 @@ func (m mockServerStorer) Load(ctx context.Context, key string) (User, error) {
}
func (m mockServerStorer) Save(ctx context.Context, user User) error {
e, err := user.GetPID(ctx)
if err != nil {
panic(err)
}
m[e] = user.(mockUser)
pid := user.GetPID(ctx)
m[pid] = user.(mockUser)
return nil
}
func (m mockUser) PutPID(ctx context.Context, email string) error {
func (m mockUser) PutPID(ctx context.Context, email string) {
m.Email = email
return nil
}
func (m mockUser) PutUsername(ctx context.Context, username string) error {
return errors.New("not impl")
}
func (m mockUser) PutPassword(ctx context.Context, password string) error {
func (m mockUser) PutPassword(ctx context.Context, password string) {
m.Password = password
return nil
}
func (m mockUser) GetPID(ctx context.Context) (email string, err error) {
return m.Email, nil
func (m mockUser) GetPID(ctx context.Context) (email string) {
return m.Email
}
func (m mockUser) GetUsername(ctx context.Context) (username string, err error) {
return "", errors.New("not impl")
}
func (m mockUser) GetPassword(ctx context.Context) (password string, err error) {
return m.Password, nil
func (m mockUser) GetPassword(ctx context.Context) (password string) {
return m.Password
}
type mockClientStateReadWriter struct {