mirror of
https://github.com/volatiletech/authboss.git
synced 2025-09-16 09:06:20 +02:00
Re-enable tests, add more tests
This commit is contained in:
@@ -69,6 +69,8 @@ func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) Storer {
|
||||
i, err := a.CurrentUser(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if i == nil {
|
||||
panic(ErrUserFound)
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
176
context_test.go
176
context_test.go
@@ -1,16 +1,59 @@
|
||||
package authboss
|
||||
|
||||
/* TODO(aarondl): Re-enable
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *http.Request {
|
||||
r, err := ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func testSetupContext() (*Authboss, *http.Request) {
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
|
||||
ab.StoreLoader = mockStoreLoader{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil))
|
||||
|
||||
return ab, r
|
||||
}
|
||||
|
||||
func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
|
||||
ab := New()
|
||||
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
|
||||
storer := mockStoredUser{
|
||||
mockUser: wantUser,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid")
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, storer)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
return ab, wantUser, req
|
||||
}
|
||||
|
||||
func testSetupContextPanic() *Authboss {
|
||||
ab := New()
|
||||
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
|
||||
ab.StoreLoader = mockStoreLoader{}
|
||||
|
||||
return ab
|
||||
}
|
||||
|
||||
func TestCurrentUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab, r := testSetupContext()
|
||||
|
||||
id, err := ab.CurrentUserID(nil, httptest.NewRequest("GET", "/", nil))
|
||||
id, err := ab.CurrentUserID(nil, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -23,10 +66,9 @@ func TestCurrentUserID(t *testing.T) {
|
||||
func TestCurrentUserIDContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxKeyPID, "george-pid"))
|
||||
id, err := ab.CurrentUserID(nil, req)
|
||||
ab, r := testSetupContext()
|
||||
|
||||
id, err := ab.CurrentUserID(nil, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -39,8 +81,9 @@ func TestCurrentUserIDContext(t *testing.T) {
|
||||
func TestCurrentUserIDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
ab := testSetupContextPanic()
|
||||
// Overwrite the setup functions state storer
|
||||
ab.SessionStateStorer = newMockClientStateRW()
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
@@ -54,15 +97,9 @@ func TestCurrentUserIDP(t *testing.T) {
|
||||
func TestCurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
ab, r := testSetupContext()
|
||||
|
||||
user, err := ab.CurrentUser(nil, httptest.NewRequest("GET", "/", nil))
|
||||
user, err := ab.CurrentUser(nil, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -77,17 +114,9 @@ func TestCurrentUser(t *testing.T) {
|
||||
func TestCurrentUserContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
wantUser := mockStoredUser{
|
||||
mockUser: mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
ab, _, r := testSetupContextCached()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-id")
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
user, err := ab.CurrentUser(nil, req)
|
||||
user, err := ab.CurrentUser(nil, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -102,11 +131,7 @@ func TestCurrentUserContext(t *testing.T) {
|
||||
func TestCurrentUserP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{}
|
||||
ab := testSetupContextPanic()
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
@@ -120,14 +145,9 @@ func TestCurrentUserP(t *testing.T) {
|
||||
func TestLoadCurrentUserID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab, r := testSetupContext()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
id, err := ab.LoadCurrentUserID(nil, &req)
|
||||
id, err := ab.LoadCurrentUserID(nil, &r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -136,16 +156,30 @@ func TestLoadCurrentUserID(t *testing.T) {
|
||||
t.Error("got:", id)
|
||||
}
|
||||
|
||||
if req.Context().Value(ctxKeyPID).(string) != "george-pid" {
|
||||
if r.Context().Value(ctxKeyPID).(string) != "george-pid" {
|
||||
t.Error("context was not updated in local request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserIDContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab, _, r := testSetupContextCached()
|
||||
|
||||
pid, err := ab.LoadCurrentUserID(nil, &r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if pid != "george-pid" {
|
||||
t.Error("got:", pid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserIDP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
|
||||
ab := testSetupContextPanic()
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
@@ -153,23 +187,16 @@ func TestLoadCurrentUserIDP(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserIDP(nil, &req)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserIDP(nil, &r)
|
||||
}
|
||||
|
||||
func TestLoadCurrentUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{
|
||||
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
ab, r := testSetupContext()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
user, err := ab.LoadCurrentUser(nil, &req)
|
||||
user, err := ab.LoadCurrentUser(nil, &r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -181,7 +208,7 @@ func TestLoadCurrentUser(t *testing.T) {
|
||||
}
|
||||
|
||||
want := user.(mockStoredUser).mockUser
|
||||
got := req.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
|
||||
got := r.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
|
||||
if got != want {
|
||||
t.Error("users mismatched:\nwant: %#v\ngot: %#v", want, got)
|
||||
}
|
||||
@@ -190,35 +217,23 @@ func TestLoadCurrentUser(t *testing.T) {
|
||||
func TestLoadCurrentUserContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
wantUser := mockStoredUser{
|
||||
mockUser: mockUser{Email: "george-pid", Password: "unreadable"},
|
||||
}
|
||||
ab, wantUser, r := testSetupContextCached()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-id")
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
user, err := ab.LoadCurrentUser(nil, &req)
|
||||
user, err := ab.LoadCurrentUser(nil, &r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
got := user.(mockStoredUser).mockUser
|
||||
if got != wantUser.mockUser {
|
||||
t.Error("users mismatched:\nwant: %#v\ngot: %#v", wantUser.mockUser, got)
|
||||
if got != wantUser {
|
||||
t.Error("users mismatched:\nwant: %#v\ngot: %#v", wantUser, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCurrentUserP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
|
||||
SessionKey: "george-pid",
|
||||
})
|
||||
ab.StoreLoader = mockStoreLoader{}
|
||||
ab := testSetupContextPanic()
|
||||
|
||||
defer func() {
|
||||
if recover().(error) != ErrUserNotFound {
|
||||
@@ -226,7 +241,14 @@ func TestLoadCurrentUserP(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserP(nil, &req)
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
_ = ab.LoadCurrentUserP(nil, &r)
|
||||
}
|
||||
|
||||
func TestCtxKeyString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := ctxKeyPID.String(); got != "authboss ctx key pid" {
|
||||
t.Error(got)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
57
expire.go
57
expire.go
@@ -1,23 +1,23 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var nowTime = time.Now
|
||||
|
||||
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
|
||||
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
|
||||
//TODO(aarondl): Rewrite this so it makes sense with new ClientStorer idioms
|
||||
//return a.timeToExpiry(state.(ClientState))
|
||||
return 0
|
||||
// TimeToExpiry returns zero if the user session is expired else the time
|
||||
// until expiry. Takes in the allowed idle duration.
|
||||
func TimeToExpiry(w http.ResponseWriter, r *http.Request, expireAfter time.Duration) time.Duration {
|
||||
return timeToExpiry(r, expireAfter)
|
||||
}
|
||||
|
||||
func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
|
||||
dateStr, ok := session.Get(SessionLastAction)
|
||||
func timeToExpiry(r *http.Request, expireAfter time.Duration) time.Duration {
|
||||
dateStr, ok := GetSession(r, SessionLastAction)
|
||||
if !ok {
|
||||
return a.ExpireAfter
|
||||
return expireAfter
|
||||
}
|
||||
|
||||
date, err := time.Parse(time.RFC3339, dateStr)
|
||||
@@ -25,7 +25,7 @@ func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
|
||||
panic("last_action is not a valid RFC3339 date")
|
||||
}
|
||||
|
||||
remaining := date.Add(a.ExpireAfter).Sub(nowTime().UTC())
|
||||
remaining := date.Add(expireAfter).Sub(nowTime().UTC())
|
||||
if remaining > 0 {
|
||||
return remaining
|
||||
}
|
||||
@@ -33,15 +33,14 @@ func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
|
||||
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
|
||||
//TODO(aarondl): Fix
|
||||
//a.refreshExpiry(session)
|
||||
// RefreshExpiry updates the last action for the user, so he doesn't
|
||||
// become expired.
|
||||
func RefreshExpiry(w http.ResponseWriter, r *http.Request) {
|
||||
refreshExpiry(w)
|
||||
}
|
||||
|
||||
func (a *Authboss) refreshExpiry(session ClientState) {
|
||||
//TODO(aarondl): Fix
|
||||
PutSession(nil, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
|
||||
func refreshExpiry(w http.ResponseWriter) {
|
||||
PutSession(w, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
type expireMiddleware struct {
|
||||
@@ -58,19 +57,21 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
|
||||
return expireMiddleware{a, next}
|
||||
}
|
||||
|
||||
// ServeHTTP removes the session if it's passed the expire time.
|
||||
// ServeHTTP removes the session and hides the loaded user from the handlers
|
||||
// below it.
|
||||
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
//TODO(aarondl): Fix
|
||||
/*
|
||||
if _, ok := GetSession(r, SessionKey); ok {
|
||||
ttl := m.ab.timeToExpiry(session)
|
||||
if ttl == 0 {
|
||||
session.Del(SessionKey)
|
||||
session.Del(SessionLastAction)
|
||||
} else {
|
||||
m.ab.refreshExpiry(session)
|
||||
}
|
||||
if _, ok := GetSession(r, SessionKey); ok {
|
||||
ttl := timeToExpiry(r, m.ab.ExpireAfter)
|
||||
if ttl == 0 {
|
||||
DelSession(w, SessionKey)
|
||||
DelSession(w, SessionLastAction)
|
||||
ctx := context.WithValue(r.Context(), ctxKeyPID, nil)
|
||||
ctx = context.WithValue(ctx, ctxKeyUser, nil)
|
||||
r = r.WithContext(ctx)
|
||||
} else {
|
||||
refreshExpiry(w)
|
||||
}
|
||||
}
|
||||
|
||||
m.next.ServeHTTP(w, r)*/
|
||||
m.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
149
expire_test.go
149
expire_test.go
@@ -1,15 +1,30 @@
|
||||
package authboss
|
||||
|
||||
/* TODO(aarondl): Re-enable
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// These tests use the global variable nowTime so cannot be parallelized
|
||||
|
||||
func TestDudeIsExpired(t *testing.T) {
|
||||
func TestExpireIsExpired(t *testing.T) {
|
||||
ab := New()
|
||||
session := mockClientStore{SessionKey: "username"}
|
||||
ab.refreshExpiry(session)
|
||||
ab.SessionStateStorer = newMockClientStateRW(
|
||||
SessionKey, "username",
|
||||
SessionLastAction, time.Now().UTC().Format(time.RFC3339),
|
||||
)
|
||||
|
||||
// No t.Parallel()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), ctxKeyPID, "primaryid"))
|
||||
r = r.WithContext(context.WithValue(r.Context(), ctxKeyUser, struct{}{}))
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
r, err := ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// No t.Parallel() - Also must be after refreshExpiry() call
|
||||
nowTime = func() time.Time {
|
||||
return time.Now().UTC().Add(ab.ExpireAfter * 2)
|
||||
}
|
||||
@@ -17,62 +32,132 @@ func TestDudeIsExpired(t *testing.T) {
|
||||
nowTime = time.Now
|
||||
}()
|
||||
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
|
||||
r, _ := http.NewRequest("GET", "tra/la/la", nil)
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
|
||||
hadUser := false
|
||||
m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
|
||||
if r.Context().Value(ctxKeyPID) != nil {
|
||||
hadUser = true
|
||||
}
|
||||
if r.Context().Value(ctxKeyUser) != nil {
|
||||
hadUser = true
|
||||
}
|
||||
}))
|
||||
|
||||
m.ServeHTTP(w, r)
|
||||
|
||||
if !called {
|
||||
t.Error("Expected middleware to call handler")
|
||||
t.Error("expected middleware to call handler")
|
||||
}
|
||||
if hadUser {
|
||||
t.Error("expected user not to be present")
|
||||
}
|
||||
|
||||
if key, ok := session.Get(SessionKey); ok {
|
||||
t.Error("Unexpected session key:", key)
|
||||
}
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
if key, ok := session.Get(SessionLastAction); ok {
|
||||
t.Error("Unexpected last action key:", key)
|
||||
want := ClientStateEvent{
|
||||
Kind: ClientStateEventDel,
|
||||
Key: SessionKey,
|
||||
}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
t.Error("want:", want, "got:", got)
|
||||
}
|
||||
want = ClientStateEvent{
|
||||
Kind: ClientStateEventDel,
|
||||
Key: SessionLastAction,
|
||||
}
|
||||
if got := csrw.sessionStateEvents[1]; got != want {
|
||||
t.Error("want:", want, "got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDudeIsNotExpired(t *testing.T) {
|
||||
func TestExpireNotExpired(t *testing.T) {
|
||||
ab := New()
|
||||
session := mockClientStore{SessionKey: "username"}
|
||||
ab.refreshExpiry(session)
|
||||
ab.SessionStateStorer = newMockClientStateRW(
|
||||
SessionKey, "username",
|
||||
SessionLastAction, time.Now().UTC().Format(time.RFC3339),
|
||||
)
|
||||
|
||||
// No t.Parallel()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
r = r.WithContext(context.WithValue(r.Context(), ctxKeyPID, "primaryid"))
|
||||
r = r.WithContext(context.WithValue(r.Context(), ctxKeyUser, struct{}{}))
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
r, err := ab.LoadClientState(w, r)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// No t.Parallel() - Also must be after refreshExpiry() call
|
||||
newTime := time.Now().UTC().Add(ab.ExpireAfter / 2)
|
||||
nowTime = func() time.Time {
|
||||
return time.Now().UTC().Add(ab.ExpireAfter / 2)
|
||||
return newTime
|
||||
}
|
||||
defer func() {
|
||||
nowTime = time.Now
|
||||
}()
|
||||
|
||||
ab.SessionStoreMaker = newMockClientStoreMaker(session)
|
||||
|
||||
r, _ := http.NewRequest("GET", "tra/la/la", nil)
|
||||
w := httptest.NewRecorder()
|
||||
called := false
|
||||
|
||||
hadUser := true
|
||||
m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
|
||||
if r.Context().Value(ctxKeyPID) == nil {
|
||||
hadUser = false
|
||||
}
|
||||
if r.Context().Value(ctxKeyUser) == nil {
|
||||
hadUser = false
|
||||
}
|
||||
}))
|
||||
|
||||
m.ServeHTTP(w, r)
|
||||
|
||||
if !called {
|
||||
t.Error("Expected middleware to call handler")
|
||||
t.Error("expected middleware to call handler")
|
||||
}
|
||||
if !hadUser {
|
||||
t.Error("expected user to be present")
|
||||
}
|
||||
|
||||
if key, ok := session.Get(SessionKey); !ok {
|
||||
t.Error("Expected session key:", key)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
|
||||
want := ClientStateEvent{
|
||||
Kind: ClientStateEventPut,
|
||||
Key: SessionLastAction,
|
||||
Value: newTime.Format(time.RFC3339),
|
||||
}
|
||||
if got := csrw.sessionStateEvents[0]; got != want {
|
||||
t.Error("want:", want, "got:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpireTimeToExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
want := 5 * time.Second
|
||||
dur := TimeToExpiry(w, r, want)
|
||||
if dur != want {
|
||||
t.Error("duration was wrong:", dur)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpireRefreshExpiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := ab.NewResponse(httptest.NewRecorder(), r)
|
||||
|
||||
RefreshExpiry(w, r)
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
if got := csrw.sessionStateEvents[0].Kind; got != ClientStateEventPut {
|
||||
t.Error("wrong event:", got)
|
||||
}
|
||||
if got := csrw.sessionStateEvents[0].Key; got != SessionLastAction {
|
||||
t.Error("wrong key:", got)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -191,13 +192,28 @@ func (m mockRenderLoader) Init(names []string) (Renderer, error) {
|
||||
return mockRenderer{}, nil
|
||||
}
|
||||
|
||||
type mockRenderer struct{}
|
||||
type mockRenderer struct {
|
||||
expectName string
|
||||
}
|
||||
|
||||
func (m mockRenderer) Render(ctx context.Context, name string, data HTMLData) ([]byte, string, error) {
|
||||
b, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
if len(m.expectName) != 0 && m.expectName != name {
|
||||
panic(fmt.Sprintf("want template name: %s, but got: %s", m.expectName, name))
|
||||
}
|
||||
|
||||
return b, "application/json", nil
|
||||
b, err := json.Marshal(data)
|
||||
return b, "application/json", err
|
||||
}
|
||||
|
||||
type mockEmailRenderer struct{}
|
||||
|
||||
func (m mockEmailRenderer) Render(ctx context.Context, name string, data HTMLData) ([]byte, string, error) {
|
||||
switch name {
|
||||
case "text":
|
||||
return []byte("a development text e-mail template"), "text/plain", nil
|
||||
case "html":
|
||||
return []byte("a development html e-mail template"), "text/html", nil
|
||||
default:
|
||||
panic("shouldn't get here")
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -11,8 +10,8 @@ import (
|
||||
// and writes the headers out.
|
||||
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
|
||||
data.MergeKV(
|
||||
"xsrfName", template.HTML(a.XSRFName),
|
||||
"xsrfToken", template.HTML(a.XSRFMaker(w, r)),
|
||||
"xsrfName", a.XSRFName,
|
||||
"xsrfToken", a.XSRFMaker(w, r),
|
||||
)
|
||||
|
||||
if a.LayoutDataMaker != nil {
|
||||
@@ -123,7 +122,7 @@ func (a *Authboss) redirectAPI(w http.ResponseWriter, r *http.Request, ro Redire
|
||||
}
|
||||
|
||||
data := HTMLData{
|
||||
"path": path,
|
||||
"location": path,
|
||||
}
|
||||
|
||||
if len(status) != 0 {
|
||||
|
256
response_test.go
Normal file
256
response_test.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package authboss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResponseRespond(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.renderer = mockRenderer{expectName: "some_template.tpl"}
|
||||
ab.SessionStateStorer = newMockClientStateRW(
|
||||
FlashSuccessKey, "flash_success",
|
||||
FlashErrorKey, "flash_error",
|
||||
)
|
||||
ab.XSRFName = "xsrf"
|
||||
ab.XSRFMaker = func(w http.ResponseWriter, r *http.Request) string {
|
||||
return "xsrftoken"
|
||||
}
|
||||
ab.LayoutDataMaker = func(w http.ResponseWriter, r *http.Request) HTMLData {
|
||||
return HTMLData{"hello": "world"}
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
wr := httptest.NewRecorder()
|
||||
w := ab.NewResponse(wr, r)
|
||||
r = loadClientStateP(ab, w, r)
|
||||
err := ab.Respond(w, r, http.StatusCreated, "some_template.tpl", HTMLData{"auth_happy": true})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if wr.Code != http.StatusCreated {
|
||||
t.Error("code was wrong:", wr.Code)
|
||||
}
|
||||
|
||||
if got := wr.HeaderMap.Get("Content-Type"); got != "application/json" {
|
||||
t.Error("content type was wrong:", got)
|
||||
}
|
||||
|
||||
expectData := HTMLData{
|
||||
"xsrfName": "xsrf",
|
||||
"xsrfToken": "xsrftoken",
|
||||
"hello": "world",
|
||||
FlashSuccessKey: "flash_success",
|
||||
FlashErrorKey: "flash_error",
|
||||
"auth_happy": true,
|
||||
}
|
||||
|
||||
var gotData HTMLData
|
||||
if err := json.Unmarshal(wr.Body.Bytes(), &gotData); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotData, expectData) {
|
||||
t.Errorf("data mismatched:\nwant: %#v\ngot: %#v", expectData, gotData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseEmail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.renderer = mockEmailRenderer{}
|
||||
ab.SessionStateStorer = newMockClientStateRW(
|
||||
FlashSuccessKey, "flash_success",
|
||||
FlashErrorKey, "flash_error",
|
||||
)
|
||||
ab.XSRFName = "xsrf"
|
||||
ab.XSRFMaker = func(w http.ResponseWriter, r *http.Request) string {
|
||||
return "xsrftoken"
|
||||
}
|
||||
ab.LayoutDataMaker = func(w http.ResponseWriter, r *http.Request) HTMLData {
|
||||
return HTMLData{"hello": "world"}
|
||||
}
|
||||
|
||||
output := &bytes.Buffer{}
|
||||
ab.Mailer = LogMailer(output)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
wr := httptest.NewRecorder()
|
||||
w := ab.NewResponse(wr, r)
|
||||
|
||||
email := Email{
|
||||
To: []string{"test@example.com"},
|
||||
From: "test@example.com",
|
||||
Subject: "subject",
|
||||
}
|
||||
ro := EmailResponseOptions{Data: nil, HTMLTemplate: "html", TextTemplate: "text"}
|
||||
err := ab.Email(w, r, email, ro)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
wantStrings := []string{
|
||||
"To: test@example.com",
|
||||
"From: test@example.com",
|
||||
"Subject: subject",
|
||||
"development text e-mail",
|
||||
"development html e-mail",
|
||||
}
|
||||
|
||||
out := output.String()
|
||||
for i, test := range wantStrings {
|
||||
if !strings.Contains(out, test) {
|
||||
t.Errorf("output missing string(%d): %s\n%s", i, test, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRedirectAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.renderer = mockRenderer{}
|
||||
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
ro := RedirectOptions{
|
||||
Success: "ok!",
|
||||
Code: http.StatusTeapot,
|
||||
RedirectPath: "/redirect", FollowRedirParam: false,
|
||||
}
|
||||
|
||||
if err := ab.Redirect(w, r, ro); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusTeapot {
|
||||
t.Error("code is wrong:", w.Code)
|
||||
}
|
||||
|
||||
var gotData map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &gotData); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := gotData["status"]; got != "success" {
|
||||
t.Error("status was wrong:", got)
|
||||
}
|
||||
if got := gotData["message"]; got != "ok!" {
|
||||
t.Error("message was wrong:", got)
|
||||
}
|
||||
if got := gotData["location"]; got != "/redirect" {
|
||||
t.Error("location was wrong:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRedirectAPIFollowRedir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
ab.renderer = mockRenderer{}
|
||||
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
ro := RedirectOptions{
|
||||
Failure: ":(",
|
||||
Code: http.StatusTeapot,
|
||||
RedirectPath: "/redirect", FollowRedirParam: true,
|
||||
}
|
||||
|
||||
if err := ab.Redirect(w, r, ro); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if w.Code != http.StatusTeapot {
|
||||
t.Error("code is wrong:", w.Code)
|
||||
}
|
||||
|
||||
var gotData map[string]string
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &gotData); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if got := gotData["status"]; got != "failure" {
|
||||
t.Error("status was wrong:", got)
|
||||
}
|
||||
if got := gotData["message"]; got != ":(" {
|
||||
t.Error("message was wrong:", got)
|
||||
}
|
||||
if got := gotData["location"]; got != "/pow" {
|
||||
t.Error("location was wrong:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRedirectNonAPI(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
|
||||
wr := httptest.NewRecorder()
|
||||
w := ab.NewResponse(wr, r)
|
||||
|
||||
ro := RedirectOptions{
|
||||
Success: "success", Failure: "failure",
|
||||
RedirectPath: "/redirect", FollowRedirParam: false,
|
||||
}
|
||||
|
||||
if err := ab.Redirect(w, r, ro); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
want := ClientStateEvent{Kind: ClientStateEventPut, Key: FlashSuccessKey, Value: "success"}
|
||||
if csrw.sessionStateEvents[0] != want {
|
||||
t.Error("event was wrong:", csrw.sessionStateEvents[0])
|
||||
}
|
||||
want = ClientStateEvent{Kind: ClientStateEventPut, Key: FlashErrorKey, Value: "failure"}
|
||||
if csrw.sessionStateEvents[1] != want {
|
||||
t.Error("event was wrong:", csrw.sessionStateEvents[1])
|
||||
}
|
||||
if wr.Code != http.StatusFound {
|
||||
t.Error("code is wrong:", wr.Code)
|
||||
}
|
||||
if got := wr.Header().Get("Location"); got != "/redirect" {
|
||||
t.Error("redirect location was wrong:", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRedirectNonAPIFollowRedir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ab := New()
|
||||
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
|
||||
wr := httptest.NewRecorder()
|
||||
w := ab.NewResponse(wr, r)
|
||||
|
||||
ro := RedirectOptions{
|
||||
RedirectPath: "/redirect", FollowRedirParam: true,
|
||||
}
|
||||
if err := ab.Redirect(w, r, ro); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
csrw := w.(*ClientStateResponseWriter)
|
||||
if len(csrw.sessionStateEvents) != 0 {
|
||||
t.Error("session state events should be empty:", csrw.sessionStateEvents)
|
||||
}
|
||||
if wr.Code != http.StatusFound {
|
||||
t.Error("code is wrong:", wr.Code)
|
||||
}
|
||||
if got := wr.Header().Get("Location"); got != "/pow" {
|
||||
t.Error("redirect location was wrong:", got)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user