mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-13 00:07:26 +02:00
Merge remote-tracking branch 'upstream/master' into helm-example
# Conflicts: # CHANGELOG.md
This commit is contained in:
commit
9a495e996b
@ -56,6 +56,8 @@
|
|||||||
## Changes since v5.1.1
|
## Changes since v5.1.1
|
||||||
|
|
||||||
- [#615](https://github.com/oauth2-proxy/oauth2-proxy/pull/615) Helm Example based on Kind cluster and Nginx ingress (@EvgeniGordeev)
|
- [#615](https://github.com/oauth2-proxy/oauth2-proxy/pull/615) Helm Example based on Kind cluster and Nginx ingress (@EvgeniGordeev)
|
||||||
|
- [#604](https://github.com/oauth2-proxy/oauth2-proxy/pull/604) Add Keycloak local testing environment (@EvgeniGordeev)
|
||||||
|
- [#539](https://github.com/oauth2-proxy/oauth2-proxy/pull/539) Refactor encryption ciphers and add AES-GCM support (@NickMeves)
|
||||||
- [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)
|
- [#601](https://github.com/oauth2-proxy/oauth2-proxy/pull/601) Ensure decrypted user/email are valid UTF8 (@JoelSpeed)
|
||||||
- [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed)
|
- [#560](https://github.com/oauth2-proxy/oauth2-proxy/pull/560) Fallback to UserInfo is User ID claim not present (@JoelSpeed)
|
||||||
- [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer)
|
- [#598](https://github.com/oauth2-proxy/oauth2-proxy/pull/598) acr_values no longer sent to IdP when empty (@ScottGuymer)
|
||||||
|
@ -13,3 +13,11 @@ nginx-up:
|
|||||||
.PHONY: nginx-%
|
.PHONY: nginx-%
|
||||||
nginx-%:
|
nginx-%:
|
||||||
docker-compose -f docker-compose.yaml -f docker-compose-nginx.yaml $*
|
docker-compose -f docker-compose.yaml -f docker-compose-nginx.yaml $*
|
||||||
|
|
||||||
|
.PHONY: keycloak-up
|
||||||
|
keycloak-up:
|
||||||
|
docker-compose -f docker-compose-keycloak.yaml up -d
|
||||||
|
|
||||||
|
.PHONY: keycloak-%
|
||||||
|
keycloak-%:
|
||||||
|
docker-compose -f docker-compose-keycloak.yaml $*
|
||||||
|
70
contrib/local-environment/docker-compose-keycloak.yaml
Normal file
70
contrib/local-environment/docker-compose-keycloak.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# This docker-compose file can be used to bring up an example instance of oauth2-proxy
|
||||||
|
# for manual testing and exploration of features.
|
||||||
|
# Alongside OAuth2-Proxy, this file also starts Keycloak to act as the identity provider,
|
||||||
|
# HTTPBin as an example upstream.
|
||||||
|
#
|
||||||
|
# This can either be created using docker-compose
|
||||||
|
# docker-compose -f docker-compose-keycloak.yaml <command>
|
||||||
|
# Or:
|
||||||
|
# make keycloak-<command> (eg. make keycloak-up, make keycloak-down)
|
||||||
|
#
|
||||||
|
# Access http://oauth2-proxy.localtest.me:4180 to initiate a login cycle using user=admin@example.com, password=password
|
||||||
|
# Access http://keycloak.localtest.me:9080 with the same credentials to check out the settings
|
||||||
|
version: '3.0'
|
||||||
|
services:
|
||||||
|
|
||||||
|
oauth2-proxy:
|
||||||
|
container_name: oauth2-proxy
|
||||||
|
image: quay.io/oauth2-proxy/oauth2-proxy:v5.1.1
|
||||||
|
command: --config /oauth2-proxy.cfg
|
||||||
|
hostname: oauth2-proxy
|
||||||
|
volumes:
|
||||||
|
- "./oauth2-proxy-keycloak.cfg:/oauth2-proxy.cfg"
|
||||||
|
restart: unless-stopped
|
||||||
|
networks:
|
||||||
|
keycloak: {}
|
||||||
|
httpbin: {}
|
||||||
|
oauth2-proxy: {}
|
||||||
|
depends_on:
|
||||||
|
- httpbin
|
||||||
|
- keycloak
|
||||||
|
ports:
|
||||||
|
- 4180:4180/tcp
|
||||||
|
|
||||||
|
httpbin:
|
||||||
|
container_name: httpbin
|
||||||
|
image: kennethreitz/httpbin:latest
|
||||||
|
hostname: httpbin
|
||||||
|
networks:
|
||||||
|
httpbin: {}
|
||||||
|
|
||||||
|
keycloak:
|
||||||
|
container_name: keycloak
|
||||||
|
image: jboss/keycloak:10.0.0
|
||||||
|
hostname: keycloak
|
||||||
|
command:
|
||||||
|
[
|
||||||
|
'-b',
|
||||||
|
'0.0.0.0',
|
||||||
|
'-Djboss.socket.binding.port-offset=1000',
|
||||||
|
'-Dkeycloak.migration.action=import',
|
||||||
|
'-Dkeycloak.migration.provider=dir',
|
||||||
|
'-Dkeycloak.migration.dir=/realm-config',
|
||||||
|
'-Dkeycloak.migration.strategy=IGNORE_EXISTING',
|
||||||
|
]
|
||||||
|
volumes:
|
||||||
|
- ./keycloak:/realm-config
|
||||||
|
environment:
|
||||||
|
KEYCLOAK_USER: admin@example.com
|
||||||
|
KEYCLOAK_PASSWORD: password
|
||||||
|
networks:
|
||||||
|
keycloak:
|
||||||
|
aliases:
|
||||||
|
- keycloak.localtest.me
|
||||||
|
ports:
|
||||||
|
- 9080:9080/tcp
|
||||||
|
|
||||||
|
networks:
|
||||||
|
httpbin: {}
|
||||||
|
keycloak: {}
|
||||||
|
oauth2-proxy: {}
|
1684
contrib/local-environment/keycloak/master-realm.json
Normal file
1684
contrib/local-environment/keycloak/master-realm.json
Normal file
File diff suppressed because it is too large
Load Diff
27
contrib/local-environment/keycloak/master-users-0.json
Normal file
27
contrib/local-environment/keycloak/master-users-0.json
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"realm" : "master",
|
||||||
|
"users" : [ {
|
||||||
|
"id" : "3356c0a0-d4d5-4436-9c5a-2299c71c08ec",
|
||||||
|
"createdTimestamp" : 1591297959169,
|
||||||
|
"username" : "admin@example.com",
|
||||||
|
"email" : "admin@example.com",
|
||||||
|
"enabled" : true,
|
||||||
|
"totp" : false,
|
||||||
|
"emailVerified" : true,
|
||||||
|
"credentials" : [ {
|
||||||
|
"id" : "a1a06ecd-fdc0-4e67-92cd-2da22d724e32",
|
||||||
|
"type" : "password",
|
||||||
|
"createdDate" : 1591297959315,
|
||||||
|
"secretData" : "{\"value\":\"6rt5zuqHVHopvd0FTFE0CYadXTtzY0mDY2BrqnNQGS51/7DfMJeGgj0roNnGMGvDv30imErNmiSOYl+cL9jiIA==\",\"salt\":\"LI0kqr09JB7J9wvr2Hxzzg==\"}",
|
||||||
|
"credentialData" : "{\"hashIterations\":27500,\"algorithm\":\"pbkdf2-sha256\"}"
|
||||||
|
} ],
|
||||||
|
"disableableCredentialTypes" : [ ],
|
||||||
|
"requiredActions" : [ ],
|
||||||
|
"realmRoles" : [ "offline_access", "admin", "uma_authorization" ],
|
||||||
|
"clientRoles" : {
|
||||||
|
"account" : [ "view-profile", "manage-account" ]
|
||||||
|
},
|
||||||
|
"notBefore" : 0,
|
||||||
|
"groups" : [ ]
|
||||||
|
} ]
|
||||||
|
}
|
20
contrib/local-environment/oauth2-proxy-keycloak.cfg
Normal file
20
contrib/local-environment/oauth2-proxy-keycloak.cfg
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
http_address="0.0.0.0:4180"
|
||||||
|
cookie_secret="OQINaROshtE9TcZkNAm-5Zs2Pv3xaWytBmc5W7sPX7w="
|
||||||
|
email_domains=["example.com"]
|
||||||
|
cookie_secure="false"
|
||||||
|
upstreams="http://httpbin"
|
||||||
|
cookie_domains=[".localtest.me"] # Required so cookie can be read on all subdomains.
|
||||||
|
whitelist_domains=[".localtest.me"] # Required to allow redirection back to original requested target.
|
||||||
|
|
||||||
|
# keycloak provider
|
||||||
|
client_secret="72341b6d-7065-4518-a0e4-50ee15025608"
|
||||||
|
client_id="oauth2-proxy"
|
||||||
|
redirect_url="http://oauth2-proxy.localtest.me:4180/oauth2/callback"
|
||||||
|
|
||||||
|
# in this case oauth2-proxy is going to visit
|
||||||
|
# http://keycloak.localtest.me:9080/auth/realms/master/.well-known/openid-configuration for configuration
|
||||||
|
oidc_issuer_url="http://keycloak.localtest.me:9080/auth/realms/master"
|
||||||
|
provider="oidc"
|
||||||
|
provider_display_name="Keycloak"
|
||||||
|
|
||||||
|
|
@ -44,7 +44,7 @@ An example [oauth2-proxy.cfg]({{ site.gitweb }}/contrib/oauth2-proxy.cfg.example
|
|||||||
| `--cookie-samesite` | string | set SameSite cookie attribute (ie: `"lax"`, `"strict"`, `"none"`, or `""`). | `""` |
|
| `--cookie-samesite` | string | set SameSite cookie attribute (ie: `"lax"`, `"strict"`, `"none"`, or `""`). | `""` |
|
||||||
| `--custom-templates-dir` | string | path to custom html templates | |
|
| `--custom-templates-dir` | string | path to custom html templates | |
|
||||||
| `--display-htpasswd-form` | bool | display username / password login form if an htpasswd file is provided | true |
|
| `--display-htpasswd-form` | bool | display username / password login form if an htpasswd file is provided | true |
|
||||||
| `--email-domain` | string | authenticate emails with the specified domain (may be given multiple times). Use `*` to authenticate any email | |
|
| `--email-domain` | string \| list | authenticate emails with the specified domain (may be given multiple times). Use `*` to authenticate any email | |
|
||||||
| `--extra-jwt-issuers` | string | if `--skip-jwt-bearer-tokens` is set, a list of extra JWT `issuer=audience` pairs (where the issuer URL has a `.well-known/openid-configuration` or a `.well-known/jwks.json`) | |
|
| `--extra-jwt-issuers` | string | if `--skip-jwt-bearer-tokens` is set, a list of extra JWT `issuer=audience` pairs (where the issuer URL has a `.well-known/openid-configuration` or a `.well-known/jwks.json`) | |
|
||||||
| `--exclude-logging-paths` | string | comma separated list of paths to exclude from logging, eg: `"/ping,/path2"` |`""` (no paths excluded) |
|
| `--exclude-logging-paths` | string | comma separated list of paths to exclude from logging, eg: `"/ping,/path2"` |`""` (no paths excluded) |
|
||||||
| `--flush-interval` | duration | period between flushing response buffers when streaming responses | `"1s"` |
|
| `--flush-interval` | duration | period between flushing response buffers when streaming responses | `"1s"` |
|
||||||
|
@ -5,7 +5,7 @@ import "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
|||||||
// SessionOptions contains configuration options for the SessionStore providers.
|
// SessionOptions contains configuration options for the SessionStore providers.
|
||||||
type SessionOptions struct {
|
type SessionOptions struct {
|
||||||
Type string `flag:"session-store-type" cfg:"session_store_type"`
|
Type string `flag:"session-store-type" cfg:"session_store_type"`
|
||||||
Cipher *encryption.Cipher `cfg:",internal"`
|
Cipher encryption.Cipher `cfg:",internal"`
|
||||||
Redis RedisStoreOptions `cfg:",squash"`
|
Redis RedisStoreOptions `cfg:",squash"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ func (s *SessionState) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EncodeSessionState returns string representation of the current session
|
// EncodeSessionState returns string representation of the current session
|
||||||
func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) {
|
func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) {
|
||||||
var ss SessionState
|
var ss SessionState
|
||||||
if c == nil {
|
if c == nil {
|
||||||
// Store only Email and User when cipher is unavailable
|
// Store only Email and User when cipher is unavailable
|
||||||
@ -77,7 +77,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
|
|||||||
&ss.IDToken,
|
&ss.IDToken,
|
||||||
&ss.RefreshToken,
|
&ss.RefreshToken,
|
||||||
} {
|
} {
|
||||||
err := c.EncryptInto(s)
|
err := into(s, c.Encrypt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -89,7 +89,7 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DecodeSessionState decodes the session cookie string into a SessionState
|
// DecodeSessionState decodes the session cookie string into a SessionState
|
||||||
func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
|
||||||
var ss SessionState
|
var ss SessionState
|
||||||
err := json.Unmarshal([]byte(v), &ss)
|
err := json.Unmarshal([]byte(v), &ss)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -104,25 +104,19 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
|||||||
PreferredUsername: ss.PreferredUsername,
|
PreferredUsername: ss.PreferredUsername,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Backward compatibility with using unencrypted Email
|
// Backward compatibility with using unencrypted Email or User
|
||||||
if ss.Email != "" {
|
// Decryption errors will leave original string
|
||||||
decryptedEmail, errEmail := c.Decrypt(ss.Email)
|
err = into(&ss.Email, c.Decrypt)
|
||||||
if errEmail == nil {
|
if err == nil {
|
||||||
if !utf8.ValidString(decryptedEmail) {
|
if !utf8.ValidString(ss.Email) {
|
||||||
return nil, errors.New("invalid value for decrypted email")
|
return nil, errors.New("invalid value for decrypted email")
|
||||||
}
|
}
|
||||||
ss.Email = decryptedEmail
|
|
||||||
}
|
}
|
||||||
}
|
err = into(&ss.User, c.Decrypt)
|
||||||
// Backward compatibility with using unencrypted User
|
if err == nil {
|
||||||
if ss.User != "" {
|
if !utf8.ValidString(ss.User) {
|
||||||
decryptedUser, errUser := c.Decrypt(ss.User)
|
|
||||||
if errUser == nil {
|
|
||||||
if !utf8.ValidString(decryptedUser) {
|
|
||||||
return nil, errors.New("invalid value for decrypted user")
|
return nil, errors.New("invalid value for decrypted user")
|
||||||
}
|
}
|
||||||
ss.User = decryptedUser
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range []*string{
|
for _, s := range []*string{
|
||||||
@ -131,7 +125,7 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
|||||||
&ss.IDToken,
|
&ss.IDToken,
|
||||||
&ss.RefreshToken,
|
&ss.RefreshToken,
|
||||||
} {
|
} {
|
||||||
err := c.DecryptInto(s)
|
err := into(s, c.Decrypt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -139,3 +133,20 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
|||||||
}
|
}
|
||||||
return &ss, nil
|
return &ss, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// codecFunc is a function that takes a []byte and encodes/decodes it
|
||||||
|
type codecFunc func([]byte) ([]byte, error)
|
||||||
|
|
||||||
|
func into(s *string, f codecFunc) error {
|
||||||
|
// Do not encrypt/decrypt nil or empty strings
|
||||||
|
if s == nil || *s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d, err := f([]byte(*s))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*s = string(d)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
package sessions_test
|
package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
mathrand "math/rand"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
@ -17,12 +19,16 @@ func timePtr(t time.Time) *time.Time {
|
|||||||
return &t
|
return &t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTestCipher(secret []byte) (encryption.Cipher, error) {
|
||||||
|
return encryption.NewBase64Cipher(encryption.NewCFBCipher, secret)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSessionStateSerialization(t *testing.T) {
|
func TestSessionStateSerialization(t *testing.T) {
|
||||||
c, err := encryption.NewCipher([]byte(secret))
|
c, err := newTestCipher([]byte(secret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c2, err := encryption.NewCipher([]byte(altSecret))
|
c2, err := newTestCipher([]byte(altSecret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
s := &sessions.SessionState{
|
s := &SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
PreferredUsername: "user",
|
PreferredUsername: "user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -34,7 +40,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
ss, err := DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "", ss.User)
|
assert.Equal(t, "", ss.User)
|
||||||
@ -47,17 +53,17 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
ss, err = DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationWithUser(t *testing.T) {
|
func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||||
c, err := encryption.NewCipher([]byte(secret))
|
c, err := newTestCipher([]byte(secret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c2, err := encryption.NewCipher([]byte(altSecret))
|
c2, err := newTestCipher([]byte(altSecret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
s := &sessions.SessionState{
|
s := &SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
PreferredUsername: "ju",
|
PreferredUsername: "ju",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
@ -69,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
ss, err := DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
@ -81,13 +87,13 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
ss, err = DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||||
s := &sessions.SessionState{
|
s := &SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
PreferredUsername: "user",
|
PreferredUsername: "user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -99,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
ss, err := DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "", ss.User)
|
assert.Equal(t, "", ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
@ -109,7 +115,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||||
s := &sessions.SessionState{
|
s := &SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
PreferredUsername: "user",
|
PreferredUsername: "user",
|
||||||
@ -122,7 +128,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
ss, err := DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
@ -132,20 +138,20 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExpired(t *testing.T) {
|
func TestExpired(t *testing.T) {
|
||||||
s := &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
|
s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
|
||||||
assert.Equal(t, true, s.IsExpired())
|
assert.Equal(t, true, s.IsExpired())
|
||||||
|
|
||||||
s = &sessions.SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
|
s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
|
||||||
assert.Equal(t, false, s.IsExpired())
|
assert.Equal(t, false, s.IsExpired())
|
||||||
|
|
||||||
s = &sessions.SessionState{}
|
s = &SessionState{}
|
||||||
assert.Equal(t, false, s.IsExpired())
|
assert.Equal(t, false, s.IsExpired())
|
||||||
}
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
sessions.SessionState
|
SessionState
|
||||||
Encoded string
|
Encoded string
|
||||||
Cipher *encryption.Cipher
|
Cipher encryption.Cipher
|
||||||
Error bool
|
Error bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,14 +165,14 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -181,7 +187,7 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||||
if tc.Error {
|
if tc.Error {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Empty(t, encoded)
|
assert.Empty(t, encoded)
|
||||||
@ -201,39 +207,39 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
eJSON, _ := e.MarshalJSON()
|
eJSON, _ := e.MarshalJSON()
|
||||||
eString := string(eJSON)
|
eString := string(eJSON)
|
||||||
|
|
||||||
c, err := encryption.NewCipher([]byte(secret))
|
c, err := newTestCipher([]byte(secret))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "",
|
User: "",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com"}`,
|
Encoded: `{"Email":"user@domain.com"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"User":"just-user"}`,
|
Encoded: `{"User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -246,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Cipher: c,
|
Cipher: c,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
@ -264,7 +270,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Error: true,
|
Error: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
|
User: "YmFzZTY0LWVuY29kZWQtdXNlcgo=", // Base64 encoding of base64-encoded-user
|
||||||
},
|
},
|
||||||
@ -274,8 +280,8 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
|
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
t.Logf("i:%d Encoded:%#vSessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||||
if tc.Error {
|
if tc.Error {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, ss)
|
assert.Nil(t, ss)
|
||||||
@ -297,7 +303,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateAge(t *testing.T) {
|
func TestSessionStateAge(t *testing.T) {
|
||||||
ss := &sessions.SessionState{}
|
ss := &SessionState{}
|
||||||
|
|
||||||
// Created at unset so should be 0
|
// Created at unset so should be 0
|
||||||
assert.Equal(t, time.Duration(0), ss.Age())
|
assert.Equal(t, time.Duration(0), ss.Age())
|
||||||
@ -306,3 +312,44 @@ func TestSessionStateAge(t *testing.T) {
|
|||||||
ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
|
ss.CreatedAt = timePtr(time.Now().Add(-1 * time.Hour))
|
||||||
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIntoEncryptAndIntoDecrypt(t *testing.T) {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
|
||||||
|
// Test all 3 valid AES sizes
|
||||||
|
for _, secretSize := range []int{16, 24, 32} {
|
||||||
|
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||||
|
secret := make([]byte, secretSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
c, err := newTestCipher(secret)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check no errors with empty or nil strings
|
||||||
|
empty := ""
|
||||||
|
assert.Equal(t, nil, into(&empty, c.Encrypt))
|
||||||
|
assert.Equal(t, nil, into(&empty, c.Decrypt))
|
||||||
|
assert.Equal(t, nil, into(nil, c.Encrypt))
|
||||||
|
assert.Equal(t, nil, into(nil, c.Decrypt))
|
||||||
|
|
||||||
|
// Test various sizes tokens might be
|
||||||
|
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||||
|
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||||
|
b := make([]byte, dataSize)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[mathrand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
data := string(b)
|
||||||
|
originalData := data
|
||||||
|
|
||||||
|
assert.Equal(t, nil, into(&data, c.Encrypt))
|
||||||
|
assert.NotEqual(t, originalData, data)
|
||||||
|
|
||||||
|
assert.Equal(t, nil, into(&data, c.Decrypt))
|
||||||
|
assert.Equal(t, originalData, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,183 +3,134 @@ package encryption
|
|||||||
import (
|
import (
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
|
// Cipher provides methods to encrypt and decrypt
|
||||||
func SecretBytes(secret string) []byte {
|
type Cipher interface {
|
||||||
b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "="))
|
Encrypt(value []byte) ([]byte, error)
|
||||||
if err == nil {
|
Decrypt(ciphertext []byte) ([]byte, error)
|
||||||
// Only return decoded form if a valid AES length
|
|
||||||
// Don't want unintentional decoding resulting in invalid lengths confusing a user
|
|
||||||
// that thought they used a 16, 24, 32 length string
|
|
||||||
for _, i := range []int{16, 24, 32} {
|
|
||||||
if len(b) == i {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If decoding didn't work or resulted in non-AES compliant length,
|
|
||||||
// assume the raw string was the intended secret
|
|
||||||
return []byte(secret)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
|
type base64Cipher struct {
|
||||||
// additionally, the 'value' is encrypted so it's opaque to the browser
|
Cipher Cipher
|
||||||
|
|
||||||
// Validate ensures a cookie is properly signed
|
|
||||||
func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
|
|
||||||
// value, timestamp, sig
|
|
||||||
parts := strings.Split(cookie.Value, "|")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) {
|
|
||||||
ts, err := strconv.Atoi(parts[1])
|
// NewBase64Cipher returns a new AES Cipher for encrypting cookie values
|
||||||
|
// and wrapping them in Base64 -- Supports Legacy encryption scheme
|
||||||
|
func NewBase64Cipher(initCipher func([]byte) (Cipher, error), secret []byte) (Cipher, error) {
|
||||||
|
c, err := initCipher(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
// The expiration timestamp set when the cookie was created
|
return &base64Cipher{Cipher: c}, nil
|
||||||
// isn't sent back by the browser. Hence, we check whether the
|
|
||||||
// creation timestamp stored in the cookie falls within the
|
|
||||||
// window defined by (Now()-expiration, Now()].
|
|
||||||
t = time.Unix(int64(ts), 0)
|
|
||||||
if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
|
|
||||||
// it's a valid cookie. now get the contents
|
|
||||||
rawValue, err := base64.URLEncoding.DecodeString(parts[0])
|
|
||||||
if err == nil {
|
|
||||||
value = string(rawValue)
|
|
||||||
ok = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignedValue returns a cookie that is signed and can later be checked with Validate
|
// Encrypt encrypts a value with the embedded Cipher & Base64 encodes it
|
||||||
func SignedValue(seed string, key string, value string, now time.Time) string {
|
func (c *base64Cipher) Encrypt(value []byte) ([]byte, error) {
|
||||||
encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
|
encrypted, err := c.Cipher.Encrypt(value)
|
||||||
timeStr := fmt.Sprintf("%d", now.Unix())
|
if err != nil {
|
||||||
sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr)
|
return nil, err
|
||||||
cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
|
|
||||||
return cookieVal
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func cookieSignature(signer func() hash.Hash, args ...string) string {
|
return []byte(base64.StdEncoding.EncodeToString(encrypted)), nil
|
||||||
h := hmac.New(signer, []byte(args[0]))
|
|
||||||
for _, arg := range args[1:] {
|
|
||||||
h.Write([]byte(arg))
|
|
||||||
}
|
|
||||||
var b []byte
|
|
||||||
b = h.Sum(b)
|
|
||||||
return base64.URLEncoding.EncodeToString(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSignature(signature string, args ...string) bool {
|
// Decrypt Base64 decodes a value & decrypts it with the embedded Cipher
|
||||||
checkSig := cookieSignature(sha256.New, args...)
|
func (c *base64Cipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||||
if checkHmac(signature, checkSig) {
|
encrypted, err := base64.StdEncoding.DecodeString(string(ciphertext))
|
||||||
return true
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to base64 decode value %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: After appropriate rollout window, remove support for SHA1
|
return c.Cipher.Decrypt(encrypted)
|
||||||
legacySig := cookieSignature(sha1.New, args...)
|
|
||||||
return checkHmac(signature, legacySig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkHmac(input, expected string) bool {
|
type cfbCipher struct {
|
||||||
inputMAC, err1 := base64.URLEncoding.DecodeString(input)
|
|
||||||
if err1 == nil {
|
|
||||||
expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
|
|
||||||
if err2 == nil {
|
|
||||||
return hmac.Equal(inputMAC, expectedMAC)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cipher provides methods to encrypt and decrypt cookie values
|
|
||||||
type Cipher struct {
|
|
||||||
cipher.Block
|
cipher.Block
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCipher returns a new aes Cipher for encrypting cookie values
|
// NewCFBCipher returns a new AES CFB Cipher
|
||||||
func NewCipher(secret []byte) (*Cipher, error) {
|
func NewCFBCipher(secret []byte) (Cipher, error) {
|
||||||
c, err := aes.NewCipher(secret)
|
c, err := aes.NewCipher(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Cipher{Block: c}, err
|
return &cfbCipher{Block: c}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt a value for use in a cookie
|
// Encrypt with AES CFB
|
||||||
func (c *Cipher) Encrypt(value string) (string, error) {
|
func (c *cfbCipher) Encrypt(value []byte) ([]byte, error) {
|
||||||
ciphertext := make([]byte, aes.BlockSize+len(value))
|
ciphertext := make([]byte, aes.BlockSize+len(value))
|
||||||
iv := ciphertext[:aes.BlockSize]
|
iv := ciphertext[:aes.BlockSize]
|
||||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||||
return "", fmt.Errorf("failed to create initialization vector %s", err)
|
return nil, fmt.Errorf("failed to create initialization vector %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stream := cipher.NewCFBEncrypter(c.Block, iv)
|
stream := cipher.NewCFBEncrypter(c.Block, iv)
|
||||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value))
|
stream.XORKeyStream(ciphertext[aes.BlockSize:], value)
|
||||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
return ciphertext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt a value from a cookie to it's original string
|
// Decrypt an AES CFB ciphertext
|
||||||
func (c *Cipher) Decrypt(s string) (string, error) {
|
func (c *cfbCipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||||
encrypted, err := base64.StdEncoding.DecodeString(s)
|
if len(ciphertext) < aes.BlockSize {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("encrypted value should be at least %d bytes, but is only %d bytes", aes.BlockSize, len(ciphertext))
|
||||||
return "", fmt.Errorf("failed to decrypt cookie value %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(encrypted) < aes.BlockSize {
|
iv, ciphertext := ciphertext[:aes.BlockSize], ciphertext[aes.BlockSize:]
|
||||||
return "", fmt.Errorf("encrypted cookie value should be "+
|
plaintext := make([]byte, len(ciphertext))
|
||||||
"at least %d bytes, but is only %d bytes",
|
|
||||||
aes.BlockSize, len(encrypted))
|
|
||||||
}
|
|
||||||
|
|
||||||
iv := encrypted[:aes.BlockSize]
|
|
||||||
encrypted = encrypted[aes.BlockSize:]
|
|
||||||
stream := cipher.NewCFBDecrypter(c.Block, iv)
|
stream := cipher.NewCFBDecrypter(c.Block, iv)
|
||||||
stream.XORKeyStream(encrypted, encrypted)
|
stream.XORKeyStream(plaintext, ciphertext)
|
||||||
|
|
||||||
return string(encrypted), nil
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EncryptInto encrypts the value and stores it back in the string pointer
|
type gcmCipher struct {
|
||||||
func (c *Cipher) EncryptInto(s *string) error {
|
cipher.Block
|
||||||
return into(c.Encrypt, s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptInto decrypts the value and stores it back in the string pointer
|
// NewGCMCipher returns a new AES GCM Cipher
|
||||||
func (c *Cipher) DecryptInto(s *string) error {
|
func NewGCMCipher(secret []byte) (Cipher, error) {
|
||||||
return into(c.Decrypt, s)
|
c, err := aes.NewCipher(secret)
|
||||||
}
|
|
||||||
|
|
||||||
// codecFunc is a function that takes a string and encodes/decodes it
|
|
||||||
type codecFunc func(string) (string, error)
|
|
||||||
|
|
||||||
func into(f codecFunc, s *string) error {
|
|
||||||
// Do not encrypt/decrypt nil or empty strings
|
|
||||||
if s == nil || *s == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := f(*s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
*s = d
|
return &gcmCipher{Block: c}, err
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
// Encrypt with AES GCM on raw bytes
|
||||||
|
func (c *gcmCipher) Encrypt(value []byte) ([]byte, error) {
|
||||||
|
gcm, err := cipher.NewGCM(c.Block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Using nonce as Seal's dst argument results in it being the first
|
||||||
|
// chunk of bytes in the ciphertext. Decrypt retrieves the nonce/IV from this.
|
||||||
|
ciphertext := gcm.Seal(nonce, nonce, value, nil)
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt an AES GCM ciphertext
|
||||||
|
func (c *gcmCipher) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||||
|
gcm, err := cipher.NewGCM(c.Block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
|
||||||
|
|
||||||
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return plaintext, nil
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,6 @@ package encryption
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -12,107 +10,20 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSecretBytesEncoded(t *testing.T) {
|
|
||||||
for _, secretSize := range []int{16, 24, 32} {
|
|
||||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
|
||||||
secret := make([]byte, secretSize)
|
|
||||||
_, err := io.ReadFull(rand.Reader, secret)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
|
|
||||||
// We test both padded & raw Base64 to ensure we handle both
|
|
||||||
// potential user input routes for Base64
|
|
||||||
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
|
||||||
sb := SecretBytes(base64Padded)
|
|
||||||
assert.Equal(t, secret, sb)
|
|
||||||
assert.Equal(t, len(sb), secretSize)
|
|
||||||
|
|
||||||
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
|
||||||
sb = SecretBytes(base64Raw)
|
|
||||||
assert.Equal(t, secret, sb)
|
|
||||||
assert.Equal(t, len(sb), secretSize)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// A string that isn't intended as Base64 and still decodes (but to unintended length)
|
|
||||||
// will return the original secret as bytes
|
|
||||||
func TestSecretBytesEncodedWrongSize(t *testing.T) {
|
|
||||||
for _, secretSize := range []int{15, 20, 28, 33, 44} {
|
|
||||||
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
|
||||||
secret := make([]byte, secretSize)
|
|
||||||
_, err := io.ReadFull(rand.Reader, secret)
|
|
||||||
assert.Equal(t, nil, err)
|
|
||||||
|
|
||||||
// We test both padded & raw Base64 to ensure we handle both
|
|
||||||
// potential user input routes for Base64
|
|
||||||
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
|
||||||
sb := SecretBytes(base64Padded)
|
|
||||||
assert.NotEqual(t, secret, sb)
|
|
||||||
assert.NotEqual(t, len(sb), secretSize)
|
|
||||||
// The given secret is returned as []byte
|
|
||||||
assert.Equal(t, base64Padded, string(sb))
|
|
||||||
|
|
||||||
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
|
||||||
sb = SecretBytes(base64Raw)
|
|
||||||
assert.NotEqual(t, secret, sb)
|
|
||||||
assert.NotEqual(t, len(sb), secretSize)
|
|
||||||
// The given secret is returned as []byte
|
|
||||||
assert.Equal(t, base64Raw, string(sb))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSecretBytesNonBase64(t *testing.T) {
|
|
||||||
trailer := "equals=========="
|
|
||||||
assert.Equal(t, trailer, string(SecretBytes(trailer)))
|
|
||||||
|
|
||||||
raw16 := "asdflkjhqwer)(*&"
|
|
||||||
sb16 := SecretBytes(raw16)
|
|
||||||
assert.Equal(t, raw16, string(sb16))
|
|
||||||
assert.Equal(t, 16, len(sb16))
|
|
||||||
|
|
||||||
raw24 := "asdflkjhqwer)(*&CJEN#$%^"
|
|
||||||
sb24 := SecretBytes(raw24)
|
|
||||||
assert.Equal(t, raw24, string(sb24))
|
|
||||||
assert.Equal(t, 24, len(sb24))
|
|
||||||
|
|
||||||
raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&"
|
|
||||||
sb32 := SecretBytes(raw32)
|
|
||||||
assert.Equal(t, raw32, string(sb32))
|
|
||||||
assert.Equal(t, 32, len(sb32))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignAndValidate(t *testing.T) {
|
|
||||||
seed := "0123456789abcdef"
|
|
||||||
key := "cookie-name"
|
|
||||||
value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded"))
|
|
||||||
epoch := "123456789"
|
|
||||||
|
|
||||||
sha256sig := cookieSignature(sha256.New, seed, key, value, epoch)
|
|
||||||
sha1sig := cookieSignature(sha1.New, seed, key, value, epoch)
|
|
||||||
|
|
||||||
assert.True(t, checkSignature(sha256sig, seed, key, value, epoch))
|
|
||||||
// This should be switched to False after fully deprecating SHA1
|
|
||||||
assert.True(t, checkSignature(sha1sig, seed, key, value, epoch))
|
|
||||||
|
|
||||||
assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch))
|
|
||||||
assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
||||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
const secret = "0123456789abcdefghijklmnopqrstuv"
|
||||||
const token = "my access token"
|
const token = "my access token"
|
||||||
c, err := NewCipher([]byte(secret))
|
c, err := NewBase64Cipher(NewCFBCipher, []byte(secret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
encoded, err := c.Encrypt(token)
|
encoded, err := c.Encrypt([]byte(token))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
decoded, err := c.Decrypt(encoded)
|
decoded, err := c.Decrypt(encoded)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
assert.NotEqual(t, token, encoded)
|
assert.NotEqual(t, []byte(token), encoded)
|
||||||
assert.Equal(t, token, decoded)
|
assert.Equal(t, []byte(token), decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
|
func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
|
||||||
@ -121,37 +32,199 @@ func TestEncodeAndDecodeAccessTokenB64(t *testing.T) {
|
|||||||
|
|
||||||
secret, err := base64.URLEncoding.DecodeString(secretBase64)
|
secret, err := base64.URLEncoding.DecodeString(secretBase64)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c, err := NewCipher([]byte(secret))
|
c, err := NewBase64Cipher(NewCFBCipher, []byte(secret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
encoded, err := c.Encrypt(token)
|
encoded, err := c.Encrypt([]byte(token))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
decoded, err := c.Decrypt(encoded)
|
decoded, err := c.Decrypt(encoded)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
assert.NotEqual(t, token, encoded)
|
assert.NotEqual(t, []byte(token), encoded)
|
||||||
assert.Equal(t, token, decoded)
|
assert.Equal(t, []byte(token), decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeIntoAndDecodeIntoAccessToken(t *testing.T) {
|
func TestEncryptAndDecrypt(t *testing.T) {
|
||||||
const secret = "0123456789abcdefghijklmnopqrstuv"
|
// Test our 2 cipher types
|
||||||
c, err := NewCipher([]byte(secret))
|
cipherInits := map[string]func([]byte) (Cipher, error){
|
||||||
|
"CFB": NewCFBCipher,
|
||||||
|
"GCM": NewGCMCipher,
|
||||||
|
}
|
||||||
|
for name, initCipher := range cipherInits {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
// Test all 3 valid AES sizes
|
||||||
|
for _, secretSize := range []int{16, 24, 32} {
|
||||||
|
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||||
|
secret := make([]byte, secretSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
token := "my access token"
|
// Test Standard & Base64 wrapped
|
||||||
originalToken := token
|
cstd, err := initCipher(secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
assert.Equal(t, nil, c.EncryptInto(&token))
|
cb64, err := NewBase64Cipher(initCipher, secret)
|
||||||
assert.NotEqual(t, originalToken, token)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
assert.Equal(t, nil, c.DecryptInto(&token))
|
ciphers := map[string]Cipher{
|
||||||
assert.Equal(t, originalToken, token)
|
"Standard": cstd,
|
||||||
|
"Base64": cb64,
|
||||||
// Check no errors with empty or nil strings
|
}
|
||||||
empty := ""
|
|
||||||
assert.Equal(t, nil, c.EncryptInto(&empty))
|
for cName, c := range ciphers {
|
||||||
assert.Equal(t, nil, c.DecryptInto(&empty))
|
t.Run(cName, func(t *testing.T) {
|
||||||
assert.Equal(t, nil, c.EncryptInto(nil))
|
// Test various sizes sessions might be
|
||||||
assert.Equal(t, nil, c.DecryptInto(nil))
|
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||||
|
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||||
|
runEncryptAndDecrypt(t, c, dataSize)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runEncryptAndDecrypt(t *testing.T, c Cipher, dataSize int) {
|
||||||
|
data := make([]byte, dataSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
// Ensure our Encrypt function doesn't encrypt in place
|
||||||
|
immutableData := make([]byte, len(data))
|
||||||
|
copy(immutableData, data)
|
||||||
|
|
||||||
|
encrypted, err := c.Encrypt(data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, encrypted, data)
|
||||||
|
// Encrypt didn't operate in-place on []byte
|
||||||
|
assert.Equal(t, data, immutableData)
|
||||||
|
|
||||||
|
// Ensure our Decrypt function doesn't decrypt in place
|
||||||
|
immutableEnc := make([]byte, len(encrypted))
|
||||||
|
copy(immutableEnc, encrypted)
|
||||||
|
|
||||||
|
decrypted, err := c.Decrypt(encrypted)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
// Original data back
|
||||||
|
assert.Equal(t, data, decrypted)
|
||||||
|
// Decrypt didn't operate in-place on []byte
|
||||||
|
assert.Equal(t, encrypted, immutableEnc)
|
||||||
|
// Encrypt/Decrypt actually did something
|
||||||
|
assert.NotEqual(t, encrypted, decrypted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptCFBWrongSecret(t *testing.T) {
|
||||||
|
secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
|
||||||
|
secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
|
||||||
|
|
||||||
|
c1, err := NewCFBCipher(secret1)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
c2, err := NewCFBCipher(secret2)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
|
||||||
|
|
||||||
|
ciphertext, err := c1.Encrypt(data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
wrongData, err := c2.Decrypt(ciphertext)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, data, wrongData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptGCMWrongSecret(t *testing.T) {
|
||||||
|
secret1 := []byte("0123456789abcdefghijklmnopqrstuv")
|
||||||
|
secret2 := []byte("9876543210abcdefghijklmnopqrstuv")
|
||||||
|
|
||||||
|
c1, err := NewGCMCipher(secret1)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
c2, err := NewGCMCipher(secret2)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
data := []byte("f3928pufm982374dj02y485dsl34890u2t9nd4028s94dm58y2394087dhmsyt29h8df")
|
||||||
|
|
||||||
|
ciphertext, err := c1.Encrypt(data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
// GCM is authenticated - this should lead to message authentication failed
|
||||||
|
_, err = c2.Decrypt(ciphertext)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt with GCM, Decrypt with CFB: Results in Garbage data
|
||||||
|
func TestGCMtoCFBErrors(t *testing.T) {
|
||||||
|
// Test all 3 valid AES sizes
|
||||||
|
for _, secretSize := range []int{16, 24, 32} {
|
||||||
|
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||||
|
secret := make([]byte, secretSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
gcm, err := NewGCMCipher(secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
cfb, err := NewCFBCipher(secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
// Test various sizes sessions might be
|
||||||
|
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||||
|
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||||
|
data := make([]byte, dataSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
encrypted, err := gcm.Encrypt(data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, encrypted, data)
|
||||||
|
|
||||||
|
decrypted, err := cfb.Decrypt(encrypted)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
// Data is mangled
|
||||||
|
assert.NotEqual(t, data, decrypted)
|
||||||
|
assert.NotEqual(t, encrypted, decrypted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt with CFB, Decrypt with GCM: Results in errors
|
||||||
|
func TestCFBtoGCMErrors(t *testing.T) {
|
||||||
|
// Test all 3 valid AES sizes
|
||||||
|
for _, secretSize := range []int{16, 24, 32} {
|
||||||
|
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||||
|
secret := make([]byte, secretSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
gcm, err := NewGCMCipher(secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
cfb, err := NewCFBCipher(secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
// Test various sizes sessions might be
|
||||||
|
for _, dataSize := range []int{10, 100, 1000, 5000, 10000} {
|
||||||
|
t.Run(fmt.Sprintf("%d", dataSize), func(t *testing.T) {
|
||||||
|
data := make([]byte, dataSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
encrypted, err := cfb.Encrypt(data)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, encrypted, data)
|
||||||
|
|
||||||
|
// GCM is authenticated - this should lead to message authentication failed
|
||||||
|
_, err = gcm.Decrypt(encrypted)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
106
pkg/encryption/utils.go
Normal file
106
pkg/encryption/utils.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
|
||||||
|
func SecretBytes(secret string) []byte {
|
||||||
|
b, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(secret, "="))
|
||||||
|
if err == nil {
|
||||||
|
// Only return decoded form if a valid AES length
|
||||||
|
// Don't want unintentional decoding resulting in invalid lengths confusing a user
|
||||||
|
// that thought they used a 16, 24, 32 length string
|
||||||
|
for _, i := range []int{16, 24, 32} {
|
||||||
|
if len(b) == i {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If decoding didn't work or resulted in non-AES compliant length,
|
||||||
|
// assume the raw string was the intended secret
|
||||||
|
return []byte(secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
|
||||||
|
// additionally, the 'value' is encrypted so it's opaque to the browser
|
||||||
|
|
||||||
|
// Validate ensures a cookie is properly signed
|
||||||
|
func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value []byte, t time.Time, ok bool) {
|
||||||
|
// value, timestamp, sig
|
||||||
|
parts := strings.Split(cookie.Value, "|")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if checkSignature(parts[2], seed, cookie.Name, parts[0], parts[1]) {
|
||||||
|
ts, err := strconv.Atoi(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// The expiration timestamp set when the cookie was created
|
||||||
|
// isn't sent back by the browser. Hence, we check whether the
|
||||||
|
// creation timestamp stored in the cookie falls within the
|
||||||
|
// window defined by (Now()-expiration, Now()].
|
||||||
|
t = time.Unix(int64(ts), 0)
|
||||||
|
if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
|
||||||
|
// it's a valid cookie. now get the contents
|
||||||
|
rawValue, err := base64.URLEncoding.DecodeString(parts[0])
|
||||||
|
if err == nil {
|
||||||
|
value = rawValue
|
||||||
|
ok = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignedValue returns a cookie that is signed and can later be checked with Validate
|
||||||
|
func SignedValue(seed string, key string, value []byte, now time.Time) string {
|
||||||
|
encodedValue := base64.URLEncoding.EncodeToString(value)
|
||||||
|
timeStr := fmt.Sprintf("%d", now.Unix())
|
||||||
|
sig := cookieSignature(sha256.New, seed, key, encodedValue, timeStr)
|
||||||
|
cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
|
||||||
|
return cookieVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func cookieSignature(signer func() hash.Hash, args ...string) string {
|
||||||
|
h := hmac.New(signer, []byte(args[0]))
|
||||||
|
for _, arg := range args[1:] {
|
||||||
|
h.Write([]byte(arg))
|
||||||
|
}
|
||||||
|
var b []byte
|
||||||
|
b = h.Sum(b)
|
||||||
|
return base64.URLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkSignature(signature string, args ...string) bool {
|
||||||
|
checkSig := cookieSignature(sha256.New, args...)
|
||||||
|
if checkHmac(signature, checkSig) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: After appropriate rollout window, remove support for SHA1
|
||||||
|
legacySig := cookieSignature(sha1.New, args...)
|
||||||
|
return checkHmac(signature, legacySig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkHmac(input, expected string) bool {
|
||||||
|
inputMAC, err1 := base64.URLEncoding.DecodeString(input)
|
||||||
|
if err1 == nil {
|
||||||
|
expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
|
||||||
|
if err2 == nil {
|
||||||
|
return hmac.Equal(inputMAC, expectedMAC)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
100
pkg/encryption/utils_test.go
Normal file
100
pkg/encryption/utils_test.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSecretBytesEncoded(t *testing.T) {
|
||||||
|
for _, secretSize := range []int{16, 24, 32} {
|
||||||
|
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||||
|
secret := make([]byte, secretSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
// We test both padded & raw Base64 to ensure we handle both
|
||||||
|
// potential user input routes for Base64
|
||||||
|
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
||||||
|
sb := SecretBytes(base64Padded)
|
||||||
|
assert.Equal(t, secret, sb)
|
||||||
|
assert.Equal(t, len(sb), secretSize)
|
||||||
|
|
||||||
|
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
||||||
|
sb = SecretBytes(base64Raw)
|
||||||
|
assert.Equal(t, secret, sb)
|
||||||
|
assert.Equal(t, len(sb), secretSize)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A string that isn't intended as Base64 and still decodes (but to unintended length)
|
||||||
|
// will return the original secret as bytes
|
||||||
|
func TestSecretBytesEncodedWrongSize(t *testing.T) {
|
||||||
|
for _, secretSize := range []int{15, 20, 28, 33, 44} {
|
||||||
|
t.Run(fmt.Sprintf("%d", secretSize), func(t *testing.T) {
|
||||||
|
secret := make([]byte, secretSize)
|
||||||
|
_, err := io.ReadFull(rand.Reader, secret)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
// We test both padded & raw Base64 to ensure we handle both
|
||||||
|
// potential user input routes for Base64
|
||||||
|
base64Padded := base64.URLEncoding.EncodeToString(secret)
|
||||||
|
sb := SecretBytes(base64Padded)
|
||||||
|
assert.NotEqual(t, secret, sb)
|
||||||
|
assert.NotEqual(t, len(sb), secretSize)
|
||||||
|
// The given secret is returned as []byte
|
||||||
|
assert.Equal(t, base64Padded, string(sb))
|
||||||
|
|
||||||
|
base64Raw := base64.RawURLEncoding.EncodeToString(secret)
|
||||||
|
sb = SecretBytes(base64Raw)
|
||||||
|
assert.NotEqual(t, secret, sb)
|
||||||
|
assert.NotEqual(t, len(sb), secretSize)
|
||||||
|
// The given secret is returned as []byte
|
||||||
|
assert.Equal(t, base64Raw, string(sb))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecretBytesNonBase64(t *testing.T) {
|
||||||
|
trailer := "equals=========="
|
||||||
|
assert.Equal(t, trailer, string(SecretBytes(trailer)))
|
||||||
|
|
||||||
|
raw16 := "asdflkjhqwer)(*&"
|
||||||
|
sb16 := SecretBytes(raw16)
|
||||||
|
assert.Equal(t, raw16, string(sb16))
|
||||||
|
assert.Equal(t, 16, len(sb16))
|
||||||
|
|
||||||
|
raw24 := "asdflkjhqwer)(*&CJEN#$%^"
|
||||||
|
sb24 := SecretBytes(raw24)
|
||||||
|
assert.Equal(t, raw24, string(sb24))
|
||||||
|
assert.Equal(t, 24, len(sb24))
|
||||||
|
|
||||||
|
raw32 := "asdflkjhqwer)(*&1234lkjhqwer)(*&"
|
||||||
|
sb32 := SecretBytes(raw32)
|
||||||
|
assert.Equal(t, raw32, string(sb32))
|
||||||
|
assert.Equal(t, 32, len(sb32))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignAndValidate(t *testing.T) {
|
||||||
|
seed := "0123456789abcdef"
|
||||||
|
key := "cookie-name"
|
||||||
|
value := base64.URLEncoding.EncodeToString([]byte("I am soooo encoded"))
|
||||||
|
epoch := "123456789"
|
||||||
|
|
||||||
|
sha256sig := cookieSignature(sha256.New, seed, key, value, epoch)
|
||||||
|
sha1sig := cookieSignature(sha1.New, seed, key, value, epoch)
|
||||||
|
|
||||||
|
assert.True(t, checkSignature(sha256sig, seed, key, value, epoch))
|
||||||
|
// This should be switched to False after fully deprecating SHA1
|
||||||
|
assert.True(t, checkSignature(sha1sig, seed, key, value, epoch))
|
||||||
|
|
||||||
|
assert.False(t, checkSignature(sha256sig, seed, key, "tampered", epoch))
|
||||||
|
assert.False(t, checkSignature(sha1sig, seed, key, "tampered", epoch))
|
||||||
|
}
|
@ -28,7 +28,7 @@ var _ sessions.SessionStore = &SessionStore{}
|
|||||||
// interface that stores sessions in client side cookies
|
// interface that stores sessions in client side cookies
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
CookieOptions *options.CookieOptions
|
CookieOptions *options.CookieOptions
|
||||||
CookieCipher *encryption.Cipher
|
CookieCipher encryption.Cipher
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save takes a sessions.SessionState and stores the information from it
|
// Save takes a sessions.SessionState and stores the information from it
|
||||||
@ -59,7 +59,7 @@ func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
|||||||
return nil, errors.New("cookie signature not valid")
|
return nil, errors.New("cookie signature not valid")
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := sessionFromCookie(val, s.CookieCipher)
|
session, err := sessionFromCookie(string(val), s.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -84,12 +84,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// cookieForSession serializes a session state for storage in a cookie
|
// cookieForSession serializes a session state for storage in a cookie
|
||||||
func cookieForSession(s *sessions.SessionState, c *encryption.Cipher) (string, error) {
|
func cookieForSession(s *sessions.SessionState, c encryption.Cipher) (string, error) {
|
||||||
return s.EncodeSessionState(c)
|
return s.EncodeSessionState(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sessionFromCookie deserializes a session from a cookie value
|
// sessionFromCookie deserializes a session from a cookie value
|
||||||
func sessionFromCookie(v string, c *encryption.Cipher) (s *sessions.SessionState, err error) {
|
func sessionFromCookie(v string, c encryption.Cipher) (s *sessions.SessionState, err error) {
|
||||||
return sessions.DecodeSessionState(v, c)
|
return sessions.DecodeSessionState(v, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Reques
|
|||||||
// authentication details
|
// authentication details
|
||||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
|
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
|
||||||
if value != "" {
|
if value != "" {
|
||||||
value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, value, now)
|
value = encryption.SignedValue(s.CookieOptions.Secret, s.CookieOptions.Name, []byte(value), now)
|
||||||
}
|
}
|
||||||
c := s.makeCookie(req, s.CookieOptions.Name, value, s.CookieOptions.Expire, now)
|
c := s.makeCookie(req, s.CookieOptions.Name, value, s.CookieOptions.Expire, now)
|
||||||
if len(c.Value) > 4096-len(s.CookieOptions.Name) {
|
if len(c.Value) > 4096-len(s.CookieOptions.Name) {
|
||||||
|
@ -32,7 +32,7 @@ type TicketData struct {
|
|||||||
// SessionStore is an implementation of the sessions.SessionStore
|
// SessionStore is an implementation of the sessions.SessionStore
|
||||||
// interface that stores sessions in redis
|
// interface that stores sessions in redis
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
CookieCipher *encryption.Cipher
|
CookieCipher encryption.Cipher
|
||||||
CookieOptions *options.CookieOptions
|
CookieOptions *options.CookieOptions
|
||||||
Client Client
|
Client Client
|
||||||
}
|
}
|
||||||
@ -175,7 +175,7 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro
|
|||||||
return nil, fmt.Errorf("cookie signature not valid")
|
return nil, fmt.Errorf("cookie signature not valid")
|
||||||
}
|
}
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
session, err := store.loadSessionFromString(ctx, val)
|
session, err := store.loadSessionFromString(ctx, string(val))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error loading session: %s", err)
|
return nil, fmt.Errorf("error loading session: %s", err)
|
||||||
}
|
}
|
||||||
@ -237,7 +237,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
|
|||||||
|
|
||||||
// We only return an error if we had an issue with redis
|
// We only return an error if we had an issue with redis
|
||||||
// If there's an issue decoding the ticket, ignore it
|
// If there's an issue decoding the ticket, ignore it
|
||||||
ticket, _ := decodeTicket(store.CookieOptions.Name, val)
|
ticket, _ := decodeTicket(store.CookieOptions.Name, string(val))
|
||||||
if ticket != nil {
|
if ticket != nil {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.Name))
|
err := store.Client.Del(ctx, ticket.asHandle(store.CookieOptions.Name))
|
||||||
@ -251,7 +251,7 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
|
|||||||
// makeCookie makes a cookie, signing the value if present
|
// makeCookie makes a cookie, signing the value if present
|
||||||
func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
|
func (store *SessionStore) makeCookie(req *http.Request, value string, expires time.Duration, now time.Time) *http.Cookie {
|
||||||
if value != "" {
|
if value != "" {
|
||||||
value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, value, now)
|
value = encryption.SignedValue(store.CookieOptions.Secret, store.CookieOptions.Name, []byte(value), now)
|
||||||
}
|
}
|
||||||
return cookies.MakeCookieFromOptions(
|
return cookies.MakeCookieFromOptions(
|
||||||
req,
|
req,
|
||||||
@ -302,7 +302,7 @@ func (store *SessionStore) getTicket(requestCookie *http.Cookie) (*TicketData, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Valid cookie, decode the ticket
|
// Valid cookie, decode the ticket
|
||||||
ticket, err := decodeTicket(store.CookieOptions.Name, val)
|
ticket, err := decodeTicket(store.CookieOptions.Name, string(val))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If we can't decode the ticket we have to create a new one
|
// If we can't decode the ticket we have to create a new one
|
||||||
return newTicket()
|
return newTicket()
|
||||||
|
@ -170,7 +170,7 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
By("Using a valid cookie with a different providers session encoding")
|
By("Using a valid cookie with a different providers session encoding")
|
||||||
broken := "BrokenSessionFromADifferentSessionImplementation"
|
broken := "BrokenSessionFromADifferentSessionImplementation"
|
||||||
value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, broken, time.Now())
|
value := encryption.SignedValue(cookieOpts.Secret, cookieOpts.Name, []byte(broken), time.Now())
|
||||||
cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now())
|
cookie := cookiesapi.MakeCookieFromOptions(request, cookieOpts.Name, value, cookieOpts, cookieOpts.Expire, time.Now())
|
||||||
request.AddCookie(cookie)
|
request.AddCookie(cookie)
|
||||||
|
|
||||||
@ -367,7 +367,7 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
_, err := rand.Read(secret)
|
_, err := rand.Read(secret)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret)
|
cookieOpts.Secret = base64.URLEncoding.EncodeToString(secret)
|
||||||
cipher, err := encryption.NewCipher(encryption.SecretBytes(cookieOpts.Secret))
|
cipher, err := encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(cookieOpts.Secret))
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
Expect(cipher).ToNot(BeNil())
|
Expect(cipher).ToNot(BeNil())
|
||||||
opts.Cipher = cipher
|
opts.Cipher = cipher
|
||||||
|
@ -38,7 +38,7 @@ func Validate(o *options.Options) error {
|
|||||||
|
|
||||||
msgs := make([]string, 0)
|
msgs := make([]string, 0)
|
||||||
|
|
||||||
var cipher *encryption.Cipher
|
var cipher encryption.Cipher
|
||||||
if o.Cookie.Secret == "" {
|
if o.Cookie.Secret == "" {
|
||||||
msgs = append(msgs, "missing setting: cookie-secret")
|
msgs = append(msgs, "missing setting: cookie-secret")
|
||||||
} else {
|
} else {
|
||||||
@ -62,7 +62,7 @@ func Validate(o *options.Options) error {
|
|||||||
len(encryption.SecretBytes(o.Cookie.Secret)), suffix))
|
len(encryption.SecretBytes(o.Cookie.Secret)), suffix))
|
||||||
} else {
|
} else {
|
||||||
var err error
|
var err error
|
||||||
cipher, err = encryption.NewCipher(encryption.SecretBytes(o.Cookie.Secret))
|
cipher, err = encryption.NewBase64Cipher(encryption.NewCFBCipher, encryption.SecretBytes(o.Cookie.Secret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
|
msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user