You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-06-15 00:15:00 +02:00
Integrate new provider constructor in main
This commit is contained in:
@ -114,6 +114,11 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
provider, err := providers.NewProvider(opts.Providers[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error intiailising provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{
|
pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{
|
||||||
TemplatesPath: opts.Templates.Path,
|
TemplatesPath: opts.Templates.Path,
|
||||||
CustomLogo: opts.Templates.CustomLogo,
|
CustomLogo: opts.Templates.CustomLogo,
|
||||||
@ -121,7 +126,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
Footer: opts.Templates.Footer,
|
Footer: opts.Templates.Footer,
|
||||||
Version: VERSION,
|
Version: VERSION,
|
||||||
Debug: opts.Templates.Debug,
|
Debug: opts.Templates.Debug,
|
||||||
ProviderName: buildProviderName(opts.GetProvider(), opts.Providers[0].Name),
|
ProviderName: buildProviderName(provider, opts.Providers[0].Name),
|
||||||
SignInMessage: buildSignInMessage(opts),
|
SignInMessage: buildSignInMessage(opts),
|
||||||
DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm,
|
DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm,
|
||||||
})
|
})
|
||||||
@ -145,7 +150,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix)
|
redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("OAuthProxy configured for %s Client ID: %s", opts.GetProvider().Data().ProviderName, opts.Providers[0].ClientID)
|
logger.Printf("OAuthProxy configured for %s Client ID: %s", provider.Data().ProviderName, opts.Providers[0].ClientID)
|
||||||
refresh := "disabled"
|
refresh := "disabled"
|
||||||
if opts.Cookie.Refresh != time.Duration(0) {
|
if opts.Cookie.Refresh != time.Duration(0) {
|
||||||
refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh)
|
refresh = fmt.Sprintf("after %s", opts.Cookie.Refresh)
|
||||||
@ -171,7 +176,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
|
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
|
||||||
}
|
}
|
||||||
sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator)
|
sessionChain := buildSessionChain(opts, provider, sessionStore, basicAuthValidator)
|
||||||
headersChain, err := buildHeadersChain(opts)
|
headersChain, err := buildHeadersChain(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not build headers chain: %v", err)
|
return nil, fmt.Errorf("could not build headers chain: %v", err)
|
||||||
@ -190,7 +195,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
|
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
|
||||||
|
|
||||||
ProxyPrefix: opts.ProxyPrefix,
|
ProxyPrefix: opts.ProxyPrefix,
|
||||||
provider: opts.GetProvider(),
|
provider: provider,
|
||||||
sessionStore: sessionStore,
|
sessionStore: sessionStore,
|
||||||
redirectURL: redirectURL,
|
redirectURL: redirectURL,
|
||||||
allowedRoutes: allowedRoutes,
|
allowedRoutes: allowedRoutes,
|
||||||
@ -346,12 +351,12 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
|
|||||||
return chain, nil
|
return chain, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain {
|
func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain {
|
||||||
chain := alice.New()
|
chain := alice.New()
|
||||||
|
|
||||||
if opts.SkipJwtBearerTokens {
|
if opts.SkipJwtBearerTokens {
|
||||||
sessionLoaders := []middlewareapi.TokenToSessionFunc{
|
sessionLoaders := []middlewareapi.TokenToSessionFunc{
|
||||||
opts.GetProvider().CreateSessionFromToken,
|
provider.CreateSessionFromToken,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, verifier := range opts.GetJWTBearerVerifiers() {
|
for _, verifier := range opts.GetJWTBearerVerifiers() {
|
||||||
@ -369,8 +374,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt
|
|||||||
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
|
||||||
SessionStore: sessionStore,
|
SessionStore: sessionStore,
|
||||||
RefreshPeriod: opts.Cookie.Refresh,
|
RefreshPeriod: opts.Cookie.Refresh,
|
||||||
RefreshSession: opts.GetProvider().RefreshSession,
|
RefreshSession: provider.RefreshSession,
|
||||||
ValidateSession: opts.GetProvider().ValidateSession,
|
ValidateSession: provider.ValidateSession,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
@ -161,13 +161,11 @@ func Test_enrichSession(t *testing.T) {
|
|||||||
err := validation.Validate(opts)
|
err := validation.Validate(opts)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// intentionally set after validation.Validate(opts) since it will clobber
|
|
||||||
// our TestProvider and call `providers.New` defaulting to `providers.GoogleProvider`
|
|
||||||
opts.SetProvider(NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail))
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return true })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
proxy.provider = NewTestProvider(&url.URL{Host: "www.example.com"}, providerEmail)
|
||||||
|
|
||||||
err = proxy.enrichSessionState(context.Background(), tc.session)
|
err = proxy.enrichSessionState(context.Background(), tc.session)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -232,13 +230,13 @@ func TestBasicAuthPassword(t *testing.T) {
|
|||||||
providerURL, _ := url.Parse(providerServer.URL)
|
providerURL, _ := url.Parse(providerServer.URL)
|
||||||
const emailAddress = "john.doe@example.com"
|
const emailAddress = "john.doe@example.com"
|
||||||
|
|
||||||
opts.SetProvider(NewTestProvider(providerURL, emailAddress))
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
proxy, err := NewOAuthProxy(opts, func(email string) bool {
|
||||||
return email == emailAddress
|
return email == emailAddress
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
proxy.provider = NewTestProvider(providerURL, emailAddress)
|
||||||
|
|
||||||
// Save the required session
|
// Save the required session
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
@ -390,10 +388,10 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe
|
|||||||
|
|
||||||
testProvider := NewTestProvider(providerURL, emailAddress)
|
testProvider := NewTestProvider(providerURL, emailAddress)
|
||||||
testProvider.ValidToken = opts.ValidToken
|
testProvider.ValidToken = opts.ValidToken
|
||||||
patt.opts.SetProvider(testProvider)
|
|
||||||
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
|
patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool {
|
||||||
return email == emailAddress
|
return email == emailAddress
|
||||||
})
|
})
|
||||||
|
patt.proxy.provider = testProvider
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -769,11 +767,17 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider = &TestProvider{
|
testProvider := &TestProvider{
|
||||||
ProviderData: &providers.ProviderData{},
|
ProviderData: &providers.ProviderData{},
|
||||||
ValidToken: opts.providerValidateCookieResponse,
|
ValidToken: opts.providerValidateCookieResponse,
|
||||||
}
|
}
|
||||||
pcTest.proxy.provider.(*TestProvider).SetAllowedGroups(pcTest.opts.Providers[0].AllowedGroups)
|
|
||||||
|
groups := pcTest.opts.Providers[0].AllowedGroups
|
||||||
|
testProvider.AllowedGroups = make(map[string]struct{}, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
testProvider.AllowedGroups[group] = struct{}{}
|
||||||
|
}
|
||||||
|
pcTest.proxy.provider = testProvider
|
||||||
|
|
||||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||||
// access_token validation.
|
// access_token validation.
|
||||||
@ -1359,12 +1363,12 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
upstreamURL, _ := url.Parse(upstreamServer.URL)
|
upstreamURL, _ := url.Parse(upstreamServer.URL)
|
||||||
opts.SetProvider(NewTestProvider(upstreamURL, ""))
|
|
||||||
|
|
||||||
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
|
proxy, err := NewOAuthProxy(opts, func(string) bool { return false })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
proxy.provider = NewTestProvider(upstreamURL, "")
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
|
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
|
||||||
proxy.ServeHTTP(rw, req)
|
proxy.ServeHTTP(rw, req)
|
||||||
@ -1409,6 +1413,7 @@ type SignatureTest struct {
|
|||||||
header http.Header
|
header http.Header
|
||||||
rw *httptest.ResponseRecorder
|
rw *httptest.ResponseRecorder
|
||||||
authenticator *SignatureAuthenticator
|
authenticator *SignatureAuthenticator
|
||||||
|
authProvider providers.Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSignatureTest() (*SignatureTest, error) {
|
func NewSignatureTest() (*SignatureTest, error) {
|
||||||
@ -1443,7 +1448,7 @@ func NewSignatureTest() (*SignatureTest, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
opts.SetProvider(NewTestProvider(providerURL, "mbland@acm.org"))
|
testProvider := NewTestProvider(providerURL, "mbland@acm.org")
|
||||||
|
|
||||||
return &SignatureTest{
|
return &SignatureTest{
|
||||||
opts,
|
opts,
|
||||||
@ -1453,6 +1458,7 @@ func NewSignatureTest() (*SignatureTest, error) {
|
|||||||
make(http.Header),
|
make(http.Header),
|
||||||
httptest.NewRecorder(),
|
httptest.NewRecorder(),
|
||||||
authenticator,
|
authenticator,
|
||||||
|
testProvider,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1486,6 +1492,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
proxy.provider = st.authProvider
|
||||||
|
|
||||||
var bodyBuf io.ReadCloser
|
var bodyBuf io.ReadCloser
|
||||||
if body != "" {
|
if body != "" {
|
||||||
|
Reference in New Issue
Block a user