package auth import ( "bytes" "errors" "html/template" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "gopkg.in/authboss.v0" "gopkg.in/authboss.v0/internal/mocks" ) func testSetup() (a *Auth, s *mocks.MockStorer) { s = mocks.NewMockStorer() ab := authboss.New() ab.LogWriter = ioutil.Discard ab.Layout = template.Must(template.New("").Parse(`{{template "authboss" .}}`)) ab.Storer = s ab.XSRFName = "xsrf" ab.XSRFMaker = func(_ http.ResponseWriter, _ *http.Request) string { return "xsrfvalue" } ab.PrimaryID = authboss.StoreUsername a = &Auth{} if err := a.Initialize(ab); err != nil { panic(err) } return a, s } func testRequest(ab *authboss.Authboss, method string, postFormValues ...string) (*authboss.Context, *httptest.ResponseRecorder, *http.Request, authboss.ClientStorerErr) { r, err := http.NewRequest(method, "", nil) if err != nil { panic(err) } sessionStorer := mocks.NewMockClientStorer() ctx := mocks.MockRequestContext(ab, postFormValues...) ctx.SessionStorer = sessionStorer return ctx, httptest.NewRecorder(), r, sessionStorer } func TestAuth(t *testing.T) { t.Parallel() a, _ := testSetup() storage := a.Storage() if storage[a.PrimaryID] != authboss.String { t.Error("Expected storage KV:", a.PrimaryID, authboss.String) } if storage[authboss.StorePassword] != authboss.String { t.Error("Expected storage KV:", authboss.StorePassword, authboss.String) } routes := a.Routes() if routes["/login"] == nil { t.Error("Expected route '/login' with handleFunc") } if routes["/logout"] == nil { t.Error("Expected route '/logout' with handleFunc") } } func TestAuth_loginHandlerFunc_GET(t *testing.T) { t.Parallel() a, _ := testSetup() ctx, w, r, _ := testRequest(a.Authboss, "GET") if err := a.loginHandlerFunc(ctx, w, r); err != nil { t.Error("Unexpected error:", err) } if w.Code != http.StatusOK { t.Error("Unexpected status:", w.Code) } body := w.Body.String() if !strings.Contains(body, " Unexpected error '%s'", i, err) } w := httptest.NewRecorder() if err := a.loginHandlerFunc(nil, w, r); err != nil { t.Errorf("%d> Unexpected error: %s", i, err) } if http.StatusMethodNotAllowed != w.Code { t.Errorf("%d> Expected status code %d, got %d", i, http.StatusMethodNotAllowed, w.Code) continue } } } func TestAuth_validateCredentials(t *testing.T) { t.Parallel() ab := authboss.New() storer := mocks.NewMockStorer() ab.Storer = storer ctx := ab.NewContext() storer.Users["john"] = authboss.Attributes{"password": "$2a$10$pgFsuQwdhwOdZp/v52dvHeEi53ZaI7dGmtwK4bAzGGN5A4nT6doqm"} if _, err := validateCredentials(ctx, "john", "a"); err != nil { t.Error("Unexpected error:", err) } ctx = ab.NewContext() if valid, err := validateCredentials(ctx, "jake", "a"); err != nil { t.Error("Expect no error when user not found:", err) } else if valid { t.Error("Expect invalid when not user found") } ctx = ab.NewContext() storer.GetErr = "Failed to load user" if _, err := validateCredentials(ctx, "", ""); err.Error() != "Failed to load user" { t.Error("Unexpected error:", err) } } func TestAuth_logoutHandlerFunc_GET(t *testing.T) { t.Parallel() a, _ := testSetup() a.AuthLogoutOKPath = "/dashboard" ctx, w, r, sessionStorer := testRequest(a.Authboss, "GET") sessionStorer.Put(authboss.SessionKey, "asdf") sessionStorer.Put(authboss.SessionLastAction, "1234") cookieStorer := mocks.NewMockClientStorer(authboss.CookieRemember, "qwert") ctx.CookieStorer = cookieStorer if err := a.logoutHandlerFunc(ctx, w, r); err != nil { t.Error("Unexpected error:", err) } if val, ok := sessionStorer.Get(authboss.SessionKey); ok { t.Error("Unexpected session key:", val) } if val, ok := sessionStorer.Get(authboss.SessionLastAction); ok { t.Error("Unexpected last action:", val) } if val, ok := cookieStorer.Get(authboss.CookieRemember); ok { t.Error("Unexpected rm cookie:", val) } if http.StatusFound != w.Code { t.Errorf("Expected status code %d, got %d", http.StatusFound, w.Code) } location := w.Header().Get("Location") if location != "/dashboard" { t.Errorf("Expected lcoation %s, got %s", "/dashboard", location) } } func TestAuth_logoutHandlerFunc_OtherMethods(t *testing.T) { a, _ := testSetup() methods := []string{"HEAD", "POST", "PUT", "DELETE", "TRACE", "CONNECT"} for i, method := range methods { r, err := http.NewRequest(method, "/logout", nil) if err != nil { t.Errorf("%d> Unexpected error '%s'", i, err) } w := httptest.NewRecorder() if err := a.logoutHandlerFunc(nil, w, r); err != nil { t.Errorf("%d> Unexpected error: %s", i, err) } if http.StatusMethodNotAllowed != w.Code { t.Errorf("%d> Expected status code %d, got %d", i, http.StatusMethodNotAllowed, w.Code) continue } } }