mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-01-10 04:18:14 +02:00
217 lines
6.1 KiB
Go
217 lines
6.1 KiB
Go
|
package persistence
|
||
|
|
||
|
import (
|
||
|
"crypto/aes"
|
||
|
"crypto/cipher"
|
||
|
"crypto/rand"
|
||
|
"encoding/base64"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
|
||
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
||
|
. "github.com/onsi/ginkgo"
|
||
|
. "github.com/onsi/ginkgo/extensions/table"
|
||
|
. "github.com/onsi/gomega"
|
||
|
)
|
||
|
|
||
|
var _ = Describe("Session Ticket Tests", func() {
|
||
|
Context("encodeTicket & decodeTicket", func() {
|
||
|
type ticketTableInput struct {
|
||
|
ticket *ticket
|
||
|
encodedTicket string
|
||
|
expectedError error
|
||
|
}
|
||
|
|
||
|
DescribeTable("encodeTicket should decodeTicket back when valid",
|
||
|
func(in ticketTableInput) {
|
||
|
if in.ticket != nil {
|
||
|
enc := in.ticket.encodeTicket()
|
||
|
Expect(enc).To(Equal(in.encodedTicket))
|
||
|
|
||
|
dec, err := decodeTicket(enc, in.ticket.options)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(dec).To(Equal(in.ticket))
|
||
|
} else {
|
||
|
_, err := decodeTicket(in.encodedTicket, nil)
|
||
|
Expect(err).To(MatchError(in.expectedError))
|
||
|
}
|
||
|
},
|
||
|
Entry("with a valid ticket", ticketTableInput{
|
||
|
ticket: &ticket{
|
||
|
id: "dummy-0123456789abcdef",
|
||
|
secret: []byte("0123456789abcdef"),
|
||
|
options: &options.Cookie{
|
||
|
Name: "dummy",
|
||
|
},
|
||
|
},
|
||
|
encodedTicket: fmt.Sprintf("%s.%s",
|
||
|
"dummy-0123456789abcdef",
|
||
|
base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))),
|
||
|
expectedError: nil,
|
||
|
}),
|
||
|
Entry("with an invalid encoded ticket with 1 part", ticketTableInput{
|
||
|
ticket: nil,
|
||
|
encodedTicket: "dummy-0123456789abcdef",
|
||
|
expectedError: errors.New("failed to decode ticket"),
|
||
|
}),
|
||
|
Entry("with an invalid base64 encoded secret", ticketTableInput{
|
||
|
ticket: nil,
|
||
|
encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$",
|
||
|
expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"),
|
||
|
}),
|
||
|
)
|
||
|
})
|
||
|
|
||
|
Context("saveSession", func() {
|
||
|
It("uses the passed save function", func() {
|
||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
c, err := t.makeCipher()
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
ss := &sessions.SessionState{User: "foobar"}
|
||
|
store := map[string][]byte{}
|
||
|
err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error {
|
||
|
store[k] = v
|
||
|
return nil
|
||
|
})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
stored, err := sessions.DecodeSessionState(store[t.id], c, false)
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(stored).To(Equal(ss))
|
||
|
})
|
||
|
|
||
|
It("errors when the saveFunc errors", func() {
|
||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
err = t.saveSession(
|
||
|
&sessions.SessionState{User: "foobar"},
|
||
|
func(k string, v []byte, e time.Duration) error {
|
||
|
return errors.New("save error")
|
||
|
})
|
||
|
Expect(err).To(MatchError(errors.New("save error")))
|
||
|
})
|
||
|
})
|
||
|
|
||
|
Context("loadSession", func() {
|
||
|
It("uses the passed load function", func() {
|
||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
c, err := t.makeCipher()
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
ss := &sessions.SessionState{User: "foobar"}
|
||
|
loadedSession, err := t.loadSession(func(k string) ([]byte, error) {
|
||
|
return ss.EncodeSessionState(c, false)
|
||
|
})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(loadedSession).To(Equal(ss))
|
||
|
})
|
||
|
|
||
|
It("errors when the loadFunc errors", func() {
|
||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
data, err := t.loadSession(func(k string) ([]byte, error) {
|
||
|
return nil, errors.New("load error")
|
||
|
})
|
||
|
Expect(data).To(BeNil())
|
||
|
Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error")))
|
||
|
})
|
||
|
})
|
||
|
|
||
|
Context("clearSession", func() {
|
||
|
It("uses the passed clear function", func() {
|
||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
var tracker string
|
||
|
err = t.clearSession(func(k string) error {
|
||
|
tracker = k
|
||
|
return nil
|
||
|
})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
Expect(tracker).To(Equal(t.id))
|
||
|
})
|
||
|
|
||
|
It("errors when the clearFunc errors", func() {
|
||
|
t, err := newTicket(&options.Cookie{Name: "dummy"})
|
||
|
Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
err = t.clearSession(func(k string) error {
|
||
|
return errors.New("clear error")
|
||
|
})
|
||
|
Expect(err).To(MatchError(errors.New("clear error")))
|
||
|
})
|
||
|
})
|
||
|
})
|
||
|
|
||
|
// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession
|
||
|
// when a V5 encoded session is in Redis
|
||
|
//
|
||
|
// TODO (@NickMeves): Remove when this is deprecated (likely V7)
|
||
|
func Test_legacyV5LoadSession(t *testing.T) {
|
||
|
testCases, _, _ := sessions.CreateLegacyV5TestCases(t)
|
||
|
|
||
|
for testName, tc := range testCases {
|
||
|
t.Run(testName, func(t *testing.T) {
|
||
|
g := NewWithT(t)
|
||
|
|
||
|
secret := make([]byte, aes.BlockSize)
|
||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||
|
g.Expect(err).ToNot(HaveOccurred())
|
||
|
tckt := &ticket{
|
||
|
secret: secret,
|
||
|
options: &options.Cookie{
|
||
|
Secret: base64.RawURLEncoding.EncodeToString([]byte(sessions.LegacyV5TestSecret)),
|
||
|
},
|
||
|
}
|
||
|
|
||
|
encrypted, err := legacyStoreValue(tc.Input, tckt.secret)
|
||
|
g.Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
ss, err := tckt.legacyV5LoadSession(encrypted)
|
||
|
if tc.Error {
|
||
|
g.Expect(err).To(HaveOccurred())
|
||
|
g.Expect(ss).To(BeNil())
|
||
|
return
|
||
|
}
|
||
|
g.Expect(err).ToNot(HaveOccurred())
|
||
|
|
||
|
// Compare sessions without *time.Time fields
|
||
|
exp := *tc.Output
|
||
|
exp.CreatedAt = nil
|
||
|
exp.ExpiresOn = nil
|
||
|
act := *ss
|
||
|
act.CreatedAt = nil
|
||
|
act.ExpiresOn = nil
|
||
|
g.Expect(exp).To(Equal(act))
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// legacyStoreValue implements the legacy V5 Redis persistence AES-CFB value encryption
|
||
|
//
|
||
|
// TODO (@NickMeves): Remove when this is deprecated (likely V7)
|
||
|
func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) {
|
||
|
ciphertext := make([]byte, len(value))
|
||
|
block, err := aes.NewCipher(ticketSecret)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("error initiating cipher block: %v", err)
|
||
|
}
|
||
|
|
||
|
// Use secret as the Initialization Vector too, because each entry has it's own key
|
||
|
stream := cipher.NewCFBEncrypter(block, ticketSecret)
|
||
|
stream.XORKeyStream(ciphertext, []byte(value))
|
||
|
|
||
|
return ciphertext, nil
|
||
|
}
|