mirror of
https://github.com/volatiletech/authboss.git
synced 2024-11-28 08:58:38 +02:00
f65d9f6bb6
- Fix many compilation errors
335 lines
8.4 KiB
Go
335 lines
8.4 KiB
Go
package authboss
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const testRouterModName = "testrouter"
|
|
|
|
func init() {
|
|
RegisterModule(testRouterModName, testRouterModule{})
|
|
}
|
|
|
|
type testRouterModule struct {
|
|
routes RouteTable
|
|
}
|
|
|
|
func (t testRouterModule) Initialize(ab *Authboss) error { return nil }
|
|
func (t testRouterModule) Routes() RouteTable { return t.routes }
|
|
|
|
func testRouterSetup() (*Authboss, http.Handler, *bytes.Buffer) {
|
|
ab := New()
|
|
logger := &bytes.Buffer{}
|
|
ab.LogWriter = logger
|
|
ab.Init(testRouterModName)
|
|
ab.MountPath = "/prefix"
|
|
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
|
|
ab.CookieStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return mockClientStore{} }
|
|
|
|
logger.Reset() // Clear out the module load messages
|
|
|
|
return ab, ab.NewRouter(), logger
|
|
}
|
|
|
|
// testRouterCallbackSetup is NOT safe for use by multiple goroutines, don't use parallel
|
|
func testRouterCallbackSetup(path string, h HandlerFunc) (w *httptest.ResponseRecorder, r *http.Request) {
|
|
registeredModules[testRouterModName] = testRouterModule{
|
|
routes: map[string]HandlerFunc{path: h},
|
|
}
|
|
|
|
w = httptest.NewRecorder()
|
|
r, _ = http.NewRequest("GET", "http://localhost/prefix"+path, nil)
|
|
|
|
return w, r
|
|
}
|
|
|
|
func TestRouter(t *testing.T) {
|
|
called := false
|
|
|
|
w, r := testRouterCallbackSetup("/called", func(http.ResponseWriter, *http.Request) error {
|
|
called = true
|
|
return nil
|
|
})
|
|
|
|
_, router, _ := testRouterSetup()
|
|
|
|
router.ServeHTTP(w, r)
|
|
|
|
if !called {
|
|
t.Error("Expected handler to be called.")
|
|
}
|
|
}
|
|
|
|
func TestRouter_NotFound(t *testing.T) {
|
|
ab, router, _ := testRouterSetup()
|
|
w := httptest.NewRecorder()
|
|
r, _ := http.NewRequest("GET", "http://localhost/wat", nil)
|
|
|
|
router.ServeHTTP(w, r)
|
|
if w.Code != http.StatusNotFound {
|
|
t.Error("Wrong code:", w.Code)
|
|
}
|
|
if body := w.Body.String(); body != "404 Page not found" {
|
|
t.Error("Wrong body:", body)
|
|
}
|
|
|
|
called := false
|
|
ab.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
})
|
|
|
|
router.ServeHTTP(w, r)
|
|
if !called {
|
|
t.Error("Should be called.")
|
|
}
|
|
}
|
|
|
|
func TestRouter_BadRequest(t *testing.T) {
|
|
err := ClientDataErr{"what"}
|
|
w, r := testRouterCallbackSetup("/badrequest",
|
|
func(http.ResponseWriter, *http.Request) error {
|
|
return err
|
|
},
|
|
)
|
|
|
|
ab, router, logger := testRouterSetup()
|
|
logger.Reset()
|
|
router.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Error("Wrong code:", w.Code)
|
|
}
|
|
if body := w.Body.String(); body != "400 Bad request" {
|
|
t.Error("Wrong body:", body)
|
|
}
|
|
|
|
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
|
t.Error(str)
|
|
}
|
|
|
|
called := false
|
|
ab.BadRequestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
})
|
|
|
|
logger.Reset()
|
|
router.ServeHTTP(w, r)
|
|
if !called {
|
|
t.Error("Should be called.")
|
|
}
|
|
|
|
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
|
t.Error(str)
|
|
}
|
|
}
|
|
|
|
func TestRouter_Error(t *testing.T) {
|
|
err := errors.New("error")
|
|
w, r := testRouterCallbackSetup("/error",
|
|
func(http.ResponseWriter, *http.Request) error {
|
|
return err
|
|
},
|
|
)
|
|
|
|
ab, router, logger := testRouterSetup()
|
|
logger.Reset()
|
|
router.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Error("Wrong code:", w.Code)
|
|
}
|
|
if body := w.Body.String(); body != "500 An error has occurred" {
|
|
t.Error("Wrong body:", body)
|
|
}
|
|
|
|
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
|
t.Error(str)
|
|
}
|
|
|
|
called := false
|
|
ab.ErrorHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
called = true
|
|
})
|
|
|
|
logger.Reset()
|
|
router.ServeHTTP(w, r)
|
|
if !called {
|
|
t.Error("Should be called.")
|
|
}
|
|
|
|
if str := logger.String(); !strings.Contains(str, err.Error()) {
|
|
t.Error(str)
|
|
}
|
|
}
|
|
|
|
func TestRouter_Redirect(t *testing.T) {
|
|
err := ErrAndRedirect{
|
|
Err: errors.New("error"),
|
|
Location: "/",
|
|
FlashSuccess: "yay",
|
|
FlashError: "nay",
|
|
}
|
|
|
|
w, r := testRouterCallbackSetup("/error",
|
|
func(http.ResponseWriter, *http.Request) error {
|
|
return err
|
|
},
|
|
)
|
|
|
|
ab, router, logger := testRouterSetup()
|
|
|
|
session := mockClientStore{}
|
|
ab.SessionStoreMaker = func(w http.ResponseWriter, r *http.Request) ClientStorer { return session }
|
|
|
|
logger.Reset()
|
|
router.ServeHTTP(w, r)
|
|
|
|
if w.Code != http.StatusFound {
|
|
t.Error("Wrong code:", w.Code)
|
|
}
|
|
if loc := w.Header().Get("Location"); loc != err.Location {
|
|
t.Error("Wrong location:", loc)
|
|
}
|
|
if succ, ok := session.Get(FlashSuccessKey); !ok || succ != err.FlashSuccess {
|
|
t.Error(succ, ok)
|
|
}
|
|
if fail, ok := session.Get(FlashErrorKey); !ok || fail != err.FlashError {
|
|
t.Error(fail, ok)
|
|
}
|
|
}
|
|
|
|
func TestRouter_redirectIfLoggedIn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
Path string
|
|
LoggedIn bool
|
|
HalfAuthed bool
|
|
|
|
ShouldRedirect bool
|
|
}{
|
|
// These routes will be accessed depending on logged in and half auth's value
|
|
{"/auth", false, false, false},
|
|
{"/auth", true, false, true},
|
|
{"/auth", true, true, false},
|
|
{"/oauth2/facebook", false, false, false},
|
|
{"/oauth2/facebook", true, false, true},
|
|
{"/oauth2/facebook", true, true, false},
|
|
{"/oauth2/callback/facebook", false, false, false},
|
|
{"/oauth2/callback/facebook", true, false, true},
|
|
{"/oauth2/callback/facebook", true, true, false},
|
|
// These are logout routes and never redirect
|
|
{"/logout", true, false, false},
|
|
{"/logout", true, true, false},
|
|
{"/oauth2/logout", true, false, false},
|
|
{"/oauth2/logout", true, true, false},
|
|
// These routes should always redirect despite half auth
|
|
{"/register", true, true, true},
|
|
{"/recover", true, true, true},
|
|
{"/register", false, false, false},
|
|
{"/recover", false, false, false},
|
|
}
|
|
|
|
storer := mockStoreLoader{"john@john.com": mockUser{
|
|
Email: "john@john.com",
|
|
Password: "password",
|
|
}}
|
|
ab := New()
|
|
ab.StoreLoader = storer
|
|
|
|
for i, test := range tests {
|
|
session := mockClientStore{}
|
|
cookies := mockClientStore{}
|
|
ctx := context.TODO()
|
|
ctx.SessionStorer = session
|
|
ctx.CookieStorer = cookies
|
|
|
|
if test.LoggedIn {
|
|
session[SessionKey] = "john@john.com"
|
|
}
|
|
if test.HalfAuthed {
|
|
session[SessionHalfAuthKey] = "true"
|
|
}
|
|
|
|
r, _ := http.NewRequest("GET", test.Path, nil)
|
|
w := httptest.NewRecorder()
|
|
handled := redirectIfLoggedIn(ctx, w, r)
|
|
|
|
if test.ShouldRedirect && (!handled || w.Code != http.StatusFound) {
|
|
t.Errorf("%d) It should have redirected the request: %q %t %d", i, test.Path, handled, w.Code)
|
|
} else if !test.ShouldRedirect && (handled || w.Code != http.StatusOK) {
|
|
t.Errorf("%d) It should have NOT redirected the request: %q %t %d", i, test.Path, handled, w.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
type deathStorer struct{}
|
|
|
|
func (d deathStorer) Create(key string, attributes Attributes) error { return nil }
|
|
func (d deathStorer) Put(key string, attributes Attributes) error { return nil }
|
|
func (d deathStorer) Get(key string) (interface{}, error) { return nil, errors.New("explosion") }
|
|
|
|
func TestRouter_redirectIfLoggedInError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ab := New()
|
|
ab.LogWriter = ioutil.Discard
|
|
ab.Storer = deathStorer{}
|
|
|
|
session := mockClientStore{SessionKey: "john"}
|
|
cookies := mockClientStore{}
|
|
ctx := context.TODO()
|
|
ctx.SessionStorer = session
|
|
ctx.CookieStorer = cookies
|
|
|
|
r, _ := http.NewRequest("GET", "/auth", nil)
|
|
w := httptest.NewRecorder()
|
|
handled := redirectIfLoggedIn(ctx, w, r)
|
|
|
|
if !handled {
|
|
t.Error("It should have been handled.")
|
|
}
|
|
if w.Code != http.StatusInternalServerError {
|
|
t.Error("It should have internal server error'd:", w.Code)
|
|
}
|
|
}
|
|
|
|
type notFoundStorer struct{}
|
|
|
|
func (n notFoundStorer) Create(key string, attributes Attributes) error { return nil }
|
|
func (n notFoundStorer) Put(key string, attributes Attributes) error { return nil }
|
|
func (n notFoundStorer) Get(key string) (interface{}, error) { return nil, ErrUserNotFound }
|
|
|
|
func TestRouter_redirectIfLoggedInUserNotFound(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ab := New()
|
|
ab.LogWriter = ioutil.Discard
|
|
ab.Storer = notFoundStorer{}
|
|
|
|
session := mockClientStore{SessionKey: "john"}
|
|
cookies := mockClientStore{}
|
|
ctx := context.TODO()
|
|
ctx.SessionStorer = session
|
|
ctx.CookieStorer = cookies
|
|
|
|
r, _ := http.NewRequest("GET", "/auth", nil)
|
|
w := httptest.NewRecorder()
|
|
handled := redirectIfLoggedIn(ctx, w, r)
|
|
|
|
if handled {
|
|
t.Error("It should not have been handled.")
|
|
}
|
|
if _, ok := session.Get(SessionKey); ok {
|
|
t.Error("It should have removed the bad session cookie")
|
|
}
|
|
}
|