diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f16ce2f..ace35074 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ## Changes since v6.0.0 +- [#562](https://github.com/oauth2-proxy/oauth2-proxy/pull/562) Create generic Authorization Header constructor (@JoelSpeed) - [#715](https://github.com/oauth2-proxy/oauth2-proxy/pull/715) Ensure session times are not nil before printing them (@JoelSpeed) - [#714](https://github.com/oauth2-proxy/oauth2-proxy/pull/714) Support passwords with Redis session stores (@NickMeves) - [#719](https://github.com/oauth2-proxy/oauth2-proxy/pull/719) Add Gosec fixes to areas that are intermittently flagged on PRs (@NickMeves) diff --git a/providers/azure.go b/providers/azure.go index 12ef13d0..0ae0cba6 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "errors" - "fmt" "net/http" "net/url" "time" @@ -154,10 +153,8 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (s return } -func getAzureHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header +func makeAzureHeader(accessToken string) http.Header { + return makeAuthorizationHeader(tokenTypeBearer, accessToken, nil) } func getEmailFromJSON(json *simplejson.Json) (string, error) { @@ -188,7 +185,7 @@ func (p *AzureProvider) GetEmailAddress(ctx context.Context, s *sessions.Session json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). - WithHeaders(getAzureHeader(s.AccessToken)). + WithHeaders(makeAzureHeader(s.AccessToken)). Do(). UnmarshalJSON() if err != nil { diff --git a/providers/digitalocean.go b/providers/digitalocean.go index a5314892..c88533e8 100644 --- a/providers/digitalocean.go +++ b/providers/digitalocean.go @@ -3,8 +3,6 @@ package providers import ( "context" "errors" - "fmt" - "net/http" "net/url" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -62,13 +60,6 @@ func NewDigitalOceanProvider(p *ProviderData) *DigitalOceanProvider { return &DigitalOceanProvider{ProviderData: p} } -func getDigitalOceanHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Content-Type", "application/json") - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header -} - // GetEmailAddress returns the Account email address func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { if s.AccessToken == "" { @@ -77,7 +68,7 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. json, err := requests.New(p.ProfileURL.String()). WithContext(ctx). - WithHeaders(getDigitalOceanHeader(s.AccessToken)). + WithHeaders(makeOIDCHeader(s.AccessToken)). Do(). UnmarshalJSON() if err != nil { @@ -93,5 +84,5 @@ func (p *DigitalOceanProvider) GetEmailAddress(ctx context.Context, s *sessions. // ValidateSessionState validates the AccessToken func (p *DigitalOceanProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { - return validateToken(ctx, p, s.AccessToken, getDigitalOceanHeader(s.AccessToken)) + return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) } diff --git a/providers/facebook.go b/providers/facebook.go index 00a5b55b..7bbc0b45 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -3,8 +3,6 @@ package providers import ( "context" "errors" - "fmt" - "net/http" "net/url" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" @@ -63,14 +61,6 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { return &FacebookProvider{ProviderData: p} } -func getFacebookHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Accept", "application/json") - header.Set("x-li-format", "json") - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header -} - // GetEmailAddress returns the Account email address func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { if s.AccessToken == "" { @@ -85,7 +75,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess requestURL := p.ProfileURL.String() + "?fields=name,email" err := requests.New(requestURL). WithContext(ctx). - WithHeaders(getFacebookHeader(s.AccessToken)). + WithHeaders(makeOIDCHeader(s.AccessToken)). Do(). UnmarshalInto(&r) if err != nil { @@ -100,5 +90,5 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess // ValidateSessionState validates the AccessToken func (p *FacebookProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { - return validateToken(ctx, p, s.AccessToken, getFacebookHeader(s.AccessToken)) + return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) } diff --git a/providers/github.go b/providers/github.go index 7d08ac0e..4004a881 100644 --- a/providers/github.go +++ b/providers/github.go @@ -74,11 +74,12 @@ func NewGitHubProvider(p *ProviderData) *GitHubProvider { return &GitHubProvider{ProviderData: p} } -func getGitHubHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Accept", "application/vnd.github.v3+json") - header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) - return header +func makeGitHubHeader(accessToken string) http.Header { + // extra headers required by the GitHub API when making authenticated requests + extraHeaders := map[string]string{ + acceptHeader: "application/vnd.github.v3+json", + } + return makeAuthorizationHeader(tokenTypeToken, accessToken, extraHeaders) } // SetOrgTeam adds GitHub org reading parameters to the OAuth2 scope @@ -129,7 +130,7 @@ func (p *GitHubProvider) hasOrg(ctx context.Context, accessToken string) (bool, var op orgsPage err := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(accessToken)). + WithHeaders(makeGitHubHeader(accessToken)). Do(). UnmarshalInto(&op) if err != nil { @@ -196,7 +197,7 @@ func (p *GitHubProvider) hasOrgAndTeam(ctx context.Context, accessToken string) // nolint:bodyclose result := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(accessToken)). + WithHeaders(makeGitHubHeader(accessToken)). Do() if result.Error() != nil { return false, result.Error() @@ -296,7 +297,7 @@ func (p *GitHubProvider) hasRepo(ctx context.Context, accessToken string) (bool, var repo repository err := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(accessToken)). + WithHeaders(makeGitHubHeader(accessToken)). Do(). UnmarshalInto(&repo) if err != nil { @@ -324,7 +325,7 @@ func (p *GitHubProvider) hasUser(ctx context.Context, accessToken string) (bool, err := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(accessToken)). + WithHeaders(makeGitHubHeader(accessToken)). Do(). UnmarshalInto(&user) if err != nil { @@ -347,7 +348,7 @@ func (p *GitHubProvider) isCollaborator(ctx context.Context, username, accessTok } result := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(accessToken)). + WithHeaders(makeGitHubHeader(accessToken)). Do() if result.Error() != nil { return false, result.Error() @@ -411,7 +412,7 @@ func (p *GitHubProvider) GetEmailAddress(ctx context.Context, s *sessions.Sessio } err := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(s.AccessToken)). + WithHeaders(makeGitHubHeader(s.AccessToken)). Do(). UnmarshalInto(&emails) if err != nil { @@ -446,7 +447,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta err := requests.New(endpoint.String()). WithContext(ctx). - WithHeaders(getGitHubHeader(s.AccessToken)). + WithHeaders(makeGitHubHeader(s.AccessToken)). Do(). UnmarshalInto(&user) if err != nil { @@ -465,7 +466,7 @@ func (p *GitHubProvider) GetUserName(ctx context.Context, s *sessions.SessionSta // ValidateSessionState validates the AccessToken func (p *GitHubProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { - return validateToken(ctx, p, s.AccessToken, getGitHubHeader(s.AccessToken)) + return validateToken(ctx, p, s.AccessToken, makeGitHubHeader(s.AccessToken)) } // isVerifiedUser diff --git a/providers/linkedin.go b/providers/linkedin.go index 67e015e6..99613e43 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -3,7 +3,6 @@ package providers import ( "context" "errors" - "fmt" "net/http" "net/url" @@ -62,12 +61,13 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { return &LinkedInProvider{ProviderData: p} } -func getLinkedInHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Accept", "application/json") - header.Set("x-li-format", "json") - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header +func makeLinkedInHeader(accessToken string) http.Header { + // extra headers required by the LinkedIn API when making authenticated requests + extraHeaders := map[string]string{ + acceptHeader: acceptApplicationJSON, + "x-li-format": "json", + } + return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) } // GetEmailAddress returns the Account email address @@ -79,7 +79,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess requestURL := p.ProfileURL.String() + "?format=json" json, err := requests.New(requestURL). WithContext(ctx). - WithHeaders(getLinkedInHeader(s.AccessToken)). + WithHeaders(makeLinkedInHeader(s.AccessToken)). Do(). UnmarshalJSON() if err != nil { @@ -95,5 +95,5 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess // ValidateSessionState validates the AccessToken func (p *LinkedInProvider) ValidateSessionState(ctx context.Context, s *sessions.SessionState) bool { - return validateToken(ctx, p, s.AccessToken, getLinkedInHeader(s.AccessToken)) + return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) } diff --git a/providers/nextcloud.go b/providers/nextcloud.go index 2fea1fbe..a7498073 100644 --- a/providers/nextcloud.go +++ b/providers/nextcloud.go @@ -3,7 +3,6 @@ package providers import ( "context" "fmt" - "net/http" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" "github.com/oauth2-proxy/oauth2-proxy/pkg/requests" @@ -24,17 +23,11 @@ func NewNextcloudProvider(p *ProviderData) *NextcloudProvider { return &NextcloudProvider{ProviderData: p} } -func getNextcloudHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header -} - // GetEmailAddress returns the Account email address func (p *NextcloudProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). - WithHeaders(getNextcloudHeader(s.AccessToken)). + WithHeaders(makeOIDCHeader(s.AccessToken)). Do(). UnmarshalJSON() if err != nil { diff --git a/providers/oidc.go b/providers/oidc.go index 8f2f02b9..b14e0b61 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -3,7 +3,6 @@ package providers import ( "context" "fmt" - "net/http" "strings" "time" @@ -221,13 +220,6 @@ func (p *OIDCProvider) ValidateSessionState(ctx context.Context, s *sessions.Ses return err == nil } -func getOIDCHeader(accessToken string) http.Header { - header := make(http.Header) - header.Set("Accept", "application/json") - header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) - return header -} - func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc.IDToken, token *oauth2.Token) (*OIDCClaims, error) { claims := &OIDCClaims{} // Extract default claims. @@ -263,7 +255,7 @@ func (p *OIDCProvider) findClaimsFromIDToken(ctx context.Context, idToken *oidc. // Make a query to the userinfo endpoint, and attempt to locate the email from there. respJSON, err := requests.New(profileURL). WithContext(ctx). - WithHeaders(getOIDCHeader(token.AccessToken)). + WithHeaders(makeOIDCHeader(token.AccessToken)). Do(). UnmarshalJSON() if err != nil { diff --git a/providers/util.go b/providers/util.go new file mode 100644 index 00000000..374f637e --- /dev/null +++ b/providers/util.go @@ -0,0 +1,31 @@ +package providers + +import ( + "fmt" + "net/http" +) + +const ( + tokenTypeBearer = "Bearer" + tokenTypeToken = "token" + + acceptHeader = "Accept" + acceptApplicationJSON = "application/json" +) + +func makeAuthorizationHeader(prefix, token string, extraHeaders map[string]string) http.Header { + header := make(http.Header) + for key, value := range extraHeaders { + header.Add(key, value) + } + header.Set("Authorization", fmt.Sprintf("%s %s", prefix, token)) + return header +} + +func makeOIDCHeader(accessToken string) http.Header { + // extra headers required by the IDP when making authenticated requests + extraHeaders := map[string]string{ + acceptHeader: acceptApplicationJSON, + } + return makeAuthorizationHeader(tokenTypeBearer, accessToken, extraHeaders) +} diff --git a/providers/util_test.go b/providers/util_test.go new file mode 100644 index 00000000..798df6cb --- /dev/null +++ b/providers/util_test.go @@ -0,0 +1,66 @@ +package providers + +import ( + "fmt" + "testing" + + . "github.com/onsi/gomega" +) + +func TestMakeAuhtorizationHeader(t *testing.T) { + testCases := []struct { + name string + prefix string + token string + extraHeaders map[string]string + }{ + { + name: "With an empty prefix, token and no additional headers", + prefix: "", + token: "", + extraHeaders: nil, + }, + { + name: "With a Bearer token type", + prefix: tokenTypeBearer, + token: "abcdef", + extraHeaders: nil, + }, + { + name: "With a Token token type", + prefix: tokenTypeToken, + token: "123456", + extraHeaders: nil, + }, + { + name: "With a Bearer token type and Accept application/json", + prefix: tokenTypeToken, + token: "abc", + extraHeaders: map[string]string{ + acceptHeader: acceptApplicationJSON, + }, + }, + { + name: "With a Bearer token type and multiple headers", + prefix: tokenTypeToken, + token: "123", + extraHeaders: map[string]string{ + acceptHeader: acceptApplicationJSON, + "foo": "bar", + "key": "value", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g := NewWithT(t) + + header := makeAuthorizationHeader(tc.prefix, tc.token, tc.extraHeaders) + g.Expect(header.Get("Authorization")).To(Equal(fmt.Sprintf("%s %s", tc.prefix, tc.token))) + for k, v := range tc.extraHeaders { + g.Expect(header.Get(k)).To(Equal(v)) + } + }) + } +}