mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-24 05:17:10 +02:00
Fix tests back up again.
- Remove a test that was obsoleted by optimizations. Not 100% sure this is correct, but it seems like if nothing has changed since the previous session/cookie read then we shouldn't need to write any new headers for them. This is especially true in the typical "I use cookies for everything" use case, but may not be true of other use cases... Remains to be seen. Since they're optimizations they should be able to removed "safely" later.
This commit is contained in:
parent
f585b35cfb
commit
2137c827d3
@ -90,8 +90,9 @@ type ClientStateResponseWriter struct {
|
|||||||
cookieStateEvents []ClientStateEvent
|
cookieStateEvents []ClientStateEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientStateMiddleware wraps all requests with the ClientStateResponseWriter
|
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
|
||||||
func (a *Authboss) ClientStateMiddleware(h http.Handler) http.Handler {
|
// as well as loading the current client state into the context for use.
|
||||||
|
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
request, err := a.LoadClientState(w, r)
|
request, err := a.LoadClientState(w, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -174,7 +175,8 @@ func (c ClientStateResponseWriter) Header() http.Header {
|
|||||||
return c.ResponseWriter.Header()
|
return c.ResponseWriter.Header()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write ensures that the
|
// Write ensures that the client state is written before any writes
|
||||||
|
// to the body occur (before header flush to http client)
|
||||||
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
|
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
|
||||||
if !c.hasWritten {
|
if !c.hasWritten {
|
||||||
if err := c.putClientState(); err != nil {
|
if err := c.putClientState(); err != nil {
|
||||||
@ -184,7 +186,7 @@ func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
|
|||||||
return c.ResponseWriter.Write(b)
|
return c.ResponseWriter.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnderlyingResponseWriter for this isnstance
|
// UnderlyingResponseWriter for this instance
|
||||||
func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||||
return c.ResponseWriter
|
return c.ResponseWriter
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package authboss
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -54,33 +53,6 @@ func TestStateResponseWriterDoubleWritePanic(t *testing.T) {
|
|||||||
w.putClientState()
|
w.putClientState()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStateResponseWriterLastSecondWriteWithPrevious(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ab := New()
|
|
||||||
ab.Storage.SessionState = newMockClientStateRW("one", "two")
|
|
||||||
ab.Storage.CookieState = newMockClientStateRW("three", "four")
|
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
|
||||||
var w http.ResponseWriter = httptest.NewRecorder()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
r, err = ab.LoadClientState(w, r)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
w = ab.NewResponse(w, r)
|
|
||||||
|
|
||||||
w.WriteHeader(200)
|
|
||||||
|
|
||||||
// This is an odd test, since the mock will always overwrite the previous
|
|
||||||
// write with the cookie values. Keeping it anyway for code coverage
|
|
||||||
got := strings.TrimSpace(w.Header().Get("test_session"))
|
|
||||||
if got != `{"three":"four"}` {
|
|
||||||
t.Error("got:", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
|
func TestStateResponseWriterLastSecondWriteHeader(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -30,8 +30,8 @@ func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
|
|||||||
ab := New()
|
ab := New()
|
||||||
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
|
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid")
|
ctx := context.WithValue(req.Context(), CTXKeyPID, "george-pid")
|
||||||
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
|
ctx = context.WithValue(ctx, CTXKeyUser, wantUser)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
|
|
||||||
return ab, wantUser, req
|
return ab, wantUser, req
|
||||||
@ -101,9 +101,7 @@ func TestCurrentUser(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, err := user.GetPID(context.TODO()); err != nil {
|
if got := user.GetPID(context.TODO()); got != "george-pid" {
|
||||||
t.Error(err)
|
|
||||||
} else if got != "george-pid" {
|
|
||||||
t.Error("got:", got)
|
t.Error("got:", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,9 +116,7 @@ func TestCurrentUserContext(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, err := user.GetPID(context.TODO()); err != nil {
|
if got := user.GetPID(context.TODO()); got != "george-pid" {
|
||||||
t.Error(err)
|
|
||||||
} else if got != "george-pid" {
|
|
||||||
t.Error("got:", got)
|
t.Error("got:", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -153,7 +149,7 @@ func TestLoadCurrentUserID(t *testing.T) {
|
|||||||
t.Error("got:", id)
|
t.Error("got:", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Context().Value(ctxKeyPID).(string) != "george-pid" {
|
if r.Context().Value(CTXKeyPID).(string) != "george-pid" {
|
||||||
t.Error("context was not updated in local request")
|
t.Error("context was not updated in local request")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -198,14 +194,12 @@ func TestLoadCurrentUser(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got, err := user.GetPID(context.TODO()); err != nil {
|
if got := user.GetPID(context.TODO()); got != "george-pid" {
|
||||||
t.Error(err)
|
|
||||||
} else if got != "george-pid" {
|
|
||||||
t.Error("got:", got)
|
t.Error("got:", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
want := user.(mockUser)
|
want := user.(mockUser)
|
||||||
got := r.Context().Value(ctxKeyUser).(mockUser)
|
got := r.Context().Value(CTXKeyUser).(mockUser)
|
||||||
if got != want {
|
if got != want {
|
||||||
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got)
|
t.Errorf("users mismatched:\nwant: %#v\ngot: %#v", want, got)
|
||||||
}
|
}
|
||||||
@ -242,10 +236,10 @@ func TestLoadCurrentUserP(t *testing.T) {
|
|||||||
_ = ab.LoadCurrentUserP(nil, &r)
|
_ = ab.LoadCurrentUserP(nil, &r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCtxKeyString(t *testing.T) {
|
func TestCTXKeyString(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
if got := ctxKeyPID.String(); got != "authboss ctx key pid" {
|
if got := CTXKeyPID.String(); got != "authboss ctx key pid" {
|
||||||
t.Error(got)
|
t.Error(got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,8 +8,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockUser struct {
|
type mockUser struct {
|
||||||
@ -29,40 +27,26 @@ func (m mockServerStorer) Load(ctx context.Context, key string) (User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m mockServerStorer) Save(ctx context.Context, user User) error {
|
func (m mockServerStorer) Save(ctx context.Context, user User) error {
|
||||||
e, err := user.GetPID(ctx)
|
pid := user.GetPID(ctx)
|
||||||
if err != nil {
|
m[pid] = user.(mockUser)
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m[e] = user.(mockUser)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mockUser) PutPID(ctx context.Context, email string) error {
|
func (m mockUser) PutPID(ctx context.Context, email string) {
|
||||||
m.Email = email
|
m.Email = email
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mockUser) PutUsername(ctx context.Context, username string) error {
|
func (m mockUser) PutPassword(ctx context.Context, password string) {
|
||||||
return errors.New("not impl")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockUser) PutPassword(ctx context.Context, password string) error {
|
|
||||||
m.Password = password
|
m.Password = password
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mockUser) GetPID(ctx context.Context) (email string, err error) {
|
func (m mockUser) GetPID(ctx context.Context) (email string) {
|
||||||
return m.Email, nil
|
return m.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m mockUser) GetUsername(ctx context.Context) (username string, err error) {
|
func (m mockUser) GetPassword(ctx context.Context) (password string) {
|
||||||
return "", errors.New("not impl")
|
return m.Password
|
||||||
}
|
|
||||||
|
|
||||||
func (m mockUser) GetPassword(ctx context.Context) (password string, err error) {
|
|
||||||
return m.Password, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockClientStateReadWriter struct {
|
type mockClientStateReadWriter struct {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user