1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-08-10 22:51:31 +02:00

Move SessionState to its own package

This commit is contained in:
Joel Speed
2019-05-05 13:33:13 +01:00
parent a1130e41a3
commit 2ab8a7d95d
20 changed files with 127 additions and 109 deletions

View File

@@ -0,0 +1,224 @@
package sessions
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/pusher/oauth2_proxy/cookie"
)
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
}
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
ExpiresOn *time.Time `json:",omitempty"`
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
}
// String constructs a summary of the session state
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
if s.AccessToken != "" {
o += " token:true"
}
if s.IDToken != "" {
o += " id_token:true"
}
if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
if s.RefreshToken != "" {
o += " refresh_token:true"
}
return o + "}"
}
// EncodeSessionState returns string representation of the current session
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
var ss SessionState
if c == nil {
// Store only Email and User when cipher is unavailable
ss.Email = s.Email
ss.User = s.User
} else {
ss = *s
var err error
if ss.Email != "" {
ss.Email, err = c.Encrypt(ss.Email)
if err != nil {
return "", err
}
}
if ss.User != "" {
ss.User, err = c.Encrypt(ss.User)
if err != nil {
return "", err
}
}
if ss.AccessToken != "" {
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
if err != nil {
return "", err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Encrypt(ss.IDToken)
if err != nil {
return "", err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
if err != nil {
return "", err
}
}
}
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss}
if !ss.ExpiresOn.IsZero() {
ssj.ExpiresOn = &ss.ExpiresOn
}
b, err := json.Marshal(ssj)
return string(b), err
}
// legacyDecodeSessionStatePlain decodes older plain session state string
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
chunks := strings.Split(v, " ")
if len(chunks) != 2 {
return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
}
user := strings.TrimPrefix(chunks[1], "user:")
email := strings.TrimPrefix(chunks[0], "email:")
return &SessionState{User: user, Email: email}, nil
}
// legacyDecodeSessionState attempts to decode the session state string
// generated by v3.1.0 or older
func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
chunks := strings.Split(v, "|")
if c == nil {
if len(chunks) != 1 {
return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
}
return legacyDecodeSessionStatePlain(chunks[0])
}
if len(chunks) != 4 && len(chunks) != 5 {
return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
}
i := 0
ss, err := legacyDecodeSessionStatePlain(chunks[i])
if err != nil {
return nil, err
}
i++
ss.AccessToken = chunks[i]
if len(chunks) == 5 {
// SessionState with IDToken in v3.1.0
i++
ss.IDToken = chunks[i]
}
i++
ts, err := strconv.Atoi(chunks[i])
if err != nil {
return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
}
ss.ExpiresOn = time.Unix(int64(ts), 0)
i++
ss.RefreshToken = chunks[i]
return ss, nil
}
// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
var ssj SessionStateJSON
var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil {
// Extract SessionState and ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}
} else {
// Try to decode a legacy string when json.Unmarshal failed
ss, err = legacyDecodeSessionState(v, c)
if err != nil {
return nil, err
}
}
if c == nil {
// Load only Email and User when cipher is unavailable
ss = &SessionState{
Email: ss.Email,
User: ss.User,
}
} else {
// Backward compatibility with using unecrypted Email
if ss.Email != "" {
decryptedEmail, errEmail := c.Decrypt(ss.Email)
if errEmail == nil {
ss.Email = decryptedEmail
}
}
// Backward compatibility with using unecrypted User
if ss.User != "" {
decryptedUser, errUser := c.Decrypt(ss.User)
if errUser == nil {
ss.User = decryptedUser
}
}
if ss.AccessToken != "" {
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
if err != nil {
return nil, err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Decrypt(ss.IDToken)
if err != nil {
return nil, err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
if err != nil {
return nil, err
}
}
}
if ss.User == "" {
ss.User = ss.Email
}
return ss, nil
}

View File

@@ -0,0 +1,318 @@
package sessions_test
import (
"fmt"
"testing"
"time"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
const secret = "0123456789abcdefghijklmnopqrstuv"
const altSecret = "0000000000abcdefghijklmnopqrstuv"
func TestSessionStateSerialization(t *testing.T) {
c, err := cookie.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err)
s := &sessions.SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IDToken, ss.IDToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.NotEqual(t, "user@domain.com", ss.User)
assert.NotEqual(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.IDToken, ss.IDToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
func TestSessionStateSerializationWithUser(t *testing.T) {
c, err := cookie.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err)
s := &sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.NotEqual(t, s.User, ss.User)
assert.NotEqual(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &sessions.SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err)
// only email should have been serialized
ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken)
}
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
s := &sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err)
// only email should have been serialized
ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken)
}
func TestExpired(t *testing.T) {
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired())
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
assert.Equal(t, false, s.IsExpired())
s = &sessions.SessionState{}
assert.Equal(t, false, s.IsExpired())
}
type testCase struct {
sessions.SessionState
Encoded string
Cipher *cookie.Cipher
Error bool
}
// TestEncodeSessionState tests EncodeSessionState with the test vector
//
// Currently only tests without cipher here because we have no way to mock
// the random generator used in EncodeSessionState.
func TestEncodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
testCases := []testCase{
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
}
for i, tc := range testCases {
encoded, err := tc.EncodeSessionState(tc.Cipher)
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
if tc.Error {
assert.Error(t, err)
assert.Empty(t, encoded)
continue
}
assert.NoError(t, err)
assert.JSONEq(t, tc.Encoded, encoded)
}
}
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON()
eString := string(eJSON)
eUnix := e.Unix()
c, err := cookie.NewCipher([]byte(secret))
assert.NoError(t, err)
testCases := []testCase{
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "user@domain.com",
},
Encoded: `{"Email":"user@domain.com"}`,
},
{
SessionState: sessions.SessionState{
User: "just-user",
},
Encoded: `{"User":"just-user"}`,
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
Cipher: c,
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"EGTllJcOFC16b7LBYzLekaHAC5SMMSPdyUrg8hd25g==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw=="}`,
Cipher: c,
},
{
Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`,
Cipher: c,
Error: true,
},
{
Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`,
Cipher: c,
Error: true,
},
{
SessionState: sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
},
Encoded: "email:user@domain.com user:just-user",
},
{
Encoded: "email:user@domain.com user:just-user||||",
Error: true,
},
{
Encoded: "email:user@domain.com user:just-user",
Cipher: c,
Error: true,
},
{
Encoded: "email:user@domain.com user:just-user|||99999999999999999999|",
Cipher: c,
Error: true,
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
Cipher: c,
},
{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
Cipher: c,
},
}
for i, tc := range testCases {
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
if tc.Error {
assert.Error(t, err)
assert.Nil(t, ss)
continue
}
assert.NoError(t, err)
if assert.NotNil(t, ss) {
assert.Equal(t, tc.User, ss.User)
assert.Equal(t, tc.Email, ss.Email)
assert.Equal(t, tc.AccessToken, ss.AccessToken)
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
assert.Equal(t, tc.IDToken, ss.IDToken)
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
}
}
}