diff --git a/CHANGELOG.md b/CHANGELOG.md index f892f6b0..9da760bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ - [#111](https://github.com/pusher/oauth2_proxy/pull/111) Add option for telling where to find a login.gov JWT key file (@timothy-spencer) +- [#141](https://github.com/pusher/oauth2_proxy/pull/141) Check google group membership based on email address (@bchess) + - Google Group membership is additionally checked via email address, allowing users outside a GSuite domain to be authorized. + # v3.2.0 ## Release highlights diff --git a/providers/google.go b/providers/google.go index f031bfc0..7b8815f4 100644 --- a/providers/google.go +++ b/providers/google.go @@ -187,10 +187,8 @@ func userInGroup(service *admin.Service, groups []string, email string) bool { user, err := fetchUser(service, email) if err != nil { logger.Printf("error fetching user: %v", err) - return false + user = nil } - id := user.Id - custID := user.CustomerId for _, group := range groups { members, err := fetchGroupMembers(service, group) @@ -204,13 +202,19 @@ func userInGroup(service *admin.Service, groups []string, email string) bool { } for _, member := range members { + if member.Email == email { + return true + } + if user == nil { + continue + } switch member.Type { case "CUSTOMER": - if member.Id == custID { + if member.Id == user.CustomerId { return true } case "USER": - if member.Id == id { + if member.Id == user.Id { return true } } diff --git a/providers/google_test.go b/providers/google_test.go index 25b375a5..0c1725bf 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -3,12 +3,15 @@ package providers import ( "encoding/base64" "encoding/json" + "fmt" "net/http" "net/http/httptest" "net/url" "testing" "github.com/stretchr/testify/assert" + + admin "google.golang.org/api/admin/directory/v1" ) func newRedeemServer(body []byte) (*url.URL, *httptest.Server) { @@ -179,3 +182,37 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { } } + +func TestGoogleProviderUserInGroup(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/users/member-by-email@example.com" { + fmt.Fprintln(w, "{}") + } else if r.URL.Path == "/users/non-member-by-email@example.com" { + fmt.Fprintln(w, "{}") + } else if r.URL.Path == "/users/member-by-id@example.com" { + fmt.Fprintln(w, "{\"id\": \"member-id\"}") + } else if r.URL.Path == "/users/non-member-by-id@example.com" { + fmt.Fprintln(w, "{\"id\": \"non-member-id\"}") + } else if r.URL.Path == "/groups/group@example.com/members" { + fmt.Fprintln(w, "{\"members\": [{\"email\": \"member-by-email@example.com\"}, {\"id\": \"member-id\", \"type\": \"USER\"}]}") + } + })) + defer ts.Close() + + client := ts.Client() + service, err := admin.New(client) + service.BasePath = ts.URL + assert.Equal(t, nil, err) + + result := userInGroup(service, []string{"group@example.com"}, "member-by-email@example.com") + assert.True(t, result) + + result = userInGroup(service, []string{"group@example.com"}, "member-by-id@example.com") + assert.True(t, result) + + result = userInGroup(service, []string{"group@example.com"}, "non-member-by-id@example.com") + assert.False(t, result) + + result = userInGroup(service, []string{"group@example.com"}, "non-member-by-email@example.com") + assert.False(t, result) +}