diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 46cdcedb..c0f91422 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -36,7 +36,7 @@ type Options struct { TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` - KeycloakGroup string `flag:"keycloak-group" cfg:"keycloak_group"` + KeycloakGroups []string `flag:"keycloak-group" cfg:"keycloak_groups"` AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` BitbucketTeam string `flag:"bitbucket-team" cfg:"bitbucket_team"` BitbucketRepository string `flag:"bitbucket-repository" cfg:"bitbucket_repository"` @@ -181,7 +181,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.StringSlice("email-domain", []string{}, "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.StringSlice("whitelist-domain", []string{}, "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)") - flagSet.String("keycloak-group", "", "restrict login to members of this group.") + flagSet.StringSlice("keycloak-group", []string{}, "restrict logins to members of these groups (may be given multiple times)") flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") flagSet.String("bitbucket-team", "", "restrict logins to members of this team") flagSet.String("bitbucket-repository", "", "restrict logins to user with access to this repository") diff --git a/pkg/validation/options.go b/pkg/validation/options.go index 4fc0b0a4..52b0fb69 100644 --- a/pkg/validation/options.go +++ b/pkg/validation/options.go @@ -263,7 +263,10 @@ func parseProviderInfo(o *options.Options, msgs []string) []string { p.SetRepo(o.GitHubRepo, o.GitHubToken) p.SetUsers(o.GitHubUsers) case *providers.KeycloakProvider: - p.SetGroup(o.KeycloakGroup) + // Backwards compatibility with `--keycloak-group` option + if len(o.KeycloakGroups) > 0 { + p.SetAllowedGroups(o.KeycloakGroups) + } case *providers.GoogleProvider: if o.GoogleServiceAccountJSON != "" { file, err := os.Open(o.GoogleServiceAccountJSON) diff --git a/providers/keycloak.go b/providers/keycloak.go index 60b3eaca..f6c74880 100644 --- a/providers/keycloak.go +++ b/providers/keycloak.go @@ -2,6 +2,7 @@ package providers import ( "context" + "fmt" "net/url" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" @@ -11,7 +12,6 @@ import ( type KeycloakProvider struct { *ProviderData - Group string } var _ Provider = (*KeycloakProvider)(nil) @@ -59,41 +59,33 @@ func NewKeycloakProvider(p *ProviderData) *KeycloakProvider { return &KeycloakProvider{ProviderData: p} } -func (p *KeycloakProvider) SetGroup(group string) { - p.Group = group -} - -func (p *KeycloakProvider) GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) { +func (p *KeycloakProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { json, err := requests.New(p.ValidateURL.String()). WithContext(ctx). SetHeader("Authorization", "Bearer "+s.AccessToken). Do(). UnmarshalJSON() if err != nil { - logger.Errorf("failed making request %s", err) - return "", err + logger.Errorf("failed making request %v", err) + return err } - if p.Group != "" { - var groups, err = json.Get("groups").Array() - if err != nil { - logger.Printf("groups not found %s", err) - return "", err - } - - var found = false - for i := range groups { - if groups[i].(string) == p.Group { - found = true - break + groups, err := json.Get("groups").StringArray() + if err != nil { + logger.Errorf("Warning: unable to extract groups from userinfo endpoint: %v", err) + } else { + for _, group := range groups { + if group != "" { + s.Groups = append(s.Groups, group) } } - - if !found { - logger.Printf("group not found, access denied") - return "", nil - } } - return json.Get("email").String() + email, err := json.Get("email").String() + if err != nil { + return fmt.Errorf("unable to extract email from userinfo endpoint: %v", err) + } + s.Email = email + + return nil } diff --git a/providers/keycloak_test.go b/providers/keycloak_test.go index 3f419f2e..d7ae4391 100644 --- a/providers/keycloak_test.go +++ b/providers/keycloak_test.go @@ -2,17 +2,33 @@ package providers import ( "context" + "errors" + "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" "github.com/stretchr/testify/assert" ) -func testKeycloakProvider(hostname, group string) *KeycloakProvider { +const ( + keycloakAccessToken = "eyJKeycloak.eyJAccess.Token" + keycloakUserinfoPath = "/api/v3/user" + + // Userinfo Test Cases + tcUIStandard = "userinfo-standard" + tcUIFail = "userinfo-fail" + tcUISingleGroup = "userinfo-single-group" + tcUIMissingEmail = "userinfo-missing-email" + tcUIMissingGroups = "userinfo-missing-groups" +) + +func testKeycloakProvider(backend *httptest.Server) (*KeycloakProvider, error) { p := NewKeycloakProvider( &ProviderData{ ProviderName: "", @@ -22,38 +38,165 @@ func testKeycloakProvider(hostname, group string) *KeycloakProvider { ValidateURL: &url.URL{}, Scope: ""}) - if group != "" { - p.SetGroup(group) - } + if backend != nil { + bURL, err := url.Parse(backend.URL) + if err != nil { + return nil, err + } + hostname := bURL.Host - if hostname != "" { updateURL(p.Data().LoginURL, hostname) updateURL(p.Data().RedeemURL, hostname) updateURL(p.Data().ProfileURL, hostname) updateURL(p.Data().ValidateURL, hostname) } - return p + + return p, nil } -func testKeycloakBackend(payload string) *httptest.Server { - path := "/api/v3/user" - +func testKeycloakBackend() *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - url := r.URL - if url.Path != path { + if r.URL.Path != keycloakUserinfoPath { w.WriteHeader(404) - } else if !IsAuthorizedInHeader(r.Header) { - w.WriteHeader(403) - } else { + } + + var err error + switch r.URL.Query().Get("testcase") { + case tcUIStandard: w.WriteHeader(200) - w.Write([]byte(payload)) + _, err = w.Write([]byte(` + { + "email": "michael.bland@gsa.gov", + "groups": [ + "test-grp1", + "test-grp2" + ] + } + `)) + case tcUIFail: + w.WriteHeader(500) + case tcUISingleGroup: + w.WriteHeader(200) + _, err = w.Write([]byte(` + { + "email": "michael.bland@gsa.gov", + "groups": ["test-grp1"] + } + `)) + case tcUIMissingEmail: + w.WriteHeader(200) + _, err = w.Write([]byte(` + { + "groups": [ + "test-grp1", + "test-grp2" + ] + } + `)) + case tcUIMissingGroups: + w.WriteHeader(200) + _, err = w.Write([]byte(` + { + "email": "michael.bland@gsa.gov" + } + `)) + default: + w.WriteHeader(404) + } + if err != nil { + panic(err) } })) } +var _ = Describe("Keycloak Provider Tests", func() { + var p *KeycloakProvider + var b *httptest.Server + + BeforeEach(func() { + b = testKeycloakBackend() + + var err error + p, err = testKeycloakProvider(b) + Expect(err).To(BeNil()) + }) + + AfterEach(func() { + b.Close() + }) + + Context("EnrichSession", func() { + type enrichSessionTableInput struct { + testcase string + expectedError error + expectedEmail string + expectedGroups []string + } + + DescribeTable("should return expected results", + func(in enrichSessionTableInput) { + var err error + p.ValidateURL, err = url.Parse( + fmt.Sprintf("%s%s?testcase=%s", b.URL, keycloakUserinfoPath, in.testcase), + ) + Expect(err).To(BeNil()) + + session := &sessions.SessionState{AccessToken: keycloakAccessToken} + err = p.EnrichSession(context.Background(), session) + + if in.expectedError != nil { + Expect(err).To(Equal(in.expectedError)) + } else { + Expect(err).To(BeNil()) + } + + Expect(session.Email).To(Equal(in.expectedEmail)) + + if in.expectedGroups != nil { + Expect(session.Groups).To(Equal(in.expectedGroups)) + } else { + Expect(session.Groups).To(BeNil()) + } + }, + Entry("email and multiple groups", enrichSessionTableInput{ + testcase: tcUIStandard, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("email and single group", enrichSessionTableInput{ + testcase: tcUISingleGroup, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: []string{"test-grp1"}, + }), + Entry("email and no groups", enrichSessionTableInput{ + testcase: tcUIMissingGroups, + expectedError: nil, + expectedEmail: "michael.bland@gsa.gov", + expectedGroups: nil, + }), + Entry("missing email", enrichSessionTableInput{ + testcase: tcUIMissingEmail, + expectedError: errors.New( + "unable to extract email from userinfo endpoint: type assertion to string failed"), + expectedEmail: "", + expectedGroups: []string{"test-grp1", "test-grp2"}, + }), + Entry("request failure", enrichSessionTableInput{ + testcase: tcUIFail, + expectedError: errors.New(`unexpected status "500": `), + expectedEmail: "", + expectedGroups: nil, + }), + ) + }) +}) + func TestKeycloakProviderDefaults(t *testing.T) { - p := testKeycloakProvider("", "") + p, err := testKeycloakProvider(nil) + assert.NoError(t, err) assert.NotEqual(t, nil, p) assert.Equal(t, "Keycloak", p.Data().ProviderName) assert.Equal(t, "https://keycloak.org/oauth/authorize", @@ -104,60 +247,3 @@ func TestKeycloakProviderOverrides(t *testing.T) { p.Data().ValidateURL.String()) assert.Equal(t, "profile", p.Data().Scope) } - -func TestKeycloakProviderGetEmailAddress(t *testing.T) { - b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\"}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "") - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) -} - -func TestKeycloakProviderGetEmailAddressAndGroup(t *testing.T) { - b := testKeycloakBackend("{\"email\": \"michael.bland@gsa.gov\", \"groups\": [\"test-grp1\", \"test-grp2\"]}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "test-grp1") - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) -} - -// Note that trying to trigger the "failed building request" case is not -// practical, since the only way it can fail is if the URL fails to parse. -func TestKeycloakProviderGetEmailAddressFailedRequest(t *testing.T) { - b := testKeycloakBackend("unused payload") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "") - - // We'll trigger a request failure by using an unexpected access - // token. Alternatively, we could allow the parsing of the payload as - // JSON to fail. - session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -} - -func TestKeycloakProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { - b := testKeycloakBackend("{\"foo\": \"bar\"}") - defer b.Close() - - bURL, _ := url.Parse(b.URL) - p := testKeycloakProvider(bURL.Host, "") - - session := CreateAuthorizedSession() - email, err := p.GetEmailAddress(context.Background(), session) - assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) -}