1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-01-24 05:26:55 +02:00

Merge pull request #147 from pusher/session-store

Add initial session-store interface and implementation
This commit is contained in:
Joel Speed 2019-05-20 10:18:47 +01:00 committed by GitHub
commit 17e97ab884
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 985 additions and 137 deletions

View File

@ -10,7 +10,11 @@
## Changes since v3.2.0 ## Changes since v3.2.0
- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings) - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed)
- Allows for multiple different session storage implementations including client and server side
- Adds tests suite for interface to ensure consistency across implementations
- Refactor some configuration options (around cookies) into packages
- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings)
- [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath) - [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath)
- [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes) - [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes)
- [#142](https://github.com/pusher/oauth2_proxy/pull/142) ARM Docker USER fix (@kskewes) - [#142](https://github.com/pusher/oauth2_proxy/pull/142) ARM Docker USER fix (@kskewes)

129
Gopkg.lock generated
View File

@ -57,6 +57,20 @@
pruneopts = "" pruneopts = ""
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
[[projects]]
digest = "1:b3c5b95e56c06f5aa72cb2500e6ee5f44fcd122872d4fec2023a488e561218bc"
name = "github.com/hpcloud/tail"
packages = [
".",
"ratelimiter",
"util",
"watch",
"winfile",
]
pruneopts = ""
revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5"
version = "v1.0.0"
[[projects]] [[projects]]
digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f"
name = "github.com/mbland/hmacauth" name = "github.com/mbland/hmacauth"
@ -73,6 +87,54 @@
pruneopts = "" pruneopts = ""
revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95"
[[projects]]
digest = "1:a3735b0978a8b53fc2ac97a6f46ca1189f0712a00df86d6ec4cf26c1a25e6d77"
name = "github.com/onsi/ginkgo"
packages = [
".",
"config",
"internal/codelocation",
"internal/containernode",
"internal/failer",
"internal/leafnodes",
"internal/remote",
"internal/spec",
"internal/spec_iterator",
"internal/specrunner",
"internal/suite",
"internal/testingtproxy",
"internal/writer",
"reporters",
"reporters/stenographer",
"reporters/stenographer/support/go-colorable",
"reporters/stenographer/support/go-isatty",
"types",
]
pruneopts = ""
revision = "eea6ad008b96acdaa524f5b409513bf062b500ad"
version = "v1.8.0"
[[projects]]
digest = "1:dbafce2fddb1ca331646fe2ac9c9413980368b19a60a4406a6e5861680bd73be"
name = "github.com/onsi/gomega"
packages = [
".",
"format",
"internal/assertion",
"internal/asyncassertion",
"internal/oraclematcher",
"internal/testingtsupport",
"matchers",
"matchers/support/goraph/bipartitegraph",
"matchers/support/goraph/edge",
"matchers/support/goraph/node",
"matchers/support/goraph/util",
"types",
]
pruneopts = ""
revision = "90e289841c1ed79b7a598a7cd9959750cb5e89e2"
version = "v1.5.0"
[[projects]] [[projects]]
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
name = "github.com/pmezard/go-difflib" name = "github.com/pmezard/go-difflib"
@ -131,6 +193,9 @@
packages = [ packages = [
"context", "context",
"context/ctxhttp", "context/ctxhttp",
"html",
"html/atom",
"html/charset",
"websocket", "websocket",
] ]
pruneopts = "" pruneopts = ""
@ -150,6 +215,42 @@
pruneopts = "" pruneopts = ""
revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402"
[[projects]]
branch = "master"
digest = "1:67a6e61e60283fd7dce50eba228080bff8805d9d69b2f121d7ec2260d120c4a8"
name = "golang.org/x/sys"
packages = ["unix"]
pruneopts = ""
revision = "ca7f33d4116e3a1f9425755d6a44e7ed9b4c97df"
[[projects]]
digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0"
name = "golang.org/x/text"
packages = [
"encoding",
"encoding/charmap",
"encoding/htmlindex",
"encoding/internal",
"encoding/internal/identifier",
"encoding/japanese",
"encoding/korean",
"encoding/simplifiedchinese",
"encoding/traditionalchinese",
"encoding/unicode",
"internal/gen",
"internal/language",
"internal/language/compact",
"internal/tag",
"internal/utf8internal",
"language",
"runes",
"transform",
"unicode/cldr",
]
pruneopts = ""
revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475"
version = "v0.3.2"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed"
@ -182,6 +283,15 @@
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
version = "v1.0.0" version = "v1.0.0"
[[projects]]
digest = "1:eb53021a8aa3f599d29c7102e65026242bdedce998a54837dc67f14b6a97c5fd"
name = "gopkg.in/fsnotify.v1"
packages = ["."]
pruneopts = ""
revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9"
source = "https://github.com/fsnotify/fsnotify.git"
version = "v1.4.7"
[[projects]] [[projects]]
digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2"
name = "gopkg.in/fsnotify/fsnotify.v1" name = "gopkg.in/fsnotify/fsnotify.v1"
@ -210,6 +320,22 @@
revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1"
version = "v2.1.3" version = "v2.1.3"
[[projects]]
branch = "v1"
digest = "1:a96d16bd088460f2e0685d46c39bcf1208ba46e0a977be2df49864ec7da447dd"
name = "gopkg.in/tomb.v1"
packages = ["."]
pruneopts = ""
revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8"
[[projects]]
digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d"
name = "gopkg.in/yaml.v2"
packages = ["."]
pruneopts = ""
revision = "51d6538a90f86fe93ac480b35f37b2be17fef232"
version = "v2.2.2"
[solve-meta] [solve-meta]
analyzer-name = "dep" analyzer-name = "dep"
analyzer-version = 1 analyzer-version = 1
@ -220,6 +346,8 @@
"github.com/dgrijalva/jwt-go", "github.com/dgrijalva/jwt-go",
"github.com/mbland/hmacauth", "github.com/mbland/hmacauth",
"github.com/mreiferson/go-options", "github.com/mreiferson/go-options",
"github.com/onsi/ginkgo",
"github.com/onsi/gomega",
"github.com/stretchr/testify/assert", "github.com/stretchr/testify/assert",
"github.com/stretchr/testify/require", "github.com/stretchr/testify/require",
"github.com/yhat/wsutil", "github.com/yhat/wsutil",
@ -231,6 +359,7 @@
"google.golang.org/api/googleapi", "google.golang.org/api/googleapi",
"gopkg.in/fsnotify/fsnotify.v1", "gopkg.in/fsnotify/fsnotify.v1",
"gopkg.in/natefinch/lumberjack.v2", "gopkg.in/natefinch/lumberjack.v2",
"gopkg.in/square/go-jose.v2",
] ]
solver-name = "gps-cdcl" solver-name = "gps-cdcl"
solver-version = 1 solver-version = 1

View File

@ -35,6 +35,10 @@
name = "gopkg.in/fsnotify/fsnotify.v1" name = "gopkg.in/fsnotify/fsnotify.v1"
version = "~1.2.0" version = "~1.2.0"
[[override]]
name = "gopkg.in/fsnotify.v1"
source = "https://github.com/fsnotify/fsnotify.git"
[[constraint]] [[constraint]]
branch = "master" branch = "master"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"

View File

@ -33,6 +33,7 @@ lint: $(GOMETALINTER)
--enable=deadcode \ --enable=deadcode \
--enable=gofmt \ --enable=gofmt \
--enable=goimports \ --enable=goimports \
--deadline=120s \
--tests ./... --tests ./...
.PHONY: dep .PHONY: dep

View File

@ -1,7 +1,8 @@
--- ---
layout: default layout: default
title: Configuration title: Configuration
permalink: /configuration permalink: /docs/configuration
has_children: true
nav_order: 3 nav_order: 3
--- ---
@ -78,6 +79,7 @@ Usage of oauth2_proxy:
-request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below) -request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below)
-resource string: The resource that is protected (Azure AD only) -resource string: The resource that is protected (Azure AD only)
-scope string: OAuth scope specification -scope string: OAuth scope specification
-session-store-type: Session data storage backend (default: cookie)
-set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)
-set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode) -set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode)
-signature-key string: GAP-Signature request signature key (algorithm:secretkey) -signature-key string: GAP-Signature request signature key (algorithm:secretkey)

View File

@ -0,0 +1,34 @@
---
layout: default
title: Sessions
permalink: /configuration
parent: Configuration
nav_order: 3
---
## Sessions
Sessions allow a user's authentication to be tracked between multiple HTTP
requests to a service.
The OAuth2 Proxy uses a Cookie to track user sessions and will store the session
data in one of the available session storage backends.
At present the available backends are (as passed to `--session-store-type`):
- [cookie](cookie-storage) (deafult)
### Cookie Storage
The Cookie storage backend is the default backend implementation and has
been used in the OAuth2 Proxy historically.
With the Cookie storage backend, all session information is stored in client
side cookies and transferred with each and every request.
The following should be known when using this implementation:
- Since all state is stored client side, this storage backend means that the OAuth2 Proxy is completely stateless
- Cookies are signed server side to prevent modification client-side
- It is recommended to set a `cookie-secret` which will ensure data is encrypted within the cookie data.
- Since multiple requests can be made concurrently to the OAuth2 Proxy, this session implementation
cannot lock sessions and while updating and refreshing sessions, there can be conflicts which force
users to re-authenticate

View File

@ -15,14 +15,27 @@ type EnvOptions map[string]interface{}
// Fields in the options struct must have an `env` and `cfg` tag to be read // Fields in the options struct must have an `env` and `cfg` tag to be read
// from the environment // from the environment
func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { func (cfg EnvOptions) LoadEnvForStruct(options interface{}) {
val := reflect.ValueOf(options).Elem() val := reflect.ValueOf(options)
typ := val.Type() var typ reflect.Type
if val.Kind() == reflect.Ptr {
typ = val.Elem().Type()
} else {
typ = val.Type()
}
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
// pull out the struct tags: // pull out the struct tags:
// flag - the name of the command line flag // flag - the name of the command line flag
// deprecated - (optional) the name of the deprecated command line flag // deprecated - (optional) the name of the deprecated command line flag
// cfg - (optional, defaults to underscored flag) the name of the config file option // cfg - (optional, defaults to underscored flag) the name of the config file option
field := typ.Field(i) field := typ.Field(i)
fieldV := reflect.Indirect(val).Field(i)
if field.Type.Kind() == reflect.Struct && field.Anonymous {
cfg.LoadEnvForStruct(fieldV.Interface())
continue
}
flagName := field.Tag.Get("flag") flagName := field.Tag.Get("flag")
envName := field.Tag.Get("env") envName := field.Tag.Get("env")
cfgName := field.Tag.Get("cfg") cfgName := field.Tag.Get("cfg")

View File

@ -1,26 +1,46 @@
package main package main_test
import ( import (
"os" "os"
"testing" "testing"
proxy "github.com/pusher/oauth2_proxy"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
type envTest struct { type EnvTest struct {
testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` TestField string `cfg:"target_field" env:"TEST_ENV_FIELD"`
EnvTestEmbed
}
type EnvTestEmbed struct {
TestFieldEmbed string `cfg:"target_field_embed" env:"TEST_ENV_FIELD_EMBED"`
} }
func TestLoadEnvForStruct(t *testing.T) { func TestLoadEnvForStruct(t *testing.T) {
cfg := make(EnvOptions) cfg := make(proxy.EnvOptions)
cfg.LoadEnvForStruct(&envTest{}) cfg.LoadEnvForStruct(&EnvTest{})
_, ok := cfg["target_field"] _, ok := cfg["target_field"]
assert.Equal(t, ok, false) assert.Equal(t, ok, false)
os.Setenv("TEST_ENV_FIELD", "1234abcd") os.Setenv("TEST_ENV_FIELD", "1234abcd")
cfg.LoadEnvForStruct(&envTest{}) cfg.LoadEnvForStruct(&EnvTest{})
v := cfg["target_field"] v := cfg["target_field"]
assert.Equal(t, v, "1234abcd") assert.Equal(t, v, "1234abcd")
} }
func TestLoadEnvForStructWithEmbeddedFields(t *testing.T) {
cfg := make(proxy.EnvOptions)
cfg.LoadEnvForStruct(&EnvTest{})
_, ok := cfg["target_field_embed"]
assert.Equal(t, ok, false)
os.Setenv("TEST_ENV_FIELD_EMBED", "1234abcd")
cfg.LoadEnvForStruct(&EnvTest{})
v := cfg["target_field_embed"]
assert.Equal(t, v, "1234abcd")
}

View File

@ -75,6 +75,8 @@ func main() {
flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag")
flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag")
flagSet.String("session-store-type", "cookie", "the session storage provider to use")
flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") flagSet.String("logging-filename", "", "File to log requests to, empty for stdout")
flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation") flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation")
flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files") flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files")

View File

@ -16,6 +16,7 @@ import (
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"github.com/yhat/wsutil" "github.com/yhat/wsutil"
) )
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
} }
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
return nil, errors.New("missing code") return nil, errors.New("missing code")
} }
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
} }
// LoadCookiedSession reads the user's authentication details from the request // LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
var age time.Duration var age time.Duration
c, err := loadCookie(req, p.CookieName) c, err := loadCookie(req, p.CookieName)
if err != nil { if err != nil {
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
} }
// SaveSession creates a new session cookie value and sets this on the response // SaveSession creates a new session cookie value and sets this on the response
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher) value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil { if err != nil {
return err return err
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
user, ok := p.ManualSignIn(rw, req) user, ok := p.ManualSignIn(rw, req)
if ok { if ok {
session := &providers.SessionState{User: user} session := &sessions.SessionState{User: user}
p.SaveSession(rw, req, session) p.SaveSession(rw, req, session)
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, 302)
} else { } else {
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
// CheckBasicAuth checks the requests Authorization header for basic auth // CheckBasicAuth checks the requests Authorization header for basic auth
// credentials and authenticates these against the proxies HtpasswdFile // credentials and authenticates these against the proxies HtpasswdFile
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
if p.HtpasswdFile == nil { if p.HtpasswdFile == nil {
return nil, nil return nil, nil
} }
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
} }
if p.HtpasswdFile.Validate(pair[0], pair[1]) { if p.HtpasswdFile.Validate(pair[0], pair[1]) {
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
return &providers.SessionState{User: pair[0]}, nil return &sessions.SessionState{User: pair[0]}, nil
} }
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
return nil, nil return nil, nil

View File

@ -16,6 +16,7 @@ import (
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
} }
} }
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil return tp.EmailAddress, nil
} }
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
return tp.ValidToken return tp.ValidToken
} }
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
} }
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
if err != nil { if err != nil {
return err return err
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
return nil return nil
} }
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
return p.proxy.LoadCookiedSession(p.req) return p.proxy.LoadCookiedSession(p.req)
} }
func TestLoadCookiedSession(t *testing.T) { func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
session, _, err := pcTest.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour) reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, age, err := pcTest.LoadCookiedSession() session, age, err := pcTest.LoadCookiedSession()
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, _, err := pcTest.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
pcTest.proxy.CookieRefresh = time.Hour pcTest.proxy.CookieRefresh = time.Hour
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
func TestAuthOnlyEndpointAccepted(t *testing.T) { func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession, time.Now())
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
test.proxy.CookieExpire = time.Duration(24) * time.Hour test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, reference) test.SaveSession(startSession, reference)
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession, time.Now())
test.validateUser = false test.validateUser = false
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
pcTest.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
req := httptest.NewRequest(method, "/foo/bar", bodyBuf) req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
req.Header = st.header req.Header = st.header
state := &providers.SessionState{ state := &sessions.SessionState{
Email: "mbland@acm.org", AccessToken: "my_access_token"} Email: "mbland@acm.org", AccessToken: "my_access_token"}
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
if err != nil { if err != nil {

View File

@ -18,6 +18,7 @@ import (
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/options"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
@ -49,14 +50,11 @@ type Options struct {
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"` CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"`
Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"` Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"`
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` // Embed CookieOptions
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` options.CookieOptions
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` // Embed SessionOptions
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` options.SessionOptions
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"` Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"`
@ -126,16 +124,18 @@ type SignatureData struct {
// NewOptions constructs a new Options with defaulted values // NewOptions constructs a new Options with defaulted values
func NewOptions() *Options { func NewOptions() *Options {
return &Options{ return &Options{
ProxyPrefix: "/oauth2", ProxyPrefix: "/oauth2",
ProxyWebSockets: true, ProxyWebSockets: true,
HTTPAddress: "127.0.0.1:4180", HTTPAddress: "127.0.0.1:4180",
HTTPSAddress: ":443", HTTPSAddress: ":443",
DisplayHtpasswdForm: true, DisplayHtpasswdForm: true,
CookieName: "_oauth2_proxy", CookieOptions: options.CookieOptions{
CookieSecure: true, CookieName: "_oauth2_proxy",
CookieHTTPOnly: true, CookieSecure: true,
CookieExpire: time.Duration(168) * time.Hour, CookieHTTPOnly: true,
CookieRefresh: time.Duration(0), CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(0),
},
SetXAuthRequest: false, SetXAuthRequest: false,
SkipAuthPreflight: false, SkipAuthPreflight: false,
PassBasicAuth: true, PassBasicAuth: true,

View File

@ -0,0 +1,15 @@
package options
import "time"
// CookieOptions contains configuration options relating to Cookie configuration
type CookieOptions struct {
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"`
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"`
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"`
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
}

View File

@ -0,0 +1,14 @@
package options
// SessionOptions contains configuration options for the SessionStore providers.
type SessionOptions struct {
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
CookieStoreOptions
}
// CookieSessionStoreType is used to indicate the CookieSessionStore should be
// used for storing sessions.
var CookieSessionStoreType = "cookie"
// CookieStoreOptions contains configuration options for the CookieSessionStore.
type CookieStoreOptions struct{}

View File

@ -0,0 +1,12 @@
package sessions
import (
"net/http"
)
// SessionStore is an interface to storing user sessions in the proxy
type SessionStore interface {
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
Load(req *http.Request) (*SessionState, error)
Clear(rw http.ResponseWriter, req *http.Request) error
}

View File

@ -1,4 +1,4 @@
package providers package sessions
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package providers package sessions_test
import ( import (
"fmt" "fmt"
@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret)) c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
s := &SessionState{ s := &sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
IDToken: "rawtoken1234", IDToken: "rawtoken1234",
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
ss, err := DecodeSessionState(encoded, c) ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User) assert.Equal(t, "user@domain.com", ss.User)
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish) // ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2) ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, "user@domain.com", ss.User) assert.NotEqual(t, "user@domain.com", ss.User)
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret)) c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
s := &SessionState{ s := &sessions.SessionState{
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
ss, err := DecodeSessionState(encoded, c) ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User) assert.Equal(t, s.User, ss.User)
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish) // ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2) ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, s.User, ss.User) assert.NotEqual(t, s.User, ss.User)
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
} }
func TestSessionStateSerializationNoCipher(t *testing.T) { func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{ s := &sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User) assert.Equal(t, "user@domain.com", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
} }
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
s := &SessionState{ s := &sessions.SessionState{
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User) assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
} }
func TestExpired(t *testing.T) { func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired()) assert.Equal(t, true, s.IsExpired())
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
assert.Equal(t, false, s.IsExpired()) assert.Equal(t, false, s.IsExpired())
s = &SessionState{} s = &sessions.SessionState{}
assert.Equal(t, false, s.IsExpired()) assert.Equal(t, false, s.IsExpired())
} }
type testCase struct { type testCase struct {
SessionState sessions.SessionState
Encoded string Encoded string
Cipher *cookie.Cipher Cipher *cookie.Cipher
Error bool Error bool
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
testCases := []testCase{ testCases := []testCase{
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: `{"Email":"user@domain.com","User":"just-user"}`, Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
encoded, err := tc.EncodeSessionState(tc.Cipher) encoded, err := tc.EncodeSessionState(tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
if tc.Error { if tc.Error {
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, encoded) assert.Empty(t, encoded)
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
} }
} }
// TestDecodeSessionState tests DecodeSessionState with the test vector // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) { func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour) e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON() eJSON, _ := e.MarshalJSON()
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
testCases := []testCase{ testCases := []testCase{
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: `{"Email":"user@domain.com","User":"just-user"}`, Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "user@domain.com", User: "user@domain.com",
}, },
Encoded: `{"Email":"user@domain.com"}`, Encoded: `{"Email":"user@domain.com"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
User: "just-user", User: "just-user",
}, },
Encoded: `{"User":"just-user"}`, Encoded: `{"User":"just-user"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c, Cipher: c,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
Error: true, Error: true,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
}, },
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
Error: true, Error: true,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c, Cipher: c,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
} }
for i, tc := range testCases { for i, tc := range testCases {
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
if tc.Error { if tc.Error {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, ss) assert.Nil(t, ss)

34
pkg/cookies/cookies.go Normal file
View File

@ -0,0 +1,34 @@
package cookies
import (
"net"
"net/http"
"strings"
"time"
"github.com/pusher/oauth2_proxy/logger"
)
// MakeCookie constructs a cookie from the given parameters,
// discovering the domain from the request if not specified.
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time) *http.Cookie {
if domain != "" {
host := req.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if !strings.HasSuffix(host, domain) {
logger.Printf("Warning: request host is %q but using configured cookie domain of %q", host, domain)
}
}
return &http.Cookie{
Name: name,
Value: value,
Path: path,
Domain: domain,
HttpOnly: httpOnly,
Secure: secure,
Expires: now.Add(expiration),
}
}

View File

@ -0,0 +1,232 @@
package cookie
import (
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/options"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/cookies"
"github.com/pusher/oauth2_proxy/pkg/sessions/utils"
)
const (
// Cookies are limited to 4kb including the length of the cookie name,
// the cookie name can be up to 256 bytes
maxCookieLength = 3840
)
// Ensure CookieSessionStore implements the interface
var _ sessions.SessionStore = &SessionStore{}
// SessionStore is an implementation of the sessions.SessionStore
// interface that stores sessions in client side cookies
type SessionStore struct {
CookieCipher *cookie.Cipher
CookieDomain string
CookieExpire time.Duration
CookieHTTPOnly bool
CookieName string
CookiePath string
CookieSecret string
CookieSecure bool
}
// Save takes a sessions.SessionState and stores the information from it
// within Cookies set on the HTTP response writer
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
value, err := utils.CookieForSession(ss, s.CookieCipher)
if err != nil {
return err
}
s.setSessionCookie(rw, req, value)
return nil
}
// Load reads sessions.SessionState information from Cookies within the
// HTTP request object
func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
c, err := loadCookie(req, s.CookieName)
if err != nil {
// always http.ErrNoCookie
return nil, fmt.Errorf("Cookie %q not present", s.CookieName)
}
val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire)
if !ok {
return nil, errors.New("Cookie Signature not valid")
}
session, err := utils.SessionFromCookie(val, s.CookieCipher)
if err != nil {
return nil, err
}
return session, nil
}
// Clear clears any saved session information by writing a cookie to
// clear the session
func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
var cookies []*http.Cookie
// matches CookieName, CookieName_<number>
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName))
for _, c := range req.Cookies() {
if cookieNameRegex.MatchString(c.Name) {
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
http.SetCookie(rw, clearCookie)
cookies = append(cookies, clearCookie)
}
}
return nil
}
// setSessionCookie adds the user's session cookie to the response
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) {
http.SetCookie(rw, c)
}
}
// makeSessionCookie creates an http.Cookie containing the authenticated user's
// authentication details
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
if value != "" {
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
}
c := s.makeCookie(req, s.CookieName, value, expiration)
if len(c.Value) > 4096-len(s.CookieName) {
return splitCookie(c)
}
return []*http.Cookie{c}
}
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
return cookies.MakeCookie(
req,
name,
value,
s.CookiePath,
s.CookieDomain,
s.CookieHTTPOnly,
s.CookieSecure,
expiration,
time.Now(),
)
}
// NewCookieSessionStore initialises a new instance of the SessionStore from
// the configuration given
func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
var cipher *cookie.Cipher
if len(cookieOpts.CookieSecret) > 0 {
var err error
cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
if err != nil {
return nil, fmt.Errorf("unable to create cipher: %v", err)
}
}
return &SessionStore{
CookieCipher: cipher,
CookieDomain: cookieOpts.CookieDomain,
CookieExpire: cookieOpts.CookieExpire,
CookieHTTPOnly: cookieOpts.CookieHTTPOnly,
CookieName: cookieOpts.CookieName,
CookiePath: cookieOpts.CookiePath,
CookieSecret: cookieOpts.CookieSecret,
CookieSecure: cookieOpts.CookieSecure,
}, nil
}
// splitCookie reads the full cookie generated to store the session and splits
// it into a slice of cookies which fit within the 4kb cookie limit indexing
// the cookies from 0
func splitCookie(c *http.Cookie) []*http.Cookie {
if len(c.Value) < maxCookieLength {
return []*http.Cookie{c}
}
cookies := []*http.Cookie{}
valueBytes := []byte(c.Value)
count := 0
for len(valueBytes) > 0 {
new := copyCookie(c)
new.Name = fmt.Sprintf("%s_%d", c.Name, count)
count++
if len(valueBytes) < maxCookieLength {
new.Value = string(valueBytes)
valueBytes = []byte{}
} else {
newValue := valueBytes[:maxCookieLength]
valueBytes = valueBytes[maxCookieLength:]
new.Value = string(newValue)
}
cookies = append(cookies, new)
}
return cookies
}
// loadCookie retreieves the sessions state cookie from the http request.
// If a single cookie is present this will be returned, otherwise it attempts
// to reconstruct a cookie split up by splitCookie
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
c, err := req.Cookie(cookieName)
if err == nil {
return c, nil
}
cookies := []*http.Cookie{}
err = nil
count := 0
for err == nil {
var c *http.Cookie
c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count))
if err == nil {
cookies = append(cookies, c)
count++
}
}
if len(cookies) == 0 {
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
}
return joinCookies(cookies)
}
// joinCookies takes a slice of cookies from the request and reconstructs the
// full session cookie
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
if len(cookies) == 0 {
return nil, fmt.Errorf("list of cookies must be > 0")
}
if len(cookies) == 1 {
return cookies[0], nil
}
c := copyCookie(cookies[0])
for i := 1; i < len(cookies); i++ {
c.Value += cookies[i].Value
}
c.Name = strings.TrimRight(c.Name, "_0")
return c, nil
}
func copyCookie(c *http.Cookie) *http.Cookie {
return &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
Raw: c.Raw,
Unparsed: c.Unparsed,
}
}

View File

@ -0,0 +1,19 @@
package sessions
import (
"fmt"
"github.com/pusher/oauth2_proxy/pkg/apis/options"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
)
// NewSessionStore creates a SessionStore from the provided configuration
func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
switch opts.Type {
case options.CookieSessionStoreType:
return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts)
default:
return nil, fmt.Errorf("unknown session store type '%s'", opts.Type)
}
}

View File

@ -0,0 +1,254 @@
package sessions_test
import (
"crypto/rand"
"encoding/base64"
"net/http"
"net/http/httptest"
"testing"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/pusher/oauth2_proxy/pkg/apis/options"
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/cookies"
"github.com/pusher/oauth2_proxy/pkg/sessions"
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
)
func TestSessionStore(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "SessionStore")
}
var _ = Describe("NewSessionStore", func() {
var opts *options.SessionOptions
var cookieOpts *options.CookieOptions
var request *http.Request
var response *httptest.ResponseRecorder
var session *sessionsapi.SessionState
var ss sessionsapi.SessionStore
CheckCookieOptions := func() {
Context("the cookies returned", func() {
var cookies []*http.Cookie
BeforeEach(func() {
cookies = response.Result().Cookies()
})
It("have the correct name set", func() {
if len(cookies) == 1 {
Expect(cookies[0].Name).To(Equal(cookieOpts.CookieName))
} else {
for _, cookie := range cookies {
Expect(cookie.Name).To(ContainSubstring(cookieOpts.CookieName))
}
}
})
It("have the correct path set", func() {
for _, cookie := range cookies {
Expect(cookie.Path).To(Equal(cookieOpts.CookiePath))
}
})
It("have the correct domain set", func() {
for _, cookie := range cookies {
Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain))
}
})
It("have the correct HTTPOnly set", func() {
for _, cookie := range cookies {
Expect(cookie.HttpOnly).To(Equal(cookieOpts.CookieHTTPOnly))
}
})
It("have the correct secure set", func() {
for _, cookie := range cookies {
Expect(cookie.Secure).To(Equal(cookieOpts.CookieSecure))
}
})
})
}
SessionStoreInterfaceTests := func() {
Context("when Save is called", func() {
BeforeEach(func() {
err := ss.Save(response, request, session)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
})
CheckCookieOptions()
})
Context("when Clear is called", func() {
BeforeEach(func() {
cookie := cookies.MakeCookie(request,
cookieOpts.CookieName,
"foo",
cookieOpts.CookiePath,
cookieOpts.CookieDomain,
cookieOpts.CookieHTTPOnly,
cookieOpts.CookieSecure,
cookieOpts.CookieExpire,
time.Now(),
)
request.AddCookie(cookie)
err := ss.Clear(response, request)
Expect(err).ToNot(HaveOccurred())
})
It("sets a `set-cookie` header in the response", func() {
Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
})
CheckCookieOptions()
})
Context("when Load is called", func() {
var loadedSession *sessionsapi.SessionState
BeforeEach(func() {
req := httptest.NewRequest("GET", "http://example.com/", nil)
resp := httptest.NewRecorder()
err := ss.Save(resp, req, session)
Expect(err).ToNot(HaveOccurred())
for _, cookie := range resp.Result().Cookies() {
request.AddCookie(cookie)
}
loadedSession, err = ss.Load(request)
Expect(err).ToNot(HaveOccurred())
})
It("loads a session equal to the original session", func() {
if cookieOpts.CookieSecret == "" {
// Only Email and User stored in session when encrypted
Expect(loadedSession.Email).To(Equal(session.Email))
Expect(loadedSession.User).To(Equal(session.User))
} else {
// All fields stored in session if encrypted
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
l := *loadedSession
l.ExpiresOn = time.Time{}
s := *session
s.ExpiresOn = time.Time{}
Expect(l).To(Equal(s))
// Compare time.Time separately
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
}
})
})
}
RunSessionTests := func() {
Context("with default options", func() {
BeforeEach(func() {
var err error
ss, err = sessions.NewSessionStore(opts, cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests()
})
Context("with non-default options", func() {
BeforeEach(func() {
cookieOpts = &options.CookieOptions{
CookieName: "_cookie_name",
CookiePath: "/path",
CookieExpire: time.Duration(72) * time.Hour,
CookieRefresh: time.Duration(3600),
CookieSecure: false,
CookieHTTPOnly: false,
CookieDomain: "example.com",
}
var err error
ss, err = sessions.NewSessionStore(opts, cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests()
})
Context("with a cookie-secret set", func() {
BeforeEach(func() {
secret := make([]byte, 32)
_, err := rand.Read(secret)
Expect(err).ToNot(HaveOccurred())
cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret)
ss, err = sessions.NewSessionStore(opts, cookieOpts)
Expect(err).ToNot(HaveOccurred())
})
SessionStoreInterfaceTests()
})
}
BeforeEach(func() {
ss = nil
opts = &options.SessionOptions{}
// Set default options in CookieOptions
cookieOpts = &options.CookieOptions{
CookieName: "_oauth2_proxy",
CookiePath: "/",
CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(0),
CookieSecure: true,
CookieHTTPOnly: true,
}
session = &sessionsapi.SessionState{
AccessToken: "AccessToken",
IDToken: "IDToken",
ExpiresOn: time.Now().Add(1 * time.Hour),
RefreshToken: "RefreshToken",
Email: "john.doe@example.com",
User: "john.doe",
}
request = httptest.NewRequest("GET", "http://example.com/", nil)
response = httptest.NewRecorder()
})
Context("with type 'cookie'", func() {
BeforeEach(func() {
opts.Type = options.CookieSessionStoreType
})
It("creates a cookie.SessionStore", func() {
ss, err := sessions.NewSessionStore(opts, cookieOpts)
Expect(err).NotTo(HaveOccurred())
Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{}))
})
Context("the cookie.SessionStore", func() {
RunSessionTests()
})
})
Context("with an invalid type", func() {
BeforeEach(func() {
opts.Type = "invalid-type"
})
It("returns an error", func() {
ss, err := sessions.NewSessionStore(opts, cookieOpts)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal("unknown session store type 'invalid-type'"))
Expect(ss).To(BeNil())
})
})
})

View File

@ -0,0 +1,41 @@
package utils
import (
"encoding/base64"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// CookieForSession serializes a session state for storage in a cookie
func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
return s.EncodeSessionState(c)
}
// SessionFromCookie deserializes a session from a cookie value
func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
return sessions.DecodeSessionState(v, c)
}
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
func SecretBytes(secret string) []byte {
b, err := base64.URLEncoding.DecodeString(addPadding(secret))
if err == nil {
return []byte(addPadding(string(b)))
}
return []byte(secret)
}
func addPadding(secret string) string {
padding := len(secret) % 4
switch padding {
case 1:
return secret + "==="
case 2:
return secret + "=="
case 3:
return secret + "="
default:
return secret
}
}

View File

@ -9,6 +9,7 @@ import (
"github.com/bitly/go-simplejson" "github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// AzureProvider represents an Azure based Identity Provider // AzureProvider represents an Azure based Identity Provider
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var email string var email string
var err error var err error

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Equal(t, "", email)

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// FacebookProvider represents an Facebook based Identity Provider // FacebookProvider represents an Facebook based Identity Provider
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// GitHubProvider represents an GitHub based Identity Provider // GitHubProvider represents an GitHub based Identity Provider
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
} }
// GetUserName returns the Account user name // GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
var user struct { var user struct {
Login string `json:"login"` Login string `json:"login"`
Email string `json:"email"` Email string `json:"email"`

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Empty(t, "", email) assert.Empty(t, "", email)
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
p.Org = "testorg1" p.Org = "testorg1"
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session) email, err := p.GetUserName(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email) assert.Equal(t, "mbland", email)

View File

@ -6,6 +6,7 @@ import (
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// GitLabProvider represents an GitLab based Identity Provider // GitLabProvider represents an GitLab based Identity Provider
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
req, err := http.NewRequest("GET", req, err := http.NewRequest("GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)

View File

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
if err != nil { if err != nil {
return return
} }
s = &SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
*ProviderData *ProviderData
} }
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// Note that we're testing the internal validateToken() used to implement // Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateSessionState() implementations // several Provider's ValidateSessionState() implementations
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
return false return false
} }

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// LinkedInProvider represents an LinkedIn based Identity Provider // LinkedInProvider represents an LinkedIn based Identity Provider
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
} }

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email) assert.Equal(t, "user@linkedin.com", email)
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)

View File

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
} }
// Store the data that we found in the session state // Store the data that we found in the session state
s = &SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),

View File

@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"time" "time"
"golang.org/x/oauth2"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"golang.org/x/oauth2"
) )
// OIDCProvider represents an OIDC based Identity Provider // OIDCProvider represents an OIDC based Identity Provider
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
ctx := context.Background() ctx := context.Background()
c := oauth2.Config{ c := oauth2.Config{
ClientID: p.ClientID, ClientID: p.ClientID,
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return true, nil return true, nil
} }
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
c := oauth2.Config{ c := oauth2.Config{
ClientID: p.ClientID, ClientID: p.ClientID,
ClientSecret: p.ClientSecret, ClientSecret: p.ClientSecret,
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
return return
} }
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string) rawIDToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
return nil, fmt.Errorf("token response did not contain an id_token") return nil, fmt.Errorf("token response did not contain an id_token")
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
} }
return &SessionState{ return &sessions.SessionState{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
IDToken: rawIDToken, IDToken: rawIDToken,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
} }
// ValidateSessionState checks that the session's IDToken is still valid // ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background() ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil { if err != nil {

View File

@ -10,10 +10,11 @@ import (
"net/url" "net/url"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// Redeem provides a default implementation of the OAuth2 token redemption process // Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &jsonResponse)
if err == nil { if err == nil {
s = &SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
} }
return return
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
return return
} }
if a := v.Get("access_token"); a != "" { if a := v.Get("access_token"); a != "" {
s = &SessionState{AccessToken: a} s = &sessions.SessionState{AccessToken: a}
} else { } else {
err = fmt.Errorf("no access token found %s", body) err = fmt.Errorf("no access token found %s", body)
} }
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
} }
// CookieForSession serializes a session state for storage in a cookie // CookieForSession serializes a session state for storage in a cookie
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
return s.EncodeSessionState(c) return s.EncodeSessionState(c)
} }
// SessionFromCookie deserializes a session from a cookie value // SessionFromCookie deserializes a session from a cookie value
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
return DecodeSessionState(v, c) return sessions.DecodeSessionState(v, c)
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// GetUserName returns the Account username // GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) { func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *SessionState) bool { func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, nil) return validateToken(p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required // do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
return false, nil return false, nil
} }

View File

@ -4,12 +4,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRefresh(t *testing.T) { func TestRefresh(t *testing.T) {
p := &ProviderData{} p := &ProviderData{}
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
}) })
assert.Equal(t, false, refreshed) assert.Equal(t, false, refreshed)

View File

@ -2,20 +2,21 @@ package providers
import ( import (
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*SessionState) (string, error) GetEmailAddress(*sessions.SessionState) (string, error)
GetUserName(*SessionState) (string, error) GetUserName(*sessions.SessionState) (string, error)
Redeem(string, string) (*SessionState, error) Redeem(string, string) (*sessions.SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool ValidateSessionState(*sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*SessionState) (bool, error) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
CookieForSession(*SessionState, *cookie.Cipher) (string, error) CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
} }
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string