2015-01-11 08:52:39 +02:00
|
|
|
package remember
|
|
|
|
|
|
|
|
import (
|
2015-01-15 12:56:13 +02:00
|
|
|
"bytes"
|
2018-03-08 01:13:06 +02:00
|
|
|
"context"
|
|
|
|
"crypto/sha512"
|
|
|
|
"encoding/base64"
|
2015-01-15 12:56:13 +02:00
|
|
|
"net/http"
|
2018-03-08 01:13:06 +02:00
|
|
|
"net/http/httptest"
|
2015-01-11 08:52:39 +02:00
|
|
|
"testing"
|
|
|
|
|
2017-07-31 04:39:33 +02:00
|
|
|
"github.com/volatiletech/authboss"
|
|
|
|
"github.com/volatiletech/authboss/internal/mocks"
|
2015-01-11 08:52:39 +02:00
|
|
|
)
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
func TestInit(t *testing.T) {
|
2015-04-01 00:27:47 +02:00
|
|
|
t.Parallel()
|
2015-01-13 00:02:07 +02:00
|
|
|
|
2015-04-01 00:27:47 +02:00
|
|
|
ab := authboss.New()
|
2015-01-13 00:02:07 +02:00
|
|
|
r := &Remember{}
|
2018-03-08 01:13:06 +02:00
|
|
|
err := r.Init(ab)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
2018-03-08 01:13:06 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
type testHarness struct {
|
|
|
|
remember *Remember
|
|
|
|
ab *authboss.Authboss
|
|
|
|
|
|
|
|
session *mocks.ClientStateRW
|
|
|
|
cookies *mocks.ClientStateRW
|
|
|
|
storer *mocks.ServerStorer
|
|
|
|
}
|
|
|
|
|
|
|
|
func testSetup() *testHarness {
|
|
|
|
harness := &testHarness{}
|
|
|
|
|
|
|
|
harness.ab = authboss.New()
|
|
|
|
harness.session = mocks.NewClientRW()
|
|
|
|
harness.cookies = mocks.NewClientRW()
|
|
|
|
harness.storer = mocks.NewServerStorer()
|
|
|
|
|
|
|
|
harness.ab.Config.Core.Logger = mocks.Logger{}
|
|
|
|
harness.ab.Config.Storage.SessionState = harness.session
|
|
|
|
harness.ab.Config.Storage.CookieState = harness.cookies
|
|
|
|
harness.ab.Config.Storage.Server = harness.storer
|
|
|
|
|
|
|
|
harness.remember = &Remember{harness.ab}
|
2015-01-13 00:02:07 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
return harness
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestRememberAfterAuth(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
h := testSetup()
|
|
|
|
|
|
|
|
user := &mocks.User{Email: "test@test.com"}
|
|
|
|
|
|
|
|
r := mocks.Request("POST")
|
|
|
|
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyRM, true))
|
|
|
|
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
|
|
|
|
|
|
|
if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
} else if handled {
|
|
|
|
t.Error("should never be handled")
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
// Force flush of headers so cookies are written
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
|
|
|
|
if len(h.storer.RMTokens["test@test.com"]) != 1 {
|
|
|
|
t.Error("token was not persisted:", h.storer.RMTokens)
|
|
|
|
}
|
|
|
|
|
|
|
|
if cookie, ok := h.cookies.ClientValues[authboss.CookieRemember]; !ok || len(cookie) == 0 {
|
|
|
|
t.Error("remember me cookie was not set")
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
func TestRememberAfterAuthSkip(t *testing.T) {
|
2015-04-01 00:27:47 +02:00
|
|
|
t.Parallel()
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h := testSetup()
|
2015-02-16 06:07:36 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
r := mocks.Request("POST")
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
2015-01-15 12:56:13 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
} else if handled {
|
|
|
|
t.Error("should never be handled")
|
2015-01-15 12:56:13 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if len(h.storer.RMTokens["test@test.com"]) != 0 {
|
|
|
|
t.Error("expected no tokens to be created")
|
|
|
|
}
|
2015-01-15 12:56:13 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyRM, false))
|
2015-08-02 23:02:14 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if handled, err := h.remember.RememberAfterAuth(w, r, false); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
} else if handled {
|
|
|
|
t.Error("should never be handled")
|
2015-02-22 22:55:09 +02:00
|
|
|
}
|
2015-01-15 12:56:13 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if len(h.storer.RMTokens["test@test.com"]) != 0 {
|
|
|
|
t.Error("expected no tokens to be created")
|
2015-01-15 12:56:13 +02:00
|
|
|
}
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
func TestMiddlewareAuth(t *testing.T) {
|
2015-04-01 00:27:47 +02:00
|
|
|
t.Parallel()
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h := testSetup()
|
|
|
|
|
|
|
|
user := &mocks.User{Email: "test@test.com"}
|
|
|
|
hash, token, _ := GenerateToken(user.Email)
|
2015-03-14 01:23:43 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h.storer.Users[user.Email] = user
|
|
|
|
h.storer.RMTokens[user.Email] = []string{hash}
|
|
|
|
h.cookies.ClientValues[authboss.CookieRemember] = token
|
2015-03-14 01:23:43 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
r := mocks.Request("POST")
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
|
|
|
|
|
|
|
var err error
|
|
|
|
r, err = h.ab.LoadClientState(w, r)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
2015-03-14 01:23:43 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
called := false
|
|
|
|
server := Middleware(h.ab)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
called = true
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
}))
|
|
|
|
|
|
|
|
server.ServeHTTP(w, r)
|
|
|
|
|
|
|
|
if !called {
|
|
|
|
t.Error("it should have called the underlying handler")
|
2015-03-14 01:23:43 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if h.session.ClientValues[authboss.SessionKey] != user.Email {
|
|
|
|
t.Error("should have saved the pid in the session")
|
2015-03-14 01:23:43 +02:00
|
|
|
}
|
2018-03-08 01:13:06 +02:00
|
|
|
// Elided the rest of the checks, authenticate tests do this
|
2015-03-14 01:23:43 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
func TestAuthenticateSuccess(t *testing.T) {
|
2015-04-01 00:27:47 +02:00
|
|
|
t.Parallel()
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h := testSetup()
|
2015-03-06 06:05:47 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
user := &mocks.User{Email: "test@test.com"}
|
|
|
|
hash, token, _ := GenerateToken(user.Email)
|
2015-03-06 06:05:47 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h.storer.Users[user.Email] = user
|
|
|
|
h.storer.RMTokens[user.Email] = []string{hash}
|
|
|
|
h.cookies.ClientValues[authboss.CookieRemember] = token
|
2015-03-06 06:05:47 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
r := mocks.Request("POST")
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
2015-03-06 06:05:47 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
var err error
|
|
|
|
r, err = h.ab.LoadClientState(w, r)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
2015-03-06 06:05:47 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if err = Authenticate(h.ab, w, r); err != nil {
|
|
|
|
t.Fatal(err)
|
2015-03-06 06:05:47 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
|
|
|
|
if cookie := h.cookies.ClientValues[authboss.CookieRemember]; cookie == token {
|
|
|
|
t.Error("the cookie should have been replaced with a new token")
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(h.storer.RMTokens[user.Email]) != 1 {
|
|
|
|
t.Error("one token should have been removed, and one should have been added")
|
|
|
|
} else if h.storer.RMTokens[user.Email][0] == token {
|
|
|
|
t.Error("a new token should have been saved")
|
|
|
|
}
|
|
|
|
|
|
|
|
if h.session.ClientValues[authboss.SessionKey] != user.Email {
|
|
|
|
t.Error("should have saved the pid in the session")
|
|
|
|
}
|
|
|
|
if h.session.ClientValues[authboss.SessionHalfAuthKey] != "true" {
|
|
|
|
t.Error("it should have become a half-authed session")
|
2015-03-06 06:05:47 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
func TestAuthenticateTokenNotFound(t *testing.T) {
|
2015-04-01 00:27:47 +02:00
|
|
|
t.Parallel()
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h := testSetup()
|
|
|
|
|
|
|
|
user := &mocks.User{Email: "test@test.com"}
|
|
|
|
_, token, _ := GenerateToken(user.Email)
|
2015-01-13 00:02:07 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h.storer.Users[user.Email] = user
|
|
|
|
h.cookies.ClientValues[authboss.CookieRemember] = token
|
2015-01-13 00:02:07 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
r := mocks.Request("POST")
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
|
|
|
|
|
|
|
var err error
|
|
|
|
r, err = h.ab.LoadClientState(w, r)
|
2015-01-13 00:02:07 +02:00
|
|
|
if err != nil {
|
2018-03-08 01:13:06 +02:00
|
|
|
t.Fatal(err)
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if err = Authenticate(h.ab, w, r); err != nil {
|
|
|
|
t.Fatal(err)
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
|
|
|
|
if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 {
|
|
|
|
t.Error("there should be no remember cookie left")
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
|
|
|
|
t.Error("it should have not logged the user in")
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
func TestAuthenticateBadTokens(t *testing.T) {
|
2015-04-01 00:27:47 +02:00
|
|
|
t.Parallel()
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h := testSetup()
|
2015-02-27 09:09:37 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
doTest := func(t *testing.T) {
|
|
|
|
t.Helper()
|
2015-01-13 00:02:07 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
r := mocks.Request("POST")
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
|
|
|
|
|
|
|
var err error
|
|
|
|
r, err = h.ab.LoadClientState(w, r)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if err = Authenticate(h.ab, w, r); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
|
|
|
|
if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 {
|
|
|
|
t.Error("there should be no remember cookie left")
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(h.session.ClientValues[authboss.SessionKey]) != 0 {
|
|
|
|
t.Error("it should have not logged the user in")
|
|
|
|
}
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
t.Run("base64", func(t *testing.T) {
|
|
|
|
h.cookies.ClientValues[authboss.CookieRemember] = "a"
|
|
|
|
doTest(t)
|
|
|
|
})
|
|
|
|
t.Run("cookieformat", func(t *testing.T) {
|
|
|
|
h.cookies.ClientValues[authboss.CookieRemember] = `aGVsbG8=` // hello
|
|
|
|
doTest(t)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAfterPasswordReset(t *testing.T) {
|
|
|
|
t.Parallel()
|
2015-03-02 06:40:09 +02:00
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
h := testSetup()
|
|
|
|
|
|
|
|
user := &mocks.User{Email: "test@test.com"}
|
|
|
|
hash1, _, _ := GenerateToken(user.Email)
|
|
|
|
hash2, token2, _ := GenerateToken(user.Email)
|
|
|
|
|
|
|
|
h.storer.Users[user.Email] = user
|
|
|
|
h.storer.RMTokens[user.Email] = []string{hash1, hash2}
|
|
|
|
h.cookies.ClientValues[authboss.CookieRemember] = token2
|
|
|
|
|
|
|
|
r := mocks.Request("POST")
|
|
|
|
r = r.WithContext(context.WithValue(r.Context(), authboss.CTXKeyUser, user))
|
|
|
|
rec := httptest.NewRecorder()
|
|
|
|
w := h.ab.NewResponse(rec, r)
|
|
|
|
|
|
|
|
if handled, err := h.remember.AfterPasswordReset(w, r, false); err != nil {
|
|
|
|
t.Error(err)
|
|
|
|
} else if handled {
|
|
|
|
t.Error("it should never be handled")
|
|
|
|
}
|
|
|
|
|
|
|
|
w.WriteHeader(http.StatusOK) // Force header flush
|
|
|
|
|
|
|
|
if len(h.storer.RMTokens[user.Email]) != 0 {
|
|
|
|
t.Error("all remember me tokens should have been removed")
|
|
|
|
}
|
|
|
|
if len(h.cookies.ClientValues[authboss.CookieRemember]) != 0 {
|
|
|
|
t.Error("there should be no remember cookie left")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestGenerateToken(t *testing.T) {
|
|
|
|
t.Parallel()
|
|
|
|
|
|
|
|
hash, tok, err := GenerateToken("test")
|
2015-01-13 00:02:07 +02:00
|
|
|
if err != nil {
|
2018-03-08 01:13:06 +02:00
|
|
|
t.Fatal(err)
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
rawToken, err := base64.URLEncoding.DecodeString(tok)
|
|
|
|
if err != nil {
|
|
|
|
t.Error(err)
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
index := bytes.IndexByte(rawToken, ';')
|
|
|
|
if index < 0 {
|
|
|
|
t.Fatalf("problem with the token format: %v", rawToken)
|
2015-01-13 00:02:07 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
bytPID := rawToken[:index]
|
|
|
|
if string(bytPID) != "test" {
|
|
|
|
t.Errorf("pid wrong: %s", bytPID)
|
2015-03-02 06:40:09 +02:00
|
|
|
}
|
|
|
|
|
2018-03-08 01:13:06 +02:00
|
|
|
sum := sha512.Sum512(rawToken)
|
|
|
|
gotHash := base64.StdEncoding.EncodeToString(sum[:])
|
|
|
|
if hash != gotHash {
|
|
|
|
t.Errorf("hash wrong, want: %s, got: %s", hash, gotHash)
|
2015-01-11 08:52:39 +02:00
|
|
|
}
|
|
|
|
}
|