diff --git a/main.go b/main.go index 7cf57771..a163c9ba 100644 --- a/main.go +++ b/main.go @@ -23,12 +23,13 @@ var ( htpasswdFile = flag.String("htpasswd-file", "", "additionally authenticate against a htpasswd file. Entries must be created with \"htpasswd -s\" for SHA encryption") cookieSecret = flag.String("cookie-secret", "", "the seed string for secure cookies") cookieDomain = flag.String("cookie-domain", "", "an optional cookie domain to force cookies to") - googleAppsDomain = flag.String("google-apps-domain", "", "authenticate against the given google apps domain") authenticatedEmailsFile = flag.String("authenticated-emails-file", "", "authenticate against emails via file (one per line)") + googleAppsDomains = StringArray{} upstreams = StringArray{} ) func init() { + flag.Var(&googleAppsDomains, "google-apps-domain", "authenticate against the given google apps domain (may be given multiple times)") flag.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path") } @@ -78,11 +79,11 @@ func main() { log.Fatalf("error parsing --redirect-url %s", err.Error()) } - validator := NewValidator(*googleAppsDomain, *authenticatedEmailsFile) + validator := NewValidator(googleAppsDomains, *authenticatedEmailsFile) oauthproxy := NewOauthProxy(upstreamUrls, *clientID, *clientSecret, validator) oauthproxy.SetRedirectUrl(redirectUrl) - if *googleAppsDomain != "" && *authenticatedEmailsFile == "" { - oauthproxy.SignInMessage = fmt.Sprintf("using a %s email address", *googleAppsDomain) + if len(googleAppsDomains) != 0 && *authenticatedEmailsFile == "" { + oauthproxy.SignInMessage = fmt.Sprintf("using a email address from the following domains: %v", strings.Join(googleAppsDomains, ", ")) } if *htpasswdFile != "" { oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(*htpasswdFile) diff --git a/validator.go b/validator.go index bf07ea5b..4caa5ec4 100644 --- a/validator.go +++ b/validator.go @@ -8,13 +8,8 @@ import ( "strings" ) -func NewValidator(domain string, usersFile string) func(string) bool { - +func NewValidator(domains []string, usersFile string) func(string) bool { validUsers := make(map[string]bool) - emailSuffix := "" - if domain != "" { - emailSuffix = fmt.Sprintf("@%s", domain) - } if usersFile != "" { r, err := os.Open(usersFile) @@ -32,9 +27,10 @@ func NewValidator(domain string, usersFile string) func(string) bool { } validator := func(email string) bool { - var valid bool - if emailSuffix != "" { - valid = strings.HasSuffix(email, emailSuffix) + valid := false + for _, domain := range domains { + emailSuffix := fmt.Sprintf("@%s", domain) + valid = valid || strings.HasSuffix(email, emailSuffix) } if !valid { _, valid = validUsers[email]