1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2025-06-15 00:15:00 +02:00

Integrate new provider constructor in main

This commit is contained in:
Joel Speed
2022-02-15 12:00:06 +00:00
parent 2e15f57b70
commit 0791aef8cc
2 changed files with 29 additions and 17 deletions

View File

@ -114,6 +114,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
} }
} }
provider, err := providers.NewProvider(opts.Providers[0])
if err != nil {
return nil, fmt.Errorf("error intiailising provider: %v", err)
}
pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{
TemplatesPath: opts.Templates.Path, TemplatesPath: opts.Templates.Path,
CustomLogo: opts.Templates.CustomLogo, CustomLogo: opts.Templates.CustomLogo,
@ -121,7 +126,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
Footer: opts.Templates.Footer, Footer: opts.Templates.Footer,
Version: VERSION, Version: VERSION,
Debug: opts.Templates.Debug, Debug: opts.Templates.Debug,
ProviderName: buildProviderName(opts.GetProvider(), opts.Providers[0].Name), ProviderName: buildProviderName(provider, opts.Providers[0].Name),
SignInMessage: buildSignInMessage(opts), SignInMessage: buildSignInMessage(opts),
DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm,
}) })
@ -145,7 +150,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix)
} }
logger.Printf("OAuthProxy configured for %s Client ID: %s", opts.GetProvider().Data().ProviderName, opts.Providers[0].ClientID) logger.Printf("OAuthProxy configured for %s Client ID: %s", provider.Data().ProviderName, opts.Providers[0].ClientID)
refresh := "disabled" refresh := "disabled"
if opts.Cookie.Refresh != time.Duration(0) { if opts.Cookie.Refresh != time.Duration(0) {
refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh) refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh)
@ -171,7 +176,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err) return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
} }
sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator)
headersChain, err := buildHeadersChain(opts) headersChain, err := buildHeadersChain(opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not build headers chain: %v", err) return nil, fmt.Errorf("could not build headers chain: %v", err)
@ -190,7 +195,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
ProxyPrefix: opts.ProxyPrefix, ProxyPrefix: opts.ProxyPrefix,
provider: opts.GetProvider(), provider: provider,
sessionStore: sessionStore, sessionStore: sessionStore,
redirectURL: redirectURL, redirectURL: redirectURL,
allowedRoutes: allowedRoutes, allowedRoutes: allowedRoutes,
@ -346,12 +351,12 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
return chain, nil return chain, nil
} }
func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain {
chain := alice.New() chain := alice.New()
if opts.SkipJwtBearerTokens { if opts.SkipJwtBearerTokens {
sessionLoaders := []middlewareapi.TokenToSessionFunc{ sessionLoaders := []middlewareapi.TokenToSessionFunc{
opts.GetProvider().CreateSessionFromToken, provider.CreateSessionFromToken,
} }
for _, verifier := range opts.GetJWTBearerVerifiers() { for _, verifier := range opts.GetJWTBearerVerifiers() {
@ -369,8 +374,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore, SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh, RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: opts.GetProvider().RefreshSession, RefreshSession: provider.RefreshSession,
ValidateSession: opts.GetProvider().ValidateSession, ValidateSession: provider.ValidateSession,
})) }))
return chain return chain

View File

@ -161,13 +161,11 @@ func Test_enrichSession(t *testing.T) {
err := validation.Validate(opts) err := validation.Validate(opts)
assert.NoError(t, err) assert.NoError(t, err)
// intentionally set after validation.Validate(opts) since it will clobber
// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail))
proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
proxy.provider = NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)
err = proxy.enrichSessionState(context.Background(), tc.session) err = proxy.enrichSessionState(context.Background(), tc.session)
assert.NoError(t, err) assert.NoError(t, err)
@ -232,13 +230,13 @@ func TestBasicAuthPassword(t *testing.T) {
providerURL, _ := url.Parse(providerServer.URL) providerURL, _ := url.Parse(providerServer.URL)
const emailAddress = "john.doe@example.com" const emailAddress = "john.doe@example.com"
opts.SetProvider(NewTestProvider(providerURL, emailAddress))
proxy, err := NewOAuthProxy(opts, func(email string) bool { proxy, err := NewOAuthProxy(opts, func(email string) bool {
return email == emailAddress return email == emailAddress
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
proxy.provider = NewTestProvider(providerURL, emailAddress)
// Save the required session // Save the required session
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
@ -390,10 +388,10 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
testProvider := NewTestProvider(providerURL, emailAddress) testProvider := NewTestProvider(providerURL, emailAddress)
testProvider.ValidToken = opts.ValidToken testProvider.ValidToken = opts.ValidToken
patt.opts.SetProvider(testProvider)
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
return email == emailAddress return email == emailAddress
}) })
patt.proxy.provider = testProvider
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -769,11 +767,17 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
if err != nil { if err != nil {
return nil, err return nil, err
} }
pcTest.proxy.provider = &TestProvider{ testProvider := &TestProvider{
ProviderData: &providers.ProviderData{}, ProviderData: &providers.ProviderData{},
ValidToken: opts.providerValidateCookieResponse, ValidToken: opts.providerValidateCookieResponse,
} }
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups)
groups := pcTest.opts.Providers[0].AllowedGroups
testProvider.AllowedGroups = make(map[string]struct{}, len(groups))
for _, group := range groups {
testProvider.AllowedGroups[group] = struct{}{}
}
pcTest.proxy.provider = testProvider
// Now, zero-out proxy.CookieRefresh for the cases that don't involve // Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation. // access_token validation.
@ -1359,12 +1363,12 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
upstreamURL, _ := url.Parse(upstreamServer.URL) upstreamURL, _ := url.Parse(upstreamServer.URL)
opts.SetProvider(NewTestProvider(upstreamURL, ""))
proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
proxy.provider = NewTestProvider(upstreamURL, "")
rw := httptest.NewRecorder() rw := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil) req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
proxy.ServeHTTP(rw, req) proxy.ServeHTTP(rw, req)
@ -1409,6 +1413,7 @@ type SignatureTest struct {
header http.Header header http.Header
rw *httptest.ResponseRecorder rw *httptest.ResponseRecorder
authenticator *SignatureAuthenticator authenticator *SignatureAuthenticator
authProvider providers.Provider
} }
func NewSignatureTest() (*SignatureTest, error) { func NewSignatureTest() (*SignatureTest, error) {
@ -1443,7 +1448,7 @@ func NewSignatureTest() (*SignatureTest, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org")) testProvider := NewTestProvider(providerURL, "mbland@acm.org")
return &SignatureTest{ return &SignatureTest{
opts, opts,
@ -1453,6 +1458,7 @@ func NewSignatureTest() (*SignatureTest, error) {
make(http.Header), make(http.Header),
httptest.NewRecorder(), httptest.NewRecorder(),
authenticator, authenticator,
testProvider,
}, nil }, nil
} }
@ -1486,6 +1492,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
if err != nil { if err != nil {
return err return err
} }
proxy.provider = st.authProvider
var bodyBuf io.ReadCloser var bodyBuf io.ReadCloser
if body != "" { if body != "" {