mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-28 09:08:44 +02:00
Centralize Ticket management of persistent stores (#682)
* Centralize Ticket management of persistent stores persistence package with Manager & Ticket will handle all the details about keys, secrets, ticket into cookies, etc. Persistent stores just need to pass Save, Load & Clear function handles to the persistent manager now. * Shift to persistence.Manager wrapping a persistence.Store * Break up the Redis client builder logic * Move error messages to Store from Manager * Convert ticket to private for Manager use only * Add persistence Manager & ticket tests * Make a custom MockStore that handles time FastForwards
This commit is contained in:
parent
f141f7cea0
commit
9643a0b10c
@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
## Changes since v6.0.0
|
## Changes since v6.0.0
|
||||||
|
|
||||||
|
- [#682](https://github.com/oauth2-proxy/oauth2-proxy/pull/682) Refactor persistent session store session ticket management (@NickMeves)
|
||||||
- [#688](https://github.com/oauth2-proxy/oauth2-proxy/pull/688) Refactor session loading to make use of middleware pattern (@JoelSpeed)
|
- [#688](https://github.com/oauth2-proxy/oauth2-proxy/pull/688) Refactor session loading to make use of middleware pattern (@JoelSpeed)
|
||||||
- [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed)
|
- [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed)
|
||||||
- [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed)
|
- [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed)
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const LegacyV5TestSecret = "0123456789abcdefghijklmnopqrstuv"
|
||||||
|
|
||||||
// LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code
|
// LegacyV5TestCase provides V5 JSON based test cases for legacy fallback code
|
||||||
type LegacyV5TestCase struct {
|
type LegacyV5TestCase struct {
|
||||||
Input string
|
Input string
|
||||||
@ -22,8 +24,6 @@ type LegacyV5TestCase struct {
|
|||||||
//
|
//
|
||||||
// TODO: Remove when this is deprecated (likely V7)
|
// TODO: Remove when this is deprecated (likely V7)
|
||||||
func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) {
|
func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encryption.Cipher, encryption.Cipher) {
|
||||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
|
||||||
|
|
||||||
created := time.Now()
|
created := time.Now()
|
||||||
createdJSON, err := created.MarshalJSON()
|
createdJSON, err := created.MarshalJSON()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -33,7 +33,7 @@ func CreateLegacyV5TestCases(t *testing.T) (map[string]LegacyV5TestCase, encrypt
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
eString := string(eJSON)
|
eString := string(eJSON)
|
||||||
|
|
||||||
cfbCipher, err := encryption.NewCFBCipher([]byte(secret))
|
cfbCipher, err := encryption.NewCFBCipher([]byte(LegacyV5TestSecret))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
legacyCipher := encryption.NewBase64Cipher(cfbCipher)
|
legacyCipher := encryption.NewBase64Cipher(cfbCipher)
|
||||||
|
|
||||||
|
15
pkg/sessions/persistence/interfaces.go
Normal file
15
pkg/sessions/persistence/interfaces.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package persistence
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is used for persistent session stores (IE not Cookie)
|
||||||
|
// Implementing this interface allows it to easily use the persistence.Manager
|
||||||
|
// for session ticket + encryption details.
|
||||||
|
type Store interface {
|
||||||
|
Save(context.Context, string, []byte, time.Duration) error
|
||||||
|
Load(context.Context, string) ([]byte, error)
|
||||||
|
Clear(context.Context, string) error
|
||||||
|
}
|
91
pkg/sessions/persistence/manager.go
Normal file
91
pkg/sessions/persistence/manager.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package persistence
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager wraps a Store and handles the implementation details of the
|
||||||
|
// sessions.SessionStore with its use of session tickets
|
||||||
|
type Manager struct {
|
||||||
|
Store Store
|
||||||
|
Options *options.Cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a Manager that can wrap a Store and manage the
|
||||||
|
// sessions.SessionStore implementation details
|
||||||
|
func NewManager(store Store, cookieOpts *options.Cookie) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
Store: store,
|
||||||
|
Options: cookieOpts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save saves a session in a persistent Store. Save will generate (or reuse an
|
||||||
|
// existing) ticket which manages unique per session encryption & retrieval
|
||||||
|
// from the persistent data store.
|
||||||
|
func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||||
|
if s.CreatedAt == nil || s.CreatedAt.IsZero() {
|
||||||
|
now := time.Now()
|
||||||
|
s.CreatedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||||
|
if err != nil {
|
||||||
|
tckt, err = newTicket(m.Options)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating a session ticket: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tckt.saveSession(s, func(key string, val []byte, exp time.Duration) error {
|
||||||
|
return m.Store.Save(req.Context(), key, val, exp)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tckt.setCookie(rw, req, s)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads sessions.SessionState information from a session store. It will
|
||||||
|
// use the session ticket from the http.Request's cookie.
|
||||||
|
func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||||
|
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tckt.loadSession(func(key string) ([]byte, error) {
|
||||||
|
return m.Store.Load(req.Context(), key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear clears any saved session information for a given ticket cookie.
|
||||||
|
// Then it clears all session data for that ticket in the Store.
|
||||||
|
func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
tckt, err := decodeTicketFromRequest(req, m.Options)
|
||||||
|
if err != nil {
|
||||||
|
// Always clear the cookie, even when we can't load a cookie from
|
||||||
|
// the request
|
||||||
|
tckt = &ticket{
|
||||||
|
options: m.Options,
|
||||||
|
}
|
||||||
|
tckt.clearCookie(rw, req)
|
||||||
|
// Don't raise an error if we didn't have a Cookie
|
||||||
|
if err == http.ErrNoCookie {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error decoding ticket to clear session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tckt.clearCookie(rw, req)
|
||||||
|
return tckt.clearSession(func(key string) error {
|
||||||
|
return m.Store.Clear(req.Context(), key)
|
||||||
|
})
|
||||||
|
}
|
34
pkg/sessions/persistence/manager_test.go
Normal file
34
pkg/sessions/persistence/manager_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package persistence
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Persistence Manager SessionStore")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("Persistence Manager SessionStore Tests", func() {
|
||||||
|
var ms *tests.MockStore
|
||||||
|
BeforeEach(func() {
|
||||||
|
ms = tests.NewMockStore()
|
||||||
|
})
|
||||||
|
tests.RunSessionStoreTests(
|
||||||
|
func(_ *options.SessionOptions, cookieOpts *options.Cookie) (sessionsapi.SessionStore, error) {
|
||||||
|
return NewManager(ms, cookieOpts), nil
|
||||||
|
},
|
||||||
|
func(d time.Duration) error {
|
||||||
|
ms.FastForward(d)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
})
|
221
pkg/sessions/persistence/ticket.go
Normal file
221
pkg/sessions/persistence/ticket.go
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
package persistence
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||||
|
)
|
||||||
|
|
||||||
|
// saveFunc performs a persistent store's save functionality using
|
||||||
|
// a key string, value []byte & (optional) expiration time.Duration
|
||||||
|
type saveFunc func(string, []byte, time.Duration) error
|
||||||
|
|
||||||
|
// loadFunc performs a load from a persistent store using a
|
||||||
|
// string key and returning the stored value as []byte
|
||||||
|
type loadFunc func(string) ([]byte, error)
|
||||||
|
|
||||||
|
// clearFunc performs a persistent store's clear functionality using
|
||||||
|
// a string key for the target of the deletion.
|
||||||
|
type clearFunc func(string) error
|
||||||
|
|
||||||
|
// ticket is a structure representing the ticket used in server based
|
||||||
|
// session storage. It provides a unique per session decryption secret giving
|
||||||
|
// more security than the shared CookieSecret.
|
||||||
|
type ticket struct {
|
||||||
|
id string
|
||||||
|
secret []byte
|
||||||
|
options *options.Cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTicket creates a new ticket. The ID & secret will be randomly created
|
||||||
|
// with 16 byte sizes. The ID will be prefixed & hex encoded.
|
||||||
|
func newTicket(cookieOpts *options.Cookie) (*ticket, error) {
|
||||||
|
rawID := make([]byte, 16)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, rawID); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create new ticket ID: %v", err)
|
||||||
|
}
|
||||||
|
// ticketID is hex encoded
|
||||||
|
ticketID := fmt.Sprintf("%s-%s", cookieOpts.Name, hex.EncodeToString(rawID))
|
||||||
|
|
||||||
|
secret := make([]byte, aes.BlockSize)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, secret); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create encryption secret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ticket{
|
||||||
|
id: ticketID,
|
||||||
|
secret: secret,
|
||||||
|
options: cookieOpts,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeTicket encodes the Ticket to a string for usage in cookies
|
||||||
|
func (t *ticket) encodeTicket() string {
|
||||||
|
return fmt.Sprintf("%s.%s", t.id, base64.RawURLEncoding.EncodeToString(t.secret))
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeTicket decodes an encoded ticket string
|
||||||
|
func decodeTicket(encTicket string, cookieOpts *options.Cookie) (*ticket, error) {
|
||||||
|
ticketParts := strings.Split(encTicket, ".")
|
||||||
|
if len(ticketParts) != 2 {
|
||||||
|
return nil, errors.New("failed to decode ticket")
|
||||||
|
}
|
||||||
|
ticketID, secretBase64 := ticketParts[0], ticketParts[1]
|
||||||
|
|
||||||
|
secret, err := base64.RawURLEncoding.DecodeString(secretBase64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode encryption secret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ticket{
|
||||||
|
id: ticketID,
|
||||||
|
secret: secret,
|
||||||
|
options: cookieOpts,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeTicketFromRequest retrieves a potential ticket cookie from a request
|
||||||
|
// and decodes it to a ticket.
|
||||||
|
func decodeTicketFromRequest(req *http.Request, cookieOpts *options.Cookie) (*ticket, error) {
|
||||||
|
requestCookie, err := req.Cookie(cookieOpts.Name)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap this error to allow `err == http.ErrNoCookie` checks
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// An existing cookie exists, try to retrieve the ticket
|
||||||
|
val, _, ok := encryption.Validate(requestCookie, cookieOpts.Secret, cookieOpts.Expire)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("session ticket cookie failed validation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid cookie, decode the ticket
|
||||||
|
return decodeTicket(string(val), cookieOpts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveSession encodes the SessionState with the ticket's secret and persists
|
||||||
|
// it to disk via the passed saveFunc.
|
||||||
|
func (t *ticket) saveSession(s *sessions.SessionState, saver saveFunc) error {
|
||||||
|
c, err := t.makeCipher()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ciphertext, err := s.EncodeSessionState(c, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encode the session state with the ticket: %v", err)
|
||||||
|
}
|
||||||
|
return saver(t.id, ciphertext, t.options.Expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadSession loads a session from the disk store via the passed loadFunc
|
||||||
|
// using the ticket.id as the key. It then decodes the SessionState using
|
||||||
|
// ticket.secret to make the AES-GCM cipher.
|
||||||
|
//
|
||||||
|
// TODO (@NickMeves): Remove legacyV5LoadSession support in V7
|
||||||
|
func (t *ticket) loadSession(loader loadFunc) (*sessions.SessionState, error) {
|
||||||
|
ciphertext, err := loader(t.id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load the session state with the ticket: %v", err)
|
||||||
|
}
|
||||||
|
c, err := t.makeCipher()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ss, err := sessions.DecodeSessionState(ciphertext, c, false)
|
||||||
|
if err != nil {
|
||||||
|
return t.legacyV5LoadSession(ciphertext)
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearSession uses the passed clearFunc to delete a session stored with a
|
||||||
|
// key of ticket.id
|
||||||
|
func (t *ticket) clearSession(clearer clearFunc) error {
|
||||||
|
return clearer(t.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCookie sets the encoded ticket as a cookie
|
||||||
|
func (t *ticket) setCookie(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) {
|
||||||
|
ticketCookie := t.makeCookie(
|
||||||
|
req,
|
||||||
|
t.encodeTicket(),
|
||||||
|
t.options.Expire,
|
||||||
|
*s.CreatedAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
http.SetCookie(rw, ticketCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearCookie removes any cookies that would be where this ticket
|
||||||
|
// would set them
|
||||||
|
func (t *ticket) clearCookie(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
clearCookie := t.makeCookie(
|
||||||
|
req,
|
||||||
|
"",
|
||||||
|
time.Hour*-1,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
http.SetCookie(rw, clearCookie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeCookie makes a cookie, signing the value if present
|
||||||
|
func (t *ticket) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
|
||||||
|
if value != "" {
|
||||||
|
value = encryption.SignedValue(t.options.Secret, t.options.Name, []byte(value), now)
|
||||||
|
}
|
||||||
|
return cookies.MakeCookieFromOptions(
|
||||||
|
req,
|
||||||
|
t.options.Name,
|
||||||
|
value,
|
||||||
|
t.options,
|
||||||
|
expires,
|
||||||
|
now,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeCipher makes a AES-GCM cipher out of the ticket's secret
|
||||||
|
func (t *ticket) makeCipher() (encryption.Cipher, error) {
|
||||||
|
c, err := encryption.NewGCMCipher(t.secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to make an AES-GCM cipher from the ticket secret: %v", err)
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// legacyV5LoadSession loads a Redis session created in V5 with historical logic
|
||||||
|
//
|
||||||
|
// TODO (@NickMeves): Remove in V7
|
||||||
|
func (t *ticket) legacyV5LoadSession(resultBytes []byte) (*sessions.SessionState, error) {
|
||||||
|
block, err := aes.NewCipher(t.secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create a legacy AES-CFB cipher from the ticket secret: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := cipher.NewCFBDecrypter(block, t.secret)
|
||||||
|
stream.XORKeyStream(resultBytes, resultBytes)
|
||||||
|
|
||||||
|
cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(t.options.Secret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
legacyCipher := encryption.NewBase64Cipher(cfbCipher)
|
||||||
|
|
||||||
|
session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), legacyCipher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
223
pkg/sessions/persistence/ticket_test.go
Normal file
223
pkg/sessions/persistence/ticket_test.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
package persistence
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_ticket(t *testing.T) {
|
||||||
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "Session Ticket")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("Session Ticket Tests", func() {
|
||||||
|
Context("encodeTicket & decodeTicket", func() {
|
||||||
|
type ticketTableInput struct {
|
||||||
|
ticket *ticket
|
||||||
|
encodedTicket string
|
||||||
|
expectedError error
|
||||||
|
}
|
||||||
|
|
||||||
|
DescribeTable("encodeTicket should decodeTicket back when valid",
|
||||||
|
func(in ticketTableInput) {
|
||||||
|
if in.ticket != nil {
|
||||||
|
enc := in.ticket.encodeTicket()
|
||||||
|
Expect(enc).To(Equal(in.encodedTicket))
|
||||||
|
|
||||||
|
dec, err := decodeTicket(enc, in.ticket.options)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(dec).To(Equal(in.ticket))
|
||||||
|
} else {
|
||||||
|
_, err := decodeTicket(in.encodedTicket, nil)
|
||||||
|
Expect(err).To(MatchError(in.expectedError))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Entry("with a valid ticket", ticketTableInput{
|
||||||
|
ticket: &ticket{
|
||||||
|
id: "dummy-0123456789abcdef",
|
||||||
|
secret: []byte("0123456789abcdef"),
|
||||||
|
options: &options.Cookie{
|
||||||
|
Name: "dummy",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
encodedTicket: fmt.Sprintf("%s.%s",
|
||||||
|
"dummy-0123456789abcdef",
|
||||||
|
base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))),
|
||||||
|
expectedError: nil,
|
||||||
|
}),
|
||||||
|
Entry("with an invalid encoded ticket with 1 part", ticketTableInput{
|
||||||
|
ticket: nil,
|
||||||
|
encodedTicket: "dummy-0123456789abcdef",
|
||||||
|
expectedError: errors.New("failed to decode ticket"),
|
||||||
|
}),
|
||||||
|
Entry("with an invalid base64 encoded secret", ticketTableInput{
|
||||||
|
ticket: nil,
|
||||||
|
encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$",
|
||||||
|
expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("saveSession", func() {
|
||||||
|
It("uses the passed save function", func() {
|
||||||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
c, err := t.makeCipher()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
ss := &sessions.SessionState{User: "foobar"}
|
||||||
|
store := map[string][]byte{}
|
||||||
|
err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error {
|
||||||
|
store[k] = v
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
stored, err := sessions.DecodeSessionState(store[t.id], c, false)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(stored).To(Equal(ss))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when the saveFunc errors", func() {
|
||||||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
err = t.saveSession(
|
||||||
|
&sessions.SessionState{User: "foobar"},
|
||||||
|
func(k string, v []byte, e time.Duration) error {
|
||||||
|
return errors.New("save error")
|
||||||
|
})
|
||||||
|
Expect(err).To(MatchError(errors.New("save error")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("loadSession", func() {
|
||||||
|
It("uses the passed load function", func() {
|
||||||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
c, err := t.makeCipher()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
ss := &sessions.SessionState{User: "foobar"}
|
||||||
|
loadedSession, err := t.loadSession(func(k string) ([]byte, error) {
|
||||||
|
return ss.EncodeSessionState(c, false)
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(loadedSession).To(Equal(ss))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when the loadFunc errors", func() {
|
||||||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
data, err := t.loadSession(func(k string) ([]byte, error) {
|
||||||
|
return nil, errors.New("load error")
|
||||||
|
})
|
||||||
|
Expect(data).To(BeNil())
|
||||||
|
Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("clearSession", func() {
|
||||||
|
It("uses the passed clear function", func() {
|
||||||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
var tracker string
|
||||||
|
err = t.clearSession(func(k string) error {
|
||||||
|
tracker = k
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(tracker).To(Equal(t.id))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("errors when the clearFunc errors", func() {
|
||||||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
err = t.clearSession(func(k string) error {
|
||||||
|
return errors.New("clear error")
|
||||||
|
})
|
||||||
|
Expect(err).To(MatchError(errors.New("clear error")))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession
|
||||||
|
// when a V5 encoded session is in Redis
|
||||||
|
//
|
||||||
|
// TODO (@NickMeves): Remove when this is deprecated (likely V7)
|
||||||
|
func Test_legacyV5LoadSession(t *testing.T) {
|
||||||
|
testCases, _, _ := sessions.CreateLegacyV5TestCases(t)
|
||||||
|
|
||||||
|
for testName, tc := range testCases {
|
||||||
|
t.Run(testName, func(t *testing.T) {
|
||||||
|
g := NewWithT(t)
|
||||||
|
|
||||||
|
secret := make([]byte, aes.BlockSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
tckt := &ticket{
|
||||||
|
secret: secret,
|
||||||
|
options: &options.Cookie{
|
||||||
|
Secret: base64.RawURLEncoding.EncodeToString([]byte(sessions.LegacyV5TestSecret)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypted, err := legacyStoreValue(tc.Input, tckt.secret)
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
ss, err := tckt.legacyV5LoadSession(encrypted)
|
||||||
|
if tc.Error {
|
||||||
|
g.Expect(err).To(HaveOccurred())
|
||||||
|
g.Expect(ss).To(BeNil())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
// Compare sessions without *time.Time fields
|
||||||
|
exp := *tc.Output
|
||||||
|
exp.CreatedAt = nil
|
||||||
|
exp.ExpiresOn = nil
|
||||||
|
act := *ss
|
||||||
|
act.CreatedAt = nil
|
||||||
|
act.ExpiresOn = nil
|
||||||
|
g.Expect(exp).To(Equal(act))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// legacyStoreValue implements the legacy V5 Redis persistence AES-CFB value encryption
|
||||||
|
//
|
||||||
|
// TODO (@NickMeves): Remove when this is deprecated (likely V7)
|
||||||
|
func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) {
|
||||||
|
ciphertext := make([]byte, len(value))
|
||||||
|
block, err := aes.NewCipher(ticketSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error initiating cipher block: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use secret as the Initialization Vector too, because each entry has it's own key
|
||||||
|
stream := cipher.NewCFBEncrypter(block, ticketSecret)
|
||||||
|
stream.XORKeyStream(ciphertext, []byte(value))
|
||||||
|
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
@ -2,69 +2,87 @@ package redis
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v7"
|
"github.com/go-redis/redis/v7"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TicketData is a structure representing the ticket used in server session storage
|
// SessionStore is an implementation of the persistence.Store
|
||||||
type TicketData struct {
|
|
||||||
TicketID string
|
|
||||||
Secret []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// SessionStore is an implementation of the sessions.SessionStore
|
|
||||||
// interface that stores sessions in redis
|
// interface that stores sessions in redis
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
CookieCipher encryption.Cipher
|
|
||||||
Cookie *options.Cookie
|
|
||||||
Client Client
|
Client Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedisSessionStore initialises a new instance of the SessionStore from
|
// NewRedisSessionStore initialises a new instance of the SessionStore and wraps
|
||||||
// the configuration given
|
// it in a persistence.Manager
|
||||||
func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) {
|
func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) {
|
||||||
cfbCipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret))
|
client, err := newRedisClient(opts.Redis)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error initialising cipher: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := newRedisCmdable(opts.Redis)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error constructing redis client: %v", err)
|
return nil, fmt.Errorf("error constructing redis client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rs := &SessionStore{
|
rs := &SessionStore{
|
||||||
Client: client,
|
Client: client,
|
||||||
CookieCipher: cfbCipher,
|
|
||||||
Cookie: cookieOpts,
|
|
||||||
}
|
}
|
||||||
return rs, nil
|
return persistence.NewManager(rs, cookieOpts), nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) {
|
// Save takes a sessions.SessionState and stores the information from it
|
||||||
|
// to redies, and adds a new persistence cookie on the HTTP response writer
|
||||||
|
func (store *SessionStore) Save(ctx context.Context, key string, value []byte, exp time.Duration) error {
|
||||||
|
err := store.Client.Set(ctx, key, value, exp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error saving redis session: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads sessions.SessionState information from a persistence
|
||||||
|
// cookie within the HTTP request object
|
||||||
|
func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
value, err := store.Client.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error loading redis session: %v", err)
|
||||||
|
}
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear clears any saved session information for a given persistence cookie
|
||||||
|
// from redis, and then clears the session
|
||||||
|
func (store *SessionStore) Clear(ctx context.Context, key string) error {
|
||||||
|
err := store.Client.Del(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error clearing the session from redis: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRedisClient makes a redis.Client (either standalone, sentinel aware, or
|
||||||
|
// redis cluster)
|
||||||
|
func newRedisClient(opts options.RedisStoreOptions) (Client, error) {
|
||||||
if opts.UseSentinel && opts.UseCluster {
|
if opts.UseSentinel && opts.UseCluster {
|
||||||
return nil, fmt.Errorf("options redis-use-sentinel and redis-use-cluster are mutually exclusive")
|
return nil, fmt.Errorf("options redis-use-sentinel and redis-use-cluster are mutually exclusive")
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.UseSentinel {
|
if opts.UseSentinel {
|
||||||
|
return buildSentinelClient(opts)
|
||||||
|
}
|
||||||
|
if opts.UseCluster {
|
||||||
|
return buildClusterClient(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildStandaloneClient(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSentinelClient makes a redis.Client that connects to Redis Sentinel
|
||||||
|
// for Primary/Replica Redis node coordination
|
||||||
|
func buildSentinelClient(opts options.RedisStoreOptions) (Client, error) {
|
||||||
addrs, err := parseRedisURLs(opts.SentinelConnectionURLs)
|
addrs, err := parseRedisURLs(opts.SentinelConnectionURLs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse redis urls: %v", err)
|
return nil, fmt.Errorf("could not parse redis urls: %v", err)
|
||||||
@ -74,9 +92,10 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) {
|
|||||||
SentinelAddrs: addrs,
|
SentinelAddrs: addrs,
|
||||||
})
|
})
|
||||||
return newClient(client), nil
|
return newClient(client), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.UseCluster {
|
// buildClusterClient makes a redis.Client that is Redis Cluster aware
|
||||||
|
func buildClusterClient(opts options.RedisStoreOptions) (Client, error) {
|
||||||
addrs, err := parseRedisURLs(opts.ClusterConnectionURLs)
|
addrs, err := parseRedisURLs(opts.ClusterConnectionURLs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not parse redis urls: %v", err)
|
return nil, fmt.Errorf("could not parse redis urls: %v", err)
|
||||||
@ -85,8 +104,11 @@ func newRedisCmdable(opts options.RedisStoreOptions) (Client, error) {
|
|||||||
Addrs: addrs,
|
Addrs: addrs,
|
||||||
})
|
})
|
||||||
return newClusterClient(client), nil
|
return newClusterClient(client), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildStandaloneClient makes a redis.Client that connects to a simple
|
||||||
|
// Redis node
|
||||||
|
func buildStandaloneClient(opts options.RedisStoreOptions) (Client, error) {
|
||||||
opt, err := redis.ParseURL(opts.ConnectionURL)
|
opt, err := redis.ParseURL(opts.ConnectionURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to parse redis url: %s", err)
|
return nil, fmt.Errorf("unable to parse redis url: %s", err)
|
||||||
@ -134,261 +156,3 @@ func parseRedisURLs(urls []string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
return addrs, nil
|
return addrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save takes a sessions.SessionState and stores the information from it
|
|
||||||
// to redies, and adds a new ticket cookie on the HTTP response writer
|
|
||||||
func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
|
||||||
if s.CreatedAt == nil || s.CreatedAt.IsZero() {
|
|
||||||
now := time.Now()
|
|
||||||
s.CreatedAt = &now
|
|
||||||
}
|
|
||||||
|
|
||||||
// Old sessions that we are refreshing would have a request cookie
|
|
||||||
// New sessions don't, so we ignore the error. storeValue will check requestCookie
|
|
||||||
requestCookie, _ := req.Cookie(store.Cookie.Name)
|
|
||||||
ctx := req.Context()
|
|
||||||
ticketString, err := store.saveSession(ctx, s, store.Cookie.Expire, requestCookie)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ticketCookie := store.makeCookie(
|
|
||||||
req,
|
|
||||||
ticketString,
|
|
||||||
store.Cookie.Expire,
|
|
||||||
*s.CreatedAt,
|
|
||||||
)
|
|
||||||
|
|
||||||
http.SetCookie(rw, ticketCookie)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load reads sessions.SessionState information from a ticket
|
|
||||||
// cookie within the HTTP request object
|
|
||||||
func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
|
||||||
requestCookie, err := req.Cookie(store.Cookie.Name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error loading session: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("cookie signature not valid")
|
|
||||||
}
|
|
||||||
ctx := req.Context()
|
|
||||||
session, err := store.loadSessionFromTicket(ctx, string(val))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error loading session: %s", err)
|
|
||||||
}
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear clears any saved session information for a given ticket cookie
|
|
||||||
// from redis, and then clears the session
|
|
||||||
func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
|
||||||
// We go ahead and clear the cookie first, always.
|
|
||||||
clearCookie := store.makeCookie(
|
|
||||||
req,
|
|
||||||
"",
|
|
||||||
time.Hour*-1,
|
|
||||||
time.Now(),
|
|
||||||
)
|
|
||||||
http.SetCookie(rw, clearCookie)
|
|
||||||
|
|
||||||
// If there was an existing cookie we should clear the session in redis
|
|
||||||
requestCookie, err := req.Cookie(store.Cookie.Name)
|
|
||||||
if err != nil && err == http.ErrNoCookie {
|
|
||||||
// No existing cookie so can't clear redis
|
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("error retrieving cookie: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("cookie signature not valid")
|
|
||||||
}
|
|
||||||
|
|
||||||
// We only return an error if we had an issue with redis
|
|
||||||
// If there's an issue decoding the ticket, ignore it
|
|
||||||
ticket, _ := decodeTicket(store.Cookie.Name, string(val))
|
|
||||||
if ticket != nil {
|
|
||||||
ctx := req.Context()
|
|
||||||
err := store.Client.Del(ctx, ticket.asHandle(store.Cookie.Name))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error clearing cookie from redis: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// saveSession encodes a session with a GCM cipher & saves the data into Redis
|
|
||||||
func (store *SessionStore) saveSession(ctx context.Context, s *sessions.SessionState, expiration time.Duration, requestCookie *http.Cookie) (string, error) {
|
|
||||||
ticket, err := store.getTicket(requestCookie)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error getting ticket: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := encryption.NewGCMCipher(ticket.Secret)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error initiating cipher block %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use AES-GCM since it provides authenticated encryption
|
|
||||||
// AES-CFB used in cookies has the cookie signing SHA to get around the lack of
|
|
||||||
// authentication in AES-CFB
|
|
||||||
ciphertext, err := s.EncodeSessionState(c, false)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
handle := ticket.asHandle(store.Cookie.Name)
|
|
||||||
err = store.Client.Set(ctx, handle, ciphertext, expiration)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return ticket.encodeTicket(store.Cookie.Name), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadSessionFromTicket loads the session based on the ticket value
|
|
||||||
func (store *SessionStore) loadSessionFromTicket(ctx context.Context, value string) (*sessions.SessionState, error) {
|
|
||||||
ticket, err := decodeTicket(store.Cookie.Name, value)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resultBytes, err := store.Client.Get(ctx, ticket.asHandle(store.Cookie.Name))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := encryption.NewGCMCipher(ticket.Secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := sessions.DecodeSessionState(resultBytes, c, false)
|
|
||||||
if err != nil {
|
|
||||||
// The GCM cipher will error due to a legacy JSON payload not passing
|
|
||||||
// the authentication check part of AES GCM encryption.
|
|
||||||
// In that case, we can attempt to fallback to try a legacy load
|
|
||||||
legacyCipher := encryption.NewBase64Cipher(store.CookieCipher)
|
|
||||||
return legacyV5DecodeSession(resultBytes, ticket, legacyCipher)
|
|
||||||
}
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacyV5DecodeSession loads the session based on the ticket value
|
|
||||||
// This fallback uses V5 style encryption of Base64 + AES CFB
|
|
||||||
func legacyV5DecodeSession(resultBytes []byte, ticket *TicketData, c encryption.Cipher) (*sessions.SessionState, error) {
|
|
||||||
block, err := aes.NewCipher(ticket.Secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Use secret as the IV too, because each entry has it's own key
|
|
||||||
stream := cipher.NewCFBDecrypter(block, ticket.Secret)
|
|
||||||
stream.XORKeyStream(resultBytes, resultBytes)
|
|
||||||
|
|
||||||
session, err := sessions.LegacyV5DecodeSessionState(string(resultBytes), c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeCookie makes a cookie, signing the value if present
|
|
||||||
func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
|
|
||||||
if value != "" {
|
|
||||||
value = encryption.SignedValue(store.Cookie.Secret, store.Cookie.Name, []byte(value), now)
|
|
||||||
}
|
|
||||||
return cookies.MakeCookieFromOptions(
|
|
||||||
req,
|
|
||||||
store.Cookie.Name,
|
|
||||||
value,
|
|
||||||
store.Cookie,
|
|
||||||
expires,
|
|
||||||
now,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getTicket retrieves an existing ticket from the cookie if present,
|
|
||||||
// or creates a new ticket
|
|
||||||
func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, error) {
|
|
||||||
if requestCookie == nil {
|
|
||||||
return newTicket()
|
|
||||||
}
|
|
||||||
|
|
||||||
// An existing cookie exists, try to retrieve the ticket
|
|
||||||
val, _, ok := encryption.Validate(requestCookie, store.Cookie.Secret, store.Cookie.Expire)
|
|
||||||
if !ok {
|
|
||||||
// Cookie is invalid, create a new ticket
|
|
||||||
return newTicket()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid cookie, decode the ticket
|
|
||||||
ticket, err := decodeTicket(store.Cookie.Name, string(val))
|
|
||||||
if err != nil {
|
|
||||||
// If we can't decode the ticket we have to create a new one
|
|
||||||
return newTicket()
|
|
||||||
}
|
|
||||||
return ticket, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTicket() (*TicketData, error) {
|
|
||||||
rawID := make([]byte, 16)
|
|
||||||
if _, err := io.ReadFull(rand.Reader, rawID); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create new ticket ID %s", err)
|
|
||||||
}
|
|
||||||
// ticketID is hex encoded
|
|
||||||
ticketID := hex.EncodeToString(rawID)
|
|
||||||
|
|
||||||
secret := make([]byte, aes.BlockSize)
|
|
||||||
if _, err := io.ReadFull(rand.Reader, secret); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create initialization vector %s", err)
|
|
||||||
}
|
|
||||||
ticket := &TicketData{
|
|
||||||
TicketID: ticketID,
|
|
||||||
Secret: secret,
|
|
||||||
}
|
|
||||||
return ticket, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ticket *TicketData) asHandle(prefix string) string {
|
|
||||||
return fmt.Sprintf("%s-%s", prefix, ticket.TicketID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeTicket(cookieName string, ticketString string) (*TicketData, error) {
|
|
||||||
prefix := cookieName + "-"
|
|
||||||
if !strings.HasPrefix(ticketString, prefix) {
|
|
||||||
return nil, fmt.Errorf("failed to decode ticket handle")
|
|
||||||
}
|
|
||||||
trimmedTicket := strings.TrimPrefix(ticketString, prefix)
|
|
||||||
|
|
||||||
ticketParts := strings.Split(trimmedTicket, ".")
|
|
||||||
if len(ticketParts) != 2 {
|
|
||||||
return nil, fmt.Errorf("failed to decode ticket")
|
|
||||||
}
|
|
||||||
ticketID, secretBase64 := ticketParts[0], ticketParts[1]
|
|
||||||
|
|
||||||
// ticketID must be a hexadecimal string
|
|
||||||
_, err := hex.DecodeString(ticketID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("server ticket failed sanity checks")
|
|
||||||
}
|
|
||||||
|
|
||||||
secret, err := base64.RawURLEncoding.DecodeString(secretBase64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to decode initialization vector %s", err)
|
|
||||||
}
|
|
||||||
ticketData := &TicketData{
|
|
||||||
TicketID: ticketID,
|
|
||||||
Secret: secret,
|
|
||||||
}
|
|
||||||
return ticketData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ticket *TicketData) encodeTicket(prefix string) string {
|
|
||||||
handle := ticket.asHandle(prefix)
|
|
||||||
ticketString := handle + "." + base64.RawURLEncoding.EncodeToString(ticket.Secret)
|
|
||||||
return ticketString
|
|
||||||
}
|
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
package redis
|
package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/aes"
|
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
@ -17,70 +12,12 @@ import (
|
|||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||||||
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/tests"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession
|
|
||||||
// when a V5 encoded session is in Redis
|
|
||||||
//
|
|
||||||
// TODO: Remove when this is deprecated (likely V7)
|
|
||||||
func Test_legacyV5DecodeSession(t *testing.T) {
|
|
||||||
testCases, _, legacyCipher := sessionsapi.CreateLegacyV5TestCases(t)
|
|
||||||
|
|
||||||
for testName, tc := range testCases {
|
|
||||||
t.Run(testName, func(t *testing.T) {
|
|
||||||
g := NewWithT(t)
|
|
||||||
|
|
||||||
secret := make([]byte, aes.BlockSize)
|
|
||||||
_, err := io.ReadFull(rand.Reader, secret)
|
|
||||||
g.Expect(err).ToNot(HaveOccurred())
|
|
||||||
ticket := &TicketData{
|
|
||||||
TicketID: "",
|
|
||||||
Secret: secret,
|
|
||||||
}
|
|
||||||
|
|
||||||
encrypted, err := legacyStoreValue(tc.Input, ticket.Secret)
|
|
||||||
g.Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
ss, err := legacyV5DecodeSession(encrypted, ticket, legacyCipher)
|
|
||||||
if tc.Error {
|
|
||||||
g.Expect(err).To(HaveOccurred())
|
|
||||||
g.Expect(ss).To(BeNil())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
g.Expect(err).ToNot(HaveOccurred())
|
|
||||||
|
|
||||||
// Compare sessions without *time.Time fields
|
|
||||||
exp := *tc.Output
|
|
||||||
exp.CreatedAt = nil
|
|
||||||
exp.ExpiresOn = nil
|
|
||||||
act := *ss
|
|
||||||
act.CreatedAt = nil
|
|
||||||
act.ExpiresOn = nil
|
|
||||||
g.Expect(exp).To(Equal(act))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacyStoreValue implements the legacy V5 Redis store AES-CFB value encryption
|
|
||||||
//
|
|
||||||
// TODO: Remove when this is deprecated (likely V7)
|
|
||||||
func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) {
|
|
||||||
ciphertext := make([]byte, len(value))
|
|
||||||
block, err := aes.NewCipher(ticketSecret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error initiating cipher block: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use secret as the Initialization Vector too, because each entry has it's own key
|
|
||||||
stream := cipher.NewCFBEncrypter(block, ticketSecret)
|
|
||||||
stream.XORKeyStream(ciphertext, []byte(value))
|
|
||||||
|
|
||||||
return ciphertext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionStore(t *testing.T) {
|
func TestSessionStore(t *testing.T) {
|
||||||
logger.SetOutput(GinkgoWriter)
|
logger.SetOutput(GinkgoWriter)
|
||||||
|
|
||||||
@ -114,9 +51,9 @@ var _ = Describe("Redis SessionStore Tests", func() {
|
|||||||
|
|
||||||
JustAfterEach(func() {
|
JustAfterEach(func() {
|
||||||
// Release any connections immediately after the test ends
|
// Release any connections immediately after the test ends
|
||||||
if redisStore, ok := ss.(*SessionStore); ok {
|
if redisManager, ok := ss.(*persistence.Manager); ok {
|
||||||
if redisStore.Client != nil {
|
if redisManager.Store.(*SessionStore).Client != nil {
|
||||||
Expect(redisStore.Client.(closer).Close()).To(Succeed())
|
Expect(redisManager.Store.(*SessionStore).Client.(closer).Close()).To(Succeed())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/logger"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions"
|
||||||
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
|
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/cookie"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/persistence"
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions/redis"
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
@ -66,10 +67,11 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
opts.Redis.ConnectionURL = "redis://"
|
opts.Redis.ConnectionURL = "redis://"
|
||||||
})
|
})
|
||||||
|
|
||||||
It("creates a redis.SessionStore", func() {
|
It("creates a persistence.Manager that wraps a redis.SessionStore", func() {
|
||||||
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(ss).To(BeAssignableToTypeOf(&redis.SessionStore{}))
|
Expect(ss).To(BeAssignableToTypeOf(&persistence.Manager{}))
|
||||||
|
Expect(ss.(*persistence.Manager).Store).To(BeAssignableToTypeOf(&redis.SessionStore{}))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
58
pkg/sessions/tests/mock_store.go
Normal file
58
pkg/sessions/tests/mock_store.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// entry is a MockStore cache entry with an expiration
|
||||||
|
type entry struct {
|
||||||
|
data []byte
|
||||||
|
expiration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockStore is a generic in-memory implementation of persistence.Store
|
||||||
|
// for mocking in tests
|
||||||
|
type MockStore struct {
|
||||||
|
cache map[string]entry
|
||||||
|
elapsed time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockStore creates a MockStore
|
||||||
|
func NewMockStore() *MockStore {
|
||||||
|
return &MockStore{
|
||||||
|
cache: map[string]entry{},
|
||||||
|
elapsed: 0 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save sets a key to the data to the memory cache
|
||||||
|
func (s *MockStore) Save(_ context.Context, key string, value []byte, exp time.Duration) error {
|
||||||
|
s.cache[key] = entry{
|
||||||
|
data: value,
|
||||||
|
expiration: exp,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load gets data from the memory cache via a key
|
||||||
|
func (s *MockStore) Load(_ context.Context, key string) ([]byte, error) {
|
||||||
|
entry, ok := s.cache[key]
|
||||||
|
if !ok || entry.expiration <= s.elapsed {
|
||||||
|
delete(s.cache, key)
|
||||||
|
return nil, fmt.Errorf("key not found: %s", key)
|
||||||
|
}
|
||||||
|
return entry.data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear deletes an entry from the memory cache
|
||||||
|
func (s *MockStore) Clear(_ context.Context, key string) error {
|
||||||
|
delete(s.cache, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FastForward simulates the flow of time to test expirations
|
||||||
|
func (s *MockStore) FastForward(duration time.Duration) {
|
||||||
|
s.elapsed += duration
|
||||||
|
}
|
@ -133,18 +133,6 @@ func RunSessionStoreTests(newSS NewSessionStoreFunc, persistentFastForward Persi
|
|||||||
PersistentSessionStoreInterfaceTests(&input)
|
PersistentSessionStoreInterfaceTests(&input)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("with an invalid cookie secret", func() {
|
|
||||||
BeforeEach(func() {
|
|
||||||
input.cookieOpts.Secret = "invalid"
|
|
||||||
})
|
|
||||||
|
|
||||||
It("returns an error when initialising the session store", func() {
|
|
||||||
ss, err := newSS(opts, input.cookieOpts)
|
|
||||||
Expect(err).To(MatchError("error initialising cipher: crypto/aes: invalid key size 7"))
|
|
||||||
Expect(ss).To(BeNil())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user