diff --git a/remember/remember.go b/remember/remember.go index 052923e..690dd7d 100644 --- a/remember/remember.go +++ b/remember/remember.go @@ -3,6 +3,7 @@ package remember import ( "bytes" + "context" "crypto/rand" "crypto/sha512" "encoding/base64" @@ -72,7 +73,7 @@ func Middleware(ab *authboss.Authboss) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Context().Value(authboss.CTXKeyPID) == nil && r.Context().Value(authboss.CTXKeyUser) == nil { - if err := Authenticate(ab, w, r); err != nil { + if err := Authenticate(ab, w, &r); err != nil { logger := ab.RequestLogger(r) logger.Errorf("failed to authenticate user via remember me: %+v", err) } @@ -89,9 +90,12 @@ func Middleware(ab *authboss.Authboss) func(http.Handler) http.Handler { // - Can't decode the base64 // - Invalid token format // - Can't find token in DB -func Authenticate(ab *authboss.Authboss, w http.ResponseWriter, req *http.Request) error { - logger := ab.RequestLogger(req) - cookie, ok := authboss.GetCookie(req, authboss.CookieRemember) +// +// In order to authenticate it adds to the request context as well as to the +// cookie and session states. +func Authenticate(ab *authboss.Authboss, w http.ResponseWriter, req **http.Request) error { + logger := ab.RequestLogger(*req) + cookie, ok := authboss.GetCookie(*req, authboss.CookieRemember) if !ok { return nil } @@ -131,9 +135,10 @@ func Authenticate(ab *authboss.Authboss, w http.ResponseWriter, req *http.Reques } if err = storer.AddRememberToken(pid, hash); err != nil { - return errors.Wrap(err, "failed to save me token") + return errors.Wrap(err, "failed to save remember me token") } + *req = (*req).WithContext(context.WithValue((*req).Context(), authboss.CTXKeyPID, pid)) authboss.PutSession(w, authboss.SessionKey, pid) authboss.PutSession(w, authboss.SessionHalfAuthKey, "true") authboss.DelCookie(w, authboss.CookieRemember) diff --git a/remember/remember_test.go b/remember/remember_test.go index 0a79bc2..24eccae 100644 --- a/remember/remember_test.go +++ b/remember/remember_test.go @@ -176,7 +176,7 @@ func TestAuthenticateSuccess(t *testing.T) { t.Fatal(err) } - if err = Authenticate(h.ab, w, r); err != nil { + if err = Authenticate(h.ab, w, &r); err != nil { t.Fatal(err) } @@ -198,6 +198,10 @@ func TestAuthenticateSuccess(t *testing.T) { if h.session.ClientValues[authboss.SessionHalfAuthKey] != "true" { t.Error("it should have become a half-authed session") } + + if r.Context().Value(authboss.CTXKeyPID).(string) != "test@test.com" { + t.Error("should have set the context value to log the user in") + } } func TestAuthenticateTokenNotFound(t *testing.T) { @@ -221,7 +225,7 @@ func TestAuthenticateTokenNotFound(t *testing.T) { t.Fatal(err) } - if err = Authenticate(h.ab, w, r); err != nil { + if err = Authenticate(h.ab, w, &r); err != nil { t.Fatal(err) } @@ -234,6 +238,10 @@ func TestAuthenticateTokenNotFound(t *testing.T) { if len(h.session.ClientValues[authboss.SessionKey]) != 0 { t.Error("it should have not logged the user in") } + + if r.Context().Value(authboss.CTXKeyPID) != nil { + t.Error("the context's pid should be empty") + } } func TestAuthenticateBadTokens(t *testing.T) { @@ -254,7 +262,7 @@ func TestAuthenticateBadTokens(t *testing.T) { t.Fatal(err) } - if err = Authenticate(h.ab, w, r); err != nil { + if err = Authenticate(h.ab, w, &r); err != nil { t.Fatal(err) } @@ -267,6 +275,10 @@ func TestAuthenticateBadTokens(t *testing.T) { if len(h.session.ClientValues[authboss.SessionKey]) != 0 { t.Error("it should have not logged the user in") } + + if r.Context().Value(authboss.CTXKeyPID) != nil { + t.Error("the context's pid should be empty") + } } t.Run("base64", func(t *testing.T) {