mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-28 09:08:44 +02:00
Merge pull request #89 from 18F/watch-email-file
Reload authenticated-emails-file upon update
This commit is contained in:
commit
254b26d4a0
1
Godeps
1
Godeps
@ -2,3 +2,4 @@ github.com/BurntSushi/toml 3883ac1ce943878302255f538fce319d23226223
|
||||
github.com/bitly/go-simplejson 3378bdcb5cebedcbf8b5750edee28010f128fe24
|
||||
github.com/mreiferson/go-options ee94b57f2fbf116075426f853e5abbcdfeca8b3d
|
||||
github.com/bmizerany/assert e17e99893cb6509f428e1728281c2ad60a6b31e3
|
||||
gopkg.in/fsnotify.v1 v1.2.0
|
||||
|
@ -40,8 +40,8 @@ func TestRequestFailure(t *testing.T) {
|
||||
resp, err := Request(req)
|
||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
||||
assert.NotEqual(t, nil, err)
|
||||
if !strings.HasSuffix(err.Error(), "connection refused") {
|
||||
t.Error("expected error when a connection fails")
|
||||
if !strings.Contains(err.Error(), "refused") {
|
||||
t.Error("expected error when a connection fails: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
return NewHtpasswd(r)
|
||||
}
|
||||
|
||||
|
75
validator.go
75
validator.go
@ -6,43 +6,84 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func NewValidator(domains []string, usersFile string) func(string) bool {
|
||||
validUsers := make(map[string]bool)
|
||||
type UserMap struct {
|
||||
usersFile string
|
||||
m unsafe.Pointer
|
||||
}
|
||||
|
||||
func NewUserMap(usersFile string, onUpdate func()) *UserMap {
|
||||
um := &UserMap{usersFile: usersFile}
|
||||
m := make(map[string]bool)
|
||||
atomic.StorePointer(&um.m, unsafe.Pointer(&m))
|
||||
if usersFile != "" {
|
||||
log.Printf("using authenticated emails file %s", usersFile)
|
||||
r, err := os.Open(usersFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed opening authenticated-emails-file=%q, %s", usersFile, err)
|
||||
}
|
||||
csv_reader := csv.NewReader(r)
|
||||
csv_reader.Comma = ','
|
||||
csv_reader.Comment = '#'
|
||||
csv_reader.TrimLeadingSpace = true
|
||||
records, err := csv_reader.ReadAll()
|
||||
for _, r := range records {
|
||||
validUsers[strings.ToLower(r[0])] = true
|
||||
started := WatchForUpdates(usersFile, func() {
|
||||
um.LoadAuthenticatedEmailsFile()
|
||||
onUpdate()
|
||||
})
|
||||
if started {
|
||||
log.Printf("watching %s for updates", usersFile)
|
||||
}
|
||||
um.LoadAuthenticatedEmailsFile()
|
||||
}
|
||||
return um
|
||||
}
|
||||
|
||||
func (um *UserMap) IsValid(email string) (result bool) {
|
||||
m := *(*map[string]bool)(atomic.LoadPointer(&um.m))
|
||||
_, result = m[email]
|
||||
return
|
||||
}
|
||||
|
||||
func (um *UserMap) LoadAuthenticatedEmailsFile() {
|
||||
r, err := os.Open(um.usersFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
|
||||
}
|
||||
defer r.Close()
|
||||
csv_reader := csv.NewReader(r)
|
||||
csv_reader.Comma = ','
|
||||
csv_reader.Comment = '#'
|
||||
csv_reader.TrimLeadingSpace = true
|
||||
records, err := csv_reader.ReadAll()
|
||||
if err != nil {
|
||||
log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
|
||||
return
|
||||
}
|
||||
updated := make(map[string]bool)
|
||||
for _, r := range records {
|
||||
updated[strings.ToLower(r[0])] = true
|
||||
}
|
||||
atomic.StorePointer(&um.m, unsafe.Pointer(&updated))
|
||||
}
|
||||
|
||||
func newValidatorImpl(domains []string, usersFile string,
|
||||
onUpdate func()) func(string) bool {
|
||||
validUsers := NewUserMap(usersFile, onUpdate)
|
||||
|
||||
for i, domain := range domains {
|
||||
domains[i] = strings.ToLower(domain)
|
||||
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
|
||||
}
|
||||
|
||||
validator := func(email string) bool {
|
||||
email = strings.ToLower(email)
|
||||
valid := false
|
||||
for _, domain := range domains {
|
||||
emailSuffix := fmt.Sprintf("@%s", domain)
|
||||
valid = valid || strings.HasSuffix(email, emailSuffix)
|
||||
valid = valid || strings.HasSuffix(email, domain)
|
||||
}
|
||||
if !valid {
|
||||
_, valid = validUsers[email]
|
||||
valid = validUsers.IsValid(email)
|
||||
}
|
||||
log.Printf("validating: is %s valid? %v", email, valid)
|
||||
return valid
|
||||
}
|
||||
return validator
|
||||
}
|
||||
|
||||
func NewValidator(domains []string, usersFile string) func(string) bool {
|
||||
return newValidatorImpl(domains, usersFile, func() {})
|
||||
}
|
||||
|
@ -7,23 +7,117 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) {
|
||||
auth_email_file, err := ioutil.TempFile("", "test_auth_emails_")
|
||||
type ValidatorTest struct {
|
||||
auth_email_file *os.File
|
||||
}
|
||||
|
||||
func NewValidatorTest(t *testing.T) *ValidatorTest {
|
||||
vt := &ValidatorTest{}
|
||||
var err error
|
||||
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file: " + err.Error())
|
||||
}
|
||||
defer os.Remove(auth_email_file.Name())
|
||||
return vt
|
||||
}
|
||||
|
||||
auth_email_file.WriteString(
|
||||
strings.Join([]string{"Foo.Bar@Example.Com"}, "\n"))
|
||||
err = auth_email_file.Close()
|
||||
if err != nil {
|
||||
t.Fatal("failed to close temp file " + auth_email_file.Name() +
|
||||
": " + err.Error())
|
||||
func (vt *ValidatorTest) TearDown() {
|
||||
os.Remove(vt.auth_email_file.Name())
|
||||
}
|
||||
|
||||
// This will close vt.auth_email_file.
|
||||
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
|
||||
defer vt.auth_email_file.Close()
|
||||
vt.auth_email_file.WriteString(strings.Join(emails, "\n"))
|
||||
if err := vt.auth_email_file.Close(); err != nil {
|
||||
t.Fatal("failed to close temp file " +
|
||||
vt.auth_email_file.Name() + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorEmpty(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string(nil))
|
||||
domains := []string(nil)
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
|
||||
if validator("foo.bar@example.com") {
|
||||
t.Error("nothing should validate when the email and " +
|
||||
"domain lists are empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorSingleEmail(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string{"foo.bar@example.com"})
|
||||
domains := []string(nil)
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
|
||||
if !validator("foo.bar@example.com") {
|
||||
t.Error("email should validate")
|
||||
}
|
||||
if validator("baz.quux@example.com") {
|
||||
t.Error("email from same domain but not in list " +
|
||||
"should not validate when domain list is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorSingleDomain(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string(nil))
|
||||
domains := []string{"example.com"}
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
|
||||
if !validator("foo.bar@example.com") {
|
||||
t.Error("email should validate")
|
||||
}
|
||||
if !validator("baz.quux@example.com") {
|
||||
t.Error("email from same domain should validate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string{
|
||||
"xyzzy@example.com",
|
||||
"plugh@example.com",
|
||||
})
|
||||
domains := []string{"example0.com", "example1.com"}
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
|
||||
if !validator("foo.bar@example0.com") {
|
||||
t.Error("email from first domain should validate")
|
||||
}
|
||||
if !validator("baz.quux@example1.com") {
|
||||
t.Error("email from second domain should validate")
|
||||
}
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("first email in list should validate")
|
||||
}
|
||||
if !validator("plugh@example.com") {
|
||||
t.Error("second email in list should validate")
|
||||
}
|
||||
if validator("xyzzy.plugh@example.com") {
|
||||
t.Error("email not in list that matches no domains " +
|
||||
"should not validate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"})
|
||||
domains := []string{"Frobozz.Com"}
|
||||
validator := NewValidator(domains, auth_email_file.Name())
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
|
||||
if !validator("foo.bar@example.com") {
|
||||
t.Error("loaded email addresses are not lower-cased")
|
||||
|
50
validator_watcher_copy_test.go
Normal file
50
validator_watcher_copy_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
// +build go1.3
|
||||
// +build !plan9,!solaris,!windows
|
||||
|
||||
// Turns out you can't copy over an existing file on Windows.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver(
|
||||
t *testing.T, emails []string) {
|
||||
orig_file := vt.auth_email_file
|
||||
var err error
|
||||
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file for copy: " + err.Error())
|
||||
}
|
||||
vt.WriteEmails(t, emails)
|
||||
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
|
||||
if err != nil {
|
||||
t.Fatal("failed to copy over temp file: " + err.Error())
|
||||
}
|
||||
vt.auth_email_file = orig_file
|
||||
}
|
||||
|
||||
func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
||||
domains := []string(nil)
|
||||
updated := make(chan bool)
|
||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
func() { updated <- true })
|
||||
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("email in list should validate")
|
||||
}
|
||||
|
||||
vt.UpdateEmailFileViaCopyingOver(t, []string{"plugh@example.com"})
|
||||
<-updated
|
||||
|
||||
if validator("xyzzy@example.com") {
|
||||
t.Error("email removed from list should not validate")
|
||||
}
|
||||
}
|
105
validator_watcher_test.go
Normal file
105
validator_watcher_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// +build go1.3
|
||||
// +build !plan9,!solaris
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
|
||||
var err error
|
||||
vt.auth_email_file, err = os.OpenFile(
|
||||
vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
t.Fatal("failed to re-open temp file for updates")
|
||||
}
|
||||
vt.WriteEmails(t, emails)
|
||||
}
|
||||
|
||||
func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace(
|
||||
t *testing.T, emails []string) {
|
||||
orig_file := vt.auth_email_file
|
||||
var err error
|
||||
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file for rename and replace: " +
|
||||
err.Error())
|
||||
}
|
||||
vt.WriteEmails(t, emails)
|
||||
|
||||
moved_name := orig_file.Name() + "-moved"
|
||||
err = os.Rename(orig_file.Name(), moved_name)
|
||||
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
|
||||
if err != nil {
|
||||
t.Fatal("failed to rename and replace temp file: " +
|
||||
err.Error())
|
||||
}
|
||||
vt.auth_email_file = orig_file
|
||||
os.Remove(moved_name)
|
||||
}
|
||||
|
||||
func TestValidatorOverwriteEmailListDirectly(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string{
|
||||
"xyzzy@example.com",
|
||||
"plugh@example.com",
|
||||
})
|
||||
domains := []string(nil)
|
||||
updated := make(chan bool)
|
||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
func() { updated <- true })
|
||||
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("first email in list should validate")
|
||||
}
|
||||
if !validator("plugh@example.com") {
|
||||
t.Error("second email in list should validate")
|
||||
}
|
||||
if validator("xyzzy.plugh@example.com") {
|
||||
t.Error("email not in list that matches no domains " +
|
||||
"should not validate")
|
||||
}
|
||||
|
||||
vt.UpdateEmailFile(t, []string{
|
||||
"xyzzy.plugh@example.com",
|
||||
"plugh@example.com",
|
||||
})
|
||||
<-updated
|
||||
|
||||
if validator("xyzzy@example.com") {
|
||||
t.Error("email removed from list should not validate")
|
||||
}
|
||||
if !validator("plugh@example.com") {
|
||||
t.Error("email retained in list should validate")
|
||||
}
|
||||
if !validator("xyzzy.plugh@example.com") {
|
||||
t.Error("email added to list should validate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) {
|
||||
vt := NewValidatorTest(t)
|
||||
defer vt.TearDown()
|
||||
|
||||
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
||||
domains := []string(nil)
|
||||
updated := make(chan bool)
|
||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
func() { updated <- true })
|
||||
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("email in list should validate")
|
||||
}
|
||||
|
||||
vt.UpdateEmailFileViaRenameAndReplace(t, []string{"plugh@example.com"})
|
||||
<-updated
|
||||
|
||||
if validator("xyzzy@example.com") {
|
||||
t.Error("email removed from list should not validate")
|
||||
}
|
||||
}
|
64
watcher.go
Normal file
64
watcher.go
Normal file
@ -0,0 +1,64 @@
|
||||
// +build go1.3
|
||||
// +build !plan9,!solaris
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"gopkg.in/fsnotify.v1"
|
||||
)
|
||||
|
||||
func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) {
|
||||
const sleep_interval = 50 * time.Millisecond
|
||||
|
||||
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
|
||||
if event.Op&fsnotify.Chmod != 0 {
|
||||
time.Sleep(sleep_interval)
|
||||
}
|
||||
for {
|
||||
if _, err := os.Stat(event.Name); err == nil {
|
||||
if err := watcher.Add(event.Name); err == nil {
|
||||
log.Printf("watching resumed for %s", event.Name)
|
||||
return
|
||||
}
|
||||
}
|
||||
time.Sleep(sleep_interval)
|
||||
}
|
||||
}
|
||||
|
||||
func WatchForUpdates(filename string, action func()) bool {
|
||||
filename = filepath.Clean(filename)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Fatal("failed to create watcher for ", filename, ": ", err)
|
||||
}
|
||||
go func() {
|
||||
defer watcher.Close()
|
||||
for {
|
||||
select {
|
||||
case event := <-watcher.Events:
|
||||
// On Arch Linux, it appears Chmod events precede Remove events,
|
||||
// which causes a race between action() and the coming Remove event.
|
||||
// If the Remove wins, the action() (which calls
|
||||
// UserMap.LoadAuthenticatedEmailsFile()) crashes when the file
|
||||
// can't be opened.
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
|
||||
log.Printf("watching interrupted on event: %s", event)
|
||||
WaitForReplacement(event, watcher)
|
||||
}
|
||||
log.Printf("reloading after event: %s", event)
|
||||
action()
|
||||
case err := <-watcher.Errors:
|
||||
log.Printf("error watching %s: %s", filename, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
if err = watcher.Add(filename); err != nil {
|
||||
log.Fatal("failed to add ", filename, " to watcher: ", err)
|
||||
}
|
||||
return true
|
||||
}
|
13
watcher_unsupported.go
Normal file
13
watcher_unsupported.go
Normal file
@ -0,0 +1,13 @@
|
||||
// +build go1.1
|
||||
// +build plan9,solaris
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
)
|
||||
|
||||
func WatchForUpdates(filename string, action func()) bool {
|
||||
log.Printf("file watching not implemented on this platform")
|
||||
return false
|
||||
}
|
Loading…
Reference in New Issue
Block a user