1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-02-09 13:46:51 +02:00

Improvements to Session State code (#536)

* Drop SessionStateJSON wrapper
* Use EncrpytInto/DecryptInto to reduce sessionstate

Co-authored-by: Henry Jenkins <henry@henryjenkins.name>
This commit is contained in:
Joel Speed 2020-05-30 08:53:38 +01:00 committed by GitHub
parent 6a88da7f7a
commit f7b28cb1d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 185 additions and 164 deletions

View File

@ -51,6 +51,7 @@
## Changes since v5.1.1
- [#536](https://github.com/oauth2-proxy/oauth2-proxy/pull/536) Improvements to Session State code (@JoelSpeed)
- [#573](https://github.com/oauth2-proxy/oauth2-proxy/pull/573) Properly parse redis urls for cluster and sentinel connections (@amnay-mo)
- [#574](https://github.com/oauth2-proxy/oauth2-proxy/pull/574) render error page on 502 proxy status (@amnay-mo)
- [#559](https://github.com/oauth2-proxy/oauth2-proxy/pull/559) Rename cookie-domain config to cookie-domains (@JoelSpeed)

View File

@ -484,8 +484,7 @@ func TestBasicAuthPassword(t *testing.T) {
})
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader(""))
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
proxy.ServeHTTP(rw, req)
if rw.Code >= 400 {
@ -541,11 +540,12 @@ func TestBasicAuthWithEmail(t *testing.T) {
expectedEmailHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
expectedUserHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(userName+":"+opts.BasicAuthPassword))
created := time.Now()
session := &sessions.SessionState{
User: userName,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: time.Now(),
CreatedAt: &created,
}
{
rw := httptest.NewRecorder()
@ -582,11 +582,12 @@ func TestPassUserHeadersWithEmail(t *testing.T) {
const emailAddress = "john.doe@example.com"
const userName = "9fcab5c9b889a557"
created := time.Now()
session := &sessions.SessionState{
User: userName,
Email: emailAddress,
AccessToken: "oauth_token",
CreatedAt: time.Now(),
CreatedAt: &created,
}
{
rw := httptest.NewRecorder()
@ -959,7 +960,8 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error)
func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
created := time.Now()
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
@ -985,7 +987,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
})
reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
@ -1001,7 +1003,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
@ -1016,7 +1018,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
opts.Cookie.Expire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
pcTest.SaveSession(startSession)
pcTest.proxy.CookieRefresh = time.Hour
@ -1062,8 +1064,9 @@ func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {
func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest()
created := time.Now()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req)
@ -1087,7 +1090,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &reference}
test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req)
@ -1098,8 +1101,9 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest()
created := time.Now()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: &created}
test.SaveSession(startSession)
test.validateUser = false
@ -1129,8 +1133,9 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
@ -1160,8 +1165,9 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
@ -1193,8 +1199,9 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
created := time.Now()
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: &created}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
@ -1569,10 +1576,11 @@ func TestGetJwtSession(t *testing.T) {
}
// Bearer
expires := time.Unix(1912151821, 0)
session, _ := test.proxy.GetJwtSession(test.req)
assert.Equal(t, session.User, "john@example.com")
assert.Equal(t, session.Email, "john@example.com")
assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0))
assert.Equal(t, session.ExpiresOn, &expires)
assert.Equal(t, session.IDToken, goodJwt)
test.proxy.ServeHTTP(test.rw, test.req)

View File

@ -2,7 +2,6 @@ package sessions
import (
"encoding/json"
"errors"
"fmt"
"time"
@ -11,26 +10,19 @@ import (
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
CreatedAt time.Time `json:"-"`
ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
PreferredUsername string `json:",omitempty"`
}
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
PreferredUsername string `json:",omitempty"`
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
@ -38,8 +30,8 @@ func (s *SessionState) IsExpired() bool {
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
}
return 0
}
@ -75,80 +67,36 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
ss.PreferredUsername = s.PreferredUsername
} else {
ss = *s
var err error
if ss.Email != "" {
ss.Email, err = c.Encrypt(ss.Email)
if err != nil {
return "", err
}
}
if ss.User != "" {
ss.User, err = c.Encrypt(ss.User)
if err != nil {
return "", err
}
}
if ss.PreferredUsername != "" {
ss.PreferredUsername, err = c.Encrypt(ss.PreferredUsername)
if err != nil {
return "", err
}
}
if ss.AccessToken != "" {
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
if err != nil {
return "", err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Encrypt(ss.IDToken)
if err != nil {
return "", err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
for _, s := range []*string{
&ss.Email,
&ss.User,
&ss.PreferredUsername,
&ss.AccessToken,
&ss.IDToken,
&ss.RefreshToken,
} {
err := c.EncryptInto(s)
if err != nil {
return "", err
}
}
}
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss}
if !ss.CreatedAt.IsZero() {
ssj.CreatedAt = &ss.CreatedAt
}
if !ss.ExpiresOn.IsZero() {
ssj.ExpiresOn = &ss.ExpiresOn
}
b, err := json.Marshal(ssj)
b, err := json.Marshal(ss)
return string(b), err
}
// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
var ssj SessionStateJSON
var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
var ss SessionState
err := json.Unmarshal([]byte(v), &ss)
if err != nil {
return nil, fmt.Errorf("error unmarshalling session: %w", err)
}
if ssj.SessionState == nil {
return nil, errors.New("expected session state to not be nil")
}
// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.CreatedAt != nil {
ss.CreatedAt = *ssj.CreatedAt
}
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}
if c == nil {
// Load only Email and User when cipher is unavailable
ss = &SessionState{
ss = SessionState{
Email: ss.Email,
User: ss.User,
PreferredUsername: ss.PreferredUsername,
@ -168,30 +116,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
ss.User = decryptedUser
}
}
if ss.PreferredUsername != "" {
ss.PreferredUsername, err = c.Decrypt(ss.PreferredUsername)
if err != nil {
return nil, err
}
}
if ss.AccessToken != "" {
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
if err != nil {
return nil, err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Decrypt(ss.IDToken)
if err != nil {
return nil, err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
for _, s := range []*string{
&ss.PreferredUsername,
&ss.AccessToken,
&ss.IDToken,
&ss.RefreshToken,
} {
err := c.DecryptInto(s)
if err != nil {
return nil, err
}
}
}
return ss, nil
return &ss, nil
}

View File

@ -13,6 +13,10 @@ import (
const secret = "0123456789abcdefghijklmnopqrstuv"
const altSecret = "0000000000abcdefghijklmnopqrstuv"
func timePtr(t time.Time) *time.Time {
return &t
}
func TestSessionStateSerialization(t *testing.T) {
c, err := encryption.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
@ -23,8 +27,8 @@ func TestSessionStateSerialization(t *testing.T) {
PreferredUsername: "user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: timePtr(time.Now()),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
@ -66,8 +70,8 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
PreferredUsername: "ju",
Email: "user@domain.com",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: timePtr(time.Now()),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
@ -102,8 +106,8 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: timePtr(time.Now()),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
@ -125,8 +129,8 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: timePtr(time.Now()),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
@ -143,10 +147,10 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
}
func TestExpired(t *testing.T) {
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
assert.Equal(t, true, s.IsExpired())
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
assert.Equal(t, false, s.IsExpired())
s = &sessions.SessionState{}
@ -182,8 +186,8 @@ func TestEncodeSessionState(t *testing.T) {
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: c,
ExpiresOn: e,
CreatedAt: &c,
ExpiresOn: &e,
RefreshToken: "refresh4321",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
@ -249,8 +253,8 @@ func TestDecodeSessionState(t *testing.T) {
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: created,
ExpiresOn: e,
CreatedAt: &created,
ExpiresOn: &e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
@ -291,7 +295,10 @@ func TestDecodeSessionState(t *testing.T) {
assert.Equal(t, tc.AccessToken, ss.AccessToken)
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
assert.Equal(t, tc.IDToken, ss.IDToken)
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
if tc.ExpiresOn != nil {
assert.NotEqual(t, nil, ss.ExpiresOn)
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
}
}
}
}
@ -303,6 +310,6 @@ func TestSessionStateAge(t *testing.T) {
assert.Equal(t, time.Duration(0), ss.Age())
// Set CreatedAt to 1 hour ago
ss.CreatedAt = time.Now().Add(-1 * time.Hour)
ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
}

View File

@ -156,3 +156,30 @@ func (c *Cipher) Decrypt(s string) (string, error) {
return string(encrypted), nil
}
// EncryptInto encrypts the value and stores it back in the string pointer
func (c *Cipher) EncryptInto(s *string) error {
return into(c.Encrypt, s)
}
// DecryptInto decrypts the value and stores it back in the string pointer
func (c *Cipher) DecryptInto(s *string) error {
return into(c.Decrypt, s)
}
// codecFunc is a function that takes a string and encodes/decodes it
type codecFunc func(string) (string, error)
func into(f codecFunc, s *string) error {
// Do not encrypt/decrypt nil or empty strings
if s == nil || *s == "" {
return nil
}
d, err := f(*s)
if err != nil {
return err
}
*s = d
return nil
}

View File

@ -133,3 +133,25 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
assert.NotEqual(t, token, encoded)
assert.Equal(t, token, decoded)
}
func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) {
const secret = "0123456789abcdefghijklmnopqrstuv"
c, err := NewCipher([]byte(secret))
assert.Equal(t, nil, err)
token := "my access token"
originalToken := token
assert.Equal(t, nil, c.EncryptInto(&token))
assert.NotEqual(t, originalToken, token)
assert.Equal(t, nil, c.DecryptInto(&token))
assert.Equal(t, originalToken, token)
// Check no errors with empty or nil strings
empty := ""
assert.Equal(t, nil, c.EncryptInto(&empty))
assert.Equal(t, nil, c.DecryptInto(&empty))
assert.Equal(t, nil, c.EncryptInto(nil))
assert.Equal(t, nil, c.DecryptInto(nil))
}

View File

@ -34,14 +34,15 @@ type SessionStore struct {
// Save takes a sessions.SessionState and stores the information from it
// within Cookies set on the HTTP response writer
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
if ss.CreatedAt.IsZero() {
ss.CreatedAt = time.Now()
if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
now := time.Now()
ss.CreatedAt = &now
}
value, err := cookieForSession(ss, s.CookieCipher)
if err != nil {
return err
}
s.setSessionCookie(rw, req, value, ss.CreatedAt)
s.setSessionCookie(rw, req, value, *ss.CreatedAt)
return nil
}

View File

@ -133,8 +133,9 @@ func parseRedisURLs(urls []string) ([]string, error) {
// Save takes a sessions.SessionState and stores the information from it
// to redies, and adds a new ticket cookie on the HTTP response writer
func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
if s.CreatedAt.IsZero() {
s.CreatedAt = time.Now()
if s.CreatedAt == nil || s.CreatedAt.IsZero() {
now := time.Now()
s.CreatedAt = &now
}
// Old sessions that we are refreshing would have a request cookie
@ -154,7 +155,7 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se
req,
ticketString,
store.CookieOptions.Expire,
s.CreatedAt,
*s.CreatedAt,
)
http.SetCookie(rw, ticketCookie)

View File

@ -15,6 +15,7 @@ import (
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
cookiesapi "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis"
@ -23,6 +24,8 @@ import (
)
func TestSessionStore(t *testing.T) {
logger.SetOutput(GinkgoWriter)
RegisterFailHandler(Fail)
RunSpecs(t, "SessionStore")
}
@ -253,16 +256,16 @@ var _ = Describe("NewSessionStore", func() {
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
l := *loadedSession
l.CreatedAt = time.Time{}
l.ExpiresOn = time.Time{}
l.CreatedAt = nil
l.ExpiresOn = nil
s := *session
s.CreatedAt = time.Time{}
s.ExpiresOn = time.Time{}
s.CreatedAt = nil
s.ExpiresOn = nil
Expect(l).To(Equal(s))
// Compare time.Time separately
Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue())
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
Expect(loadedSession.CreatedAt.Equal(*session.CreatedAt)).To(BeTrue())
Expect(loadedSession.ExpiresOn.Equal(*session.ExpiresOn)).To(BeTrue())
}
})
}
@ -392,10 +395,11 @@ var _ = Describe("NewSessionStore", func() {
SameSite: "",
}
expires := time.Now().Add(1 * time.Hour)
session = &sessionsapi.SessionState{
AccessToken: "AccessToken",
IDToken: "IDToken",
ExpiresOn: time.Now().Add(1 * time.Hour),
ExpiresOn: &expires,
RefreshToken: "RefreshToken",
Email: "john.doe@example.com",
User: "john.doe",

View File

@ -126,11 +126,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s
return
}
created := time.Now()
expires := time.Unix(jsonResponse.ExpiresOn, 0)
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Unix(jsonResponse.ExpiresOn, 0),
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken,
}
return

View File

@ -67,7 +67,7 @@ func (p *GitLabProvider) Redeem(ctx context.Context, redirectURL, code string) (
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@ -209,12 +209,13 @@ func (p *GitLabProvider) createSessionState(ctx context.Context, token *oauth2.T
return nil, fmt.Errorf("could not verify id_token: %v", err)
}
created := time.Now()
return &sessions.SessionState{
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
CreatedAt: time.Now(),
ExpiresOn: idToken.Expiry,
CreatedAt: &created,
ExpiresOn: &idToken.Expiry,
}, nil
}

View File

@ -153,11 +153,14 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) (
if err != nil {
return
}
created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
CreatedAt: &created,
ExpiresOn: &expires,
RefreshToken: jsonResponse.RefreshToken,
Email: c.Email,
User: c.Subject,
@ -245,7 +248,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@ -260,9 +263,10 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions
}
origExpiration := s.ExpiresOn
expires := time.Now().Add(duration).Truncate(time.Second)
s.AccessToken = newToken
s.IDToken = newIDToken
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
s.ExpiresOn = &expires
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}

View File

@ -250,12 +250,15 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string)
return
}
created := time.Now()
expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)
// Store the data that we found in the session state
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
CreatedAt: &created,
ExpiresOn: &expires,
Email: email,
}
return

View File

@ -72,7 +72,7 @@ func (p *OIDCProvider) Redeem(ctx context.Context, redirectURL, code string) (s
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access Token (and optional ID token) if required
func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" {
return false, nil
}
@ -163,10 +163,11 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
}
}
created := time.Now()
newSession.AccessToken = token.AccessToken
newSession.RefreshToken = token.RefreshToken
newSession.CreatedAt = time.Now()
newSession.ExpiresOn = token.Expiry
newSession.CreatedAt = &created
newSession.ExpiresOn = &token.Expiry
return newSession, nil
}
@ -179,7 +180,7 @@ func (p *OIDCProvider) CreateSessionStateFromBearerToken(ctx context.Context, ra
newSession.AccessToken = rawIDToken
newSession.IDToken = rawIDToken
newSession.RefreshToken = ""
newSession.ExpiresOn = idToken.Expiry
newSession.ExpiresOn = &idToken.Expiry
return newSession, nil
}

View File

@ -204,8 +204,8 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) {
existingSession := &sessions.SessionState{
AccessToken: "changeit",
IDToken: idToken,
CreatedAt: time.Time{},
ExpiresOn: time.Time{},
CreatedAt: nil,
ExpiresOn: nil,
RefreshToken: refreshToken,
Email: "janedoe@example.com",
User: "11223344",
@ -238,8 +238,8 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) {
existingSession := &sessions.SessionState{
AccessToken: "changeit",
IDToken: "changeit",
CreatedAt: time.Time{},
ExpiresOn: time.Time{},
CreatedAt: nil,
ExpiresOn: nil,
RefreshToken: refreshToken,
Email: "changeit",
User: "changeit",

View File

@ -81,7 +81,8 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (s
return
}
if a := v.Get("access_token"); a != "" {
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
created := time.Now()
s = &sessions.SessionState{AccessToken: a, CreatedAt: &created}
} else {
err = fmt.Errorf("no access token found %s", body)
}
@ -168,7 +169,7 @@ func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, ra
newSession.AccessToken = rawIDToken
newSession.IDToken = rawIDToken
newSession.RefreshToken = ""
newSession.ExpiresOn = idToken.Expiry
newSession.ExpiresOn = &idToken.Expiry
return newSession, nil
}

View File

@ -11,8 +11,10 @@ import (
func TestRefresh(t *testing.T) {
p := &ProviderData{}
expires := time.Now().Add(time.Duration(-11) * time.Minute)
refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
ExpiresOn: &expires,
})
assert.Equal(t, false, refreshed)
assert.Equal(t, nil, err)