1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-09-16 09:06:20 +02:00

Re-enable tests, add more tests

This commit is contained in:
Aaron L
2017-03-05 10:01:46 -08:00
parent 24fc6196c7
commit a92fb4d069
7 changed files with 527 additions and 146 deletions

View File

@@ -69,6 +69,8 @@ func (a *Authboss) CurrentUserP(w http.ResponseWriter, r *http.Request) Storer {
i, err := a.CurrentUser(w, r)
if err != nil {
panic(err)
} else if i == nil {
panic(ErrUserFound)
}
return i
}

View File

@@ -1,16 +1,59 @@
package authboss
/* TODO(aarondl): Re-enable
import (
"context"
"net/http"
"net/http/httptest"
"testing"
)
func loadClientStateP(ab *Authboss, w http.ResponseWriter, r *http.Request) *http.Request {
r, err := ab.LoadClientState(w, r)
if err != nil {
panic(err)
}
return r
}
func testSetupContext() (*Authboss, *http.Request) {
ab := New()
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
ab.StoreLoader = mockStoreLoader{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
r := loadClientStateP(ab, nil, httptest.NewRequest("GET", "/", nil))
return ab, r
}
func testSetupContextCached() (*Authboss, mockUser, *http.Request) {
ab := New()
wantUser := mockUser{Email: "george-pid", Password: "unreadable"}
storer := mockStoredUser{
mockUser: wantUser,
}
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-pid")
ctx = context.WithValue(ctx, ctxKeyUser, storer)
req = req.WithContext(ctx)
return ab, wantUser, req
}
func testSetupContextPanic() *Authboss {
ab := New()
ab.SessionStateStorer = newMockClientStateRW(SessionKey, "george-pid")
ab.StoreLoader = mockStoreLoader{}
return ab
}
func TestCurrentUserID(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab, r := testSetupContext()
id, err := ab.CurrentUserID(nil, httptest.NewRequest("GET", "/", nil))
id, err := ab.CurrentUserID(nil, r)
if err != nil {
t.Error(err)
}
@@ -23,10 +66,9 @@ func TestCurrentUserID(t *testing.T) {
func TestCurrentUserIDContext(t *testing.T) {
t.Parallel()
ab := New()
req := httptest.NewRequest("GET", "/", nil)
req = req.WithContext(context.WithValue(req.Context(), ctxKeyPID, "george-pid"))
id, err := ab.CurrentUserID(nil, req)
ab, r := testSetupContext()
id, err := ab.CurrentUserID(nil, r)
if err != nil {
t.Error(err)
}
@@ -39,8 +81,9 @@ func TestCurrentUserIDContext(t *testing.T) {
func TestCurrentUserIDP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
ab := testSetupContextPanic()
// Overwrite the setup functions state storer
ab.SessionStateStorer = newMockClientStateRW()
defer func() {
if recover().(error) != ErrUserNotFound {
@@ -54,15 +97,9 @@ func TestCurrentUserIDP(t *testing.T) {
func TestCurrentUser(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
ab, r := testSetupContext()
user, err := ab.CurrentUser(nil, httptest.NewRequest("GET", "/", nil))
user, err := ab.CurrentUser(nil, r)
if err != nil {
t.Error(err)
}
@@ -77,17 +114,9 @@ func TestCurrentUser(t *testing.T) {
func TestCurrentUserContext(t *testing.T) {
t.Parallel()
ab := New()
wantUser := mockStoredUser{
mockUser: mockUser{Email: "george-pid", Password: "unreadable"},
}
ab, _, r := testSetupContextCached()
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-id")
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
req = req.WithContext(ctx)
user, err := ab.CurrentUser(nil, req)
user, err := ab.CurrentUser(nil, r)
if err != nil {
t.Error(err)
}
@@ -102,11 +131,7 @@ func TestCurrentUserContext(t *testing.T) {
func TestCurrentUserP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{}
ab := testSetupContextPanic()
defer func() {
if recover().(error) != ErrUserNotFound {
@@ -120,14 +145,9 @@ func TestCurrentUserP(t *testing.T) {
func TestLoadCurrentUserID(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab, r := testSetupContext()
req := httptest.NewRequest("GET", "/", nil)
id, err := ab.LoadCurrentUserID(nil, &req)
id, err := ab.LoadCurrentUserID(nil, &r)
if err != nil {
t.Error(err)
}
@@ -136,16 +156,30 @@ func TestLoadCurrentUserID(t *testing.T) {
t.Error("got:", id)
}
if req.Context().Value(ctxKeyPID).(string) != "george-pid" {
if r.Context().Value(ctxKeyPID).(string) != "george-pid" {
t.Error("context was not updated in local request")
}
}
func TestLoadCurrentUserIDContext(t *testing.T) {
t.Parallel()
ab, _, r := testSetupContextCached()
pid, err := ab.LoadCurrentUserID(nil, &r)
if err != nil {
t.Error(err)
}
if pid != "george-pid" {
t.Error("got:", pid)
}
}
func TestLoadCurrentUserIDP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{})
ab := testSetupContextPanic()
defer func() {
if recover().(error) != ErrUserNotFound {
@@ -153,23 +187,16 @@ func TestLoadCurrentUserIDP(t *testing.T) {
}
}()
req := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserIDP(nil, &req)
r := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserIDP(nil, &r)
}
func TestLoadCurrentUser(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{
"george-pid": mockUser{Email: "george-pid", Password: "unreadable"},
}
ab, r := testSetupContext()
req := httptest.NewRequest("GET", "/", nil)
user, err := ab.LoadCurrentUser(nil, &req)
user, err := ab.LoadCurrentUser(nil, &r)
if err != nil {
t.Error(err)
}
@@ -181,7 +208,7 @@ func TestLoadCurrentUser(t *testing.T) {
}
want := user.(mockStoredUser).mockUser
got := req.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
got := r.Context().Value(ctxKeyUser).(mockStoredUser).mockUser
if got != want {
t.Error("users mismatched:\nwant: %#v\ngot: %#v", want, got)
}
@@ -190,35 +217,23 @@ func TestLoadCurrentUser(t *testing.T) {
func TestLoadCurrentUserContext(t *testing.T) {
t.Parallel()
ab := New()
wantUser := mockStoredUser{
mockUser: mockUser{Email: "george-pid", Password: "unreadable"},
}
ab, wantUser, r := testSetupContextCached()
req := httptest.NewRequest("GET", "/", nil)
ctx := context.WithValue(req.Context(), ctxKeyPID, "george-id")
ctx = context.WithValue(ctx, ctxKeyUser, wantUser)
req = req.WithContext(ctx)
user, err := ab.LoadCurrentUser(nil, &req)
user, err := ab.LoadCurrentUser(nil, &r)
if err != nil {
t.Error(err)
}
got := user.(mockStoredUser).mockUser
if got != wantUser.mockUser {
t.Error("users mismatched:\nwant: %#v\ngot: %#v", wantUser.mockUser, got)
if got != wantUser {
t.Error("users mismatched:\nwant: %#v\ngot: %#v", wantUser, got)
}
}
func TestLoadCurrentUserP(t *testing.T) {
t.Parallel()
ab := New()
ab.SessionStoreMaker = newMockClientStoreMaker(mockClientStore{
SessionKey: "george-pid",
})
ab.StoreLoader = mockStoreLoader{}
ab := testSetupContextPanic()
defer func() {
if recover().(error) != ErrUserNotFound {
@@ -226,7 +241,14 @@ func TestLoadCurrentUserP(t *testing.T) {
}
}()
req := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserP(nil, &req)
r := httptest.NewRequest("GET", "/", nil)
_ = ab.LoadCurrentUserP(nil, &r)
}
func TestCtxKeyString(t *testing.T) {
t.Parallel()
if got := ctxKeyPID.String(); got != "authboss ctx key pid" {
t.Error(got)
}
}
*/

View File

@@ -1,23 +1,23 @@
package authboss
import (
"context"
"net/http"
"time"
)
var nowTime = time.Now
// TimeToExpiry returns zero if the user session is expired else the time until expiry.
func (a *Authboss) TimeToExpiry(w http.ResponseWriter, r *http.Request) time.Duration {
//TODO(aarondl): Rewrite this so it makes sense with new ClientStorer idioms
//return a.timeToExpiry(state.(ClientState))
return 0
// TimeToExpiry returns zero if the user session is expired else the time
// until expiry. Takes in the allowed idle duration.
func TimeToExpiry(w http.ResponseWriter, r *http.Request, expireAfter time.Duration) time.Duration {
return timeToExpiry(r, expireAfter)
}
func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
dateStr, ok := session.Get(SessionLastAction)
func timeToExpiry(r *http.Request, expireAfter time.Duration) time.Duration {
dateStr, ok := GetSession(r, SessionLastAction)
if !ok {
return a.ExpireAfter
return expireAfter
}
date, err := time.Parse(time.RFC3339, dateStr)
@@ -25,7 +25,7 @@ func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
panic("last_action is not a valid RFC3339 date")
}
remaining := date.Add(a.ExpireAfter).Sub(nowTime().UTC())
remaining := date.Add(expireAfter).Sub(nowTime().UTC())
if remaining > 0 {
return remaining
}
@@ -33,15 +33,14 @@ func (a *Authboss) timeToExpiry(session ClientState) time.Duration {
return 0
}
// RefreshExpiry updates the last action for the user, so he doesn't become expired.
func (a *Authboss) RefreshExpiry(w http.ResponseWriter, r *http.Request) {
//TODO(aarondl): Fix
//a.refreshExpiry(session)
// RefreshExpiry updates the last action for the user, so he doesn't
// become expired.
func RefreshExpiry(w http.ResponseWriter, r *http.Request) {
refreshExpiry(w)
}
func (a *Authboss) refreshExpiry(session ClientState) {
//TODO(aarondl): Fix
PutSession(nil, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
func refreshExpiry(w http.ResponseWriter) {
PutSession(w, SessionLastAction, nowTime().UTC().Format(time.RFC3339))
}
type expireMiddleware struct {
@@ -58,19 +57,21 @@ func (a *Authboss) ExpireMiddleware(next http.Handler) http.Handler {
return expireMiddleware{a, next}
}
// ServeHTTP removes the session if it's passed the expire time.
// ServeHTTP removes the session and hides the loaded user from the handlers
// below it.
func (m expireMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
//TODO(aarondl): Fix
/*
if _, ok := GetSession(r, SessionKey); ok {
ttl := m.ab.timeToExpiry(session)
if ttl == 0 {
session.Del(SessionKey)
session.Del(SessionLastAction)
} else {
m.ab.refreshExpiry(session)
}
if _, ok := GetSession(r, SessionKey); ok {
ttl := timeToExpiry(r, m.ab.ExpireAfter)
if ttl == 0 {
DelSession(w, SessionKey)
DelSession(w, SessionLastAction)
ctx := context.WithValue(r.Context(), ctxKeyPID, nil)
ctx = context.WithValue(ctx, ctxKeyUser, nil)
r = r.WithContext(ctx)
} else {
refreshExpiry(w)
}
}
m.next.ServeHTTP(w, r)*/
m.next.ServeHTTP(w, r)
}

View File

@@ -1,15 +1,30 @@
package authboss
/* TODO(aarondl): Re-enable
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// These tests use the global variable nowTime so cannot be parallelized
func TestDudeIsExpired(t *testing.T) {
func TestExpireIsExpired(t *testing.T) {
ab := New()
session := mockClientStore{SessionKey: "username"}
ab.refreshExpiry(session)
ab.SessionStateStorer = newMockClientStateRW(
SessionKey, "username",
SessionLastAction, time.Now().UTC().Format(time.RFC3339),
)
// No t.Parallel()
r := httptest.NewRequest("GET", "/", nil)
r = r.WithContext(context.WithValue(r.Context(), ctxKeyPID, "primaryid"))
r = r.WithContext(context.WithValue(r.Context(), ctxKeyUser, struct{}{}))
w := ab.NewResponse(httptest.NewRecorder(), r)
r, err := ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
// No t.Parallel() - Also must be after refreshExpiry() call
nowTime = func() time.Time {
return time.Now().UTC().Add(ab.ExpireAfter * 2)
}
@@ -17,62 +32,132 @@ func TestDudeIsExpired(t *testing.T) {
nowTime = time.Now
}()
ab.SessionStoreMaker = newMockClientStoreMaker(session)
r, _ := http.NewRequest("GET", "tra/la/la", nil)
w := httptest.NewRecorder()
called := false
hadUser := false
m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
if r.Context().Value(ctxKeyPID) != nil {
hadUser = true
}
if r.Context().Value(ctxKeyUser) != nil {
hadUser = true
}
}))
m.ServeHTTP(w, r)
if !called {
t.Error("Expected middleware to call handler")
t.Error("expected middleware to call handler")
}
if hadUser {
t.Error("expected user not to be present")
}
if key, ok := session.Get(SessionKey); ok {
t.Error("Unexpected session key:", key)
}
csrw := w.(*ClientStateResponseWriter)
if key, ok := session.Get(SessionLastAction); ok {
t.Error("Unexpected last action key:", key)
want := ClientStateEvent{
Kind: ClientStateEventDel,
Key: SessionKey,
}
if got := csrw.sessionStateEvents[0]; got != want {
t.Error("want:", want, "got:", got)
}
want = ClientStateEvent{
Kind: ClientStateEventDel,
Key: SessionLastAction,
}
if got := csrw.sessionStateEvents[1]; got != want {
t.Error("want:", want, "got:", got)
}
}
func TestDudeIsNotExpired(t *testing.T) {
func TestExpireNotExpired(t *testing.T) {
ab := New()
session := mockClientStore{SessionKey: "username"}
ab.refreshExpiry(session)
ab.SessionStateStorer = newMockClientStateRW(
SessionKey, "username",
SessionLastAction, time.Now().UTC().Format(time.RFC3339),
)
// No t.Parallel()
r := httptest.NewRequest("GET", "/", nil)
r = r.WithContext(context.WithValue(r.Context(), ctxKeyPID, "primaryid"))
r = r.WithContext(context.WithValue(r.Context(), ctxKeyUser, struct{}{}))
w := ab.NewResponse(httptest.NewRecorder(), r)
r, err := ab.LoadClientState(w, r)
if err != nil {
t.Error(err)
}
// No t.Parallel() - Also must be after refreshExpiry() call
newTime := time.Now().UTC().Add(ab.ExpireAfter / 2)
nowTime = func() time.Time {
return time.Now().UTC().Add(ab.ExpireAfter / 2)
return newTime
}
defer func() {
nowTime = time.Now
}()
ab.SessionStoreMaker = newMockClientStoreMaker(session)
r, _ := http.NewRequest("GET", "tra/la/la", nil)
w := httptest.NewRecorder()
called := false
hadUser := true
m := ab.ExpireMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
if r.Context().Value(ctxKeyPID) == nil {
hadUser = false
}
if r.Context().Value(ctxKeyUser) == nil {
hadUser = false
}
}))
m.ServeHTTP(w, r)
if !called {
t.Error("Expected middleware to call handler")
t.Error("expected middleware to call handler")
}
if !hadUser {
t.Error("expected user to be present")
}
if key, ok := session.Get(SessionKey); !ok {
t.Error("Expected session key:", key)
csrw := w.(*ClientStateResponseWriter)
want := ClientStateEvent{
Kind: ClientStateEventPut,
Key: SessionLastAction,
Value: newTime.Format(time.RFC3339),
}
if got := csrw.sessionStateEvents[0]; got != want {
t.Error("want:", want, "got:", got)
}
}
func TestExpireTimeToExpiry(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
want := 5 * time.Second
dur := TimeToExpiry(w, r, want)
if dur != want {
t.Error("duration was wrong:", dur)
}
}
func TestExpireRefreshExpiry(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("GET", "/", nil)
w := ab.NewResponse(httptest.NewRecorder(), r)
RefreshExpiry(w, r)
csrw := w.(*ClientStateResponseWriter)
if got := csrw.sessionStateEvents[0].Kind; got != ClientStateEventPut {
t.Error("wrong event:", got)
}
if got := csrw.sessionStateEvents[0].Key; got != SessionLastAction {
t.Error("wrong key:", got)
}
}
*/

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
@@ -191,13 +192,28 @@ func (m mockRenderLoader) Init(names []string) (Renderer, error) {
return mockRenderer{}, nil
}
type mockRenderer struct{}
type mockRenderer struct {
expectName string
}
func (m mockRenderer) Render(ctx context.Context, name string, data HTMLData) ([]byte, string, error) {
b, err := json.Marshal(data)
if err != nil {
return nil, "", err
if len(m.expectName) != 0 && m.expectName != name {
panic(fmt.Sprintf("want template name: %s, but got: %s", m.expectName, name))
}
return b, "application/json", nil
b, err := json.Marshal(data)
return b, "application/json", err
}
type mockEmailRenderer struct{}
func (m mockEmailRenderer) Render(ctx context.Context, name string, data HTMLData) ([]byte, string, error) {
switch name {
case "text":
return []byte("a development text e-mail template"), "text/plain", nil
case "html":
return []byte("a development html e-mail template"), "text/html", nil
default:
panic("shouldn't get here")
}
}

View File

@@ -1,7 +1,6 @@
package authboss
import (
"html/template"
"net/http"
"github.com/pkg/errors"
@@ -11,8 +10,8 @@ import (
// and writes the headers out.
func (a *Authboss) Respond(w http.ResponseWriter, r *http.Request, code int, templateName string, data HTMLData) error {
data.MergeKV(
"xsrfName", template.HTML(a.XSRFName),
"xsrfToken", template.HTML(a.XSRFMaker(w, r)),
"xsrfName", a.XSRFName,
"xsrfToken", a.XSRFMaker(w, r),
)
if a.LayoutDataMaker != nil {
@@ -123,7 +122,7 @@ func (a *Authboss) redirectAPI(w http.ResponseWriter, r *http.Request, ro Redire
}
data := HTMLData{
"path": path,
"location": path,
}
if len(status) != 0 {

256
response_test.go Normal file
View File

@@ -0,0 +1,256 @@
package authboss
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
)
func TestResponseRespond(t *testing.T) {
t.Parallel()
ab := New()
ab.renderer = mockRenderer{expectName: "some_template.tpl"}
ab.SessionStateStorer = newMockClientStateRW(
FlashSuccessKey, "flash_success",
FlashErrorKey, "flash_error",
)
ab.XSRFName = "xsrf"
ab.XSRFMaker = func(w http.ResponseWriter, r *http.Request) string {
return "xsrftoken"
}
ab.LayoutDataMaker = func(w http.ResponseWriter, r *http.Request) HTMLData {
return HTMLData{"hello": "world"}
}
r := httptest.NewRequest("GET", "/", nil)
wr := httptest.NewRecorder()
w := ab.NewResponse(wr, r)
r = loadClientStateP(ab, w, r)
err := ab.Respond(w, r, http.StatusCreated, "some_template.tpl", HTMLData{"auth_happy": true})
if err != nil {
t.Error(err)
}
if wr.Code != http.StatusCreated {
t.Error("code was wrong:", wr.Code)
}
if got := wr.HeaderMap.Get("Content-Type"); got != "application/json" {
t.Error("content type was wrong:", got)
}
expectData := HTMLData{
"xsrfName": "xsrf",
"xsrfToken": "xsrftoken",
"hello": "world",
FlashSuccessKey: "flash_success",
FlashErrorKey: "flash_error",
"auth_happy": true,
}
var gotData HTMLData
if err := json.Unmarshal(wr.Body.Bytes(), &gotData); err != nil {
t.Error(err)
}
if !reflect.DeepEqual(gotData, expectData) {
t.Errorf("data mismatched:\nwant: %#v\ngot: %#v", expectData, gotData)
}
}
func TestResponseEmail(t *testing.T) {
t.Parallel()
ab := New()
ab.renderer = mockEmailRenderer{}
ab.SessionStateStorer = newMockClientStateRW(
FlashSuccessKey, "flash_success",
FlashErrorKey, "flash_error",
)
ab.XSRFName = "xsrf"
ab.XSRFMaker = func(w http.ResponseWriter, r *http.Request) string {
return "xsrftoken"
}
ab.LayoutDataMaker = func(w http.ResponseWriter, r *http.Request) HTMLData {
return HTMLData{"hello": "world"}
}
output := &bytes.Buffer{}
ab.Mailer = LogMailer(output)
r := httptest.NewRequest("GET", "/", nil)
wr := httptest.NewRecorder()
w := ab.NewResponse(wr, r)
email := Email{
To: []string{"test@example.com"},
From: "test@example.com",
Subject: "subject",
}
ro := EmailResponseOptions{Data: nil, HTMLTemplate: "html", TextTemplate: "text"}
err := ab.Email(w, r, email, ro)
if err != nil {
t.Error(err)
}
wantStrings := []string{
"To: test@example.com",
"From: test@example.com",
"Subject: subject",
"development text e-mail",
"development html e-mail",
}
out := output.String()
for i, test := range wantStrings {
if !strings.Contains(out, test) {
t.Errorf("output missing string(%d): %s\n%s", i, test, out)
}
}
}
func TestResponseRedirectAPI(t *testing.T) {
t.Parallel()
ab := New()
ab.renderer = mockRenderer{}
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
w := httptest.NewRecorder()
r.Header.Set("Content-Type", "application/json")
ro := RedirectOptions{
Success: "ok!",
Code: http.StatusTeapot,
RedirectPath: "/redirect", FollowRedirParam: false,
}
if err := ab.Redirect(w, r, ro); err != nil {
t.Error(err)
}
if w.Code != http.StatusTeapot {
t.Error("code is wrong:", w.Code)
}
var gotData map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &gotData); err != nil {
t.Error(err)
}
if got := gotData["status"]; got != "success" {
t.Error("status was wrong:", got)
}
if got := gotData["message"]; got != "ok!" {
t.Error("message was wrong:", got)
}
if got := gotData["location"]; got != "/redirect" {
t.Error("location was wrong:", got)
}
}
func TestResponseRedirectAPIFollowRedir(t *testing.T) {
t.Parallel()
ab := New()
ab.renderer = mockRenderer{}
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
w := httptest.NewRecorder()
r.Header.Set("Content-Type", "application/json")
ro := RedirectOptions{
Failure: ":(",
Code: http.StatusTeapot,
RedirectPath: "/redirect", FollowRedirParam: true,
}
if err := ab.Redirect(w, r, ro); err != nil {
t.Error(err)
}
if w.Code != http.StatusTeapot {
t.Error("code is wrong:", w.Code)
}
var gotData map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &gotData); err != nil {
t.Error(err)
}
if got := gotData["status"]; got != "failure" {
t.Error("status was wrong:", got)
}
if got := gotData["message"]; got != ":(" {
t.Error("message was wrong:", got)
}
if got := gotData["location"]; got != "/pow" {
t.Error("location was wrong:", got)
}
}
func TestResponseRedirectNonAPI(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
wr := httptest.NewRecorder()
w := ab.NewResponse(wr, r)
ro := RedirectOptions{
Success: "success", Failure: "failure",
RedirectPath: "/redirect", FollowRedirParam: false,
}
if err := ab.Redirect(w, r, ro); err != nil {
t.Error(err)
}
csrw := w.(*ClientStateResponseWriter)
want := ClientStateEvent{Kind: ClientStateEventPut, Key: FlashSuccessKey, Value: "success"}
if csrw.sessionStateEvents[0] != want {
t.Error("event was wrong:", csrw.sessionStateEvents[0])
}
want = ClientStateEvent{Kind: ClientStateEventPut, Key: FlashErrorKey, Value: "failure"}
if csrw.sessionStateEvents[1] != want {
t.Error("event was wrong:", csrw.sessionStateEvents[1])
}
if wr.Code != http.StatusFound {
t.Error("code is wrong:", wr.Code)
}
if got := wr.Header().Get("Location"); got != "/redirect" {
t.Error("redirect location was wrong:", got)
}
}
func TestResponseRedirectNonAPIFollowRedir(t *testing.T) {
t.Parallel()
ab := New()
r := httptest.NewRequest("POST", "/?redir=/pow", nil)
wr := httptest.NewRecorder()
w := ab.NewResponse(wr, r)
ro := RedirectOptions{
RedirectPath: "/redirect", FollowRedirParam: true,
}
if err := ab.Redirect(w, r, ro); err != nil {
t.Error(err)
}
csrw := w.(*ClientStateResponseWriter)
if len(csrw.sessionStateEvents) != 0 {
t.Error("session state events should be empty:", csrw.sessionStateEvents)
}
if wr.Code != http.StatusFound {
t.Error("code is wrong:", wr.Code)
}
if got := wr.Header().Get("Location"); got != "/pow" {
t.Error("redirect location was wrong:", got)
}
}