1
0
mirror of https://github.com/oauth2-proxy/oauth2-proxy.git synced 2024-11-30 09:16:52 +02:00

Merge pull request #89 from 18F/watch-email-file

Reload authenticated-emails-file upon update
This commit is contained in:
Jehiah Czebotar 2015-05-12 11:08:38 -04:00
commit 254b26d4a0
9 changed files with 398 additions and 29 deletions

1
Godeps
View File

@ -2,3 +2,4 @@ github.com/BurntSushi/toml 3883ac1ce943878302255f538fce319d23226223
github.com/bitly/go-simplejson 3378bdcb5cebedcbf8b5750edee28010f128fe24 github.com/bitly/go-simplejson 3378bdcb5cebedcbf8b5750edee28010f128fe24
github.com/mreiferson/go-options ee94b57f2fbf116075426f853e5abbcdfeca8b3d github.com/mreiferson/go-options ee94b57f2fbf116075426f853e5abbcdfeca8b3d
github.com/bmizerany/assert e17e99893cb6509f428e1728281c2ad60a6b31e3 github.com/bmizerany/assert e17e99893cb6509f428e1728281c2ad60a6b31e3
gopkg.in/fsnotify.v1 v1.2.0

View File

@ -40,8 +40,8 @@ func TestRequestFailure(t *testing.T) {
resp, err := Request(req) resp, err := Request(req)
assert.Equal(t, (*simplejson.Json)(nil), resp) assert.Equal(t, (*simplejson.Json)(nil), resp)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if !strings.HasSuffix(err.Error(), "connection refused") { if !strings.Contains(err.Error(), "refused") {
t.Error("expected error when a connection fails") t.Error("expected error when a connection fails: ", err)
} }
} }

View File

@ -21,6 +21,7 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Close()
return NewHtpasswd(r) return NewHtpasswd(r)
} }

View File

@ -6,43 +6,84 @@ import (
"log" "log"
"os" "os"
"strings" "strings"
"sync/atomic"
"unsafe"
) )
func NewValidator(domains []string, usersFile string) func(string) bool { type UserMap struct {
validUsers := make(map[string]bool) usersFile string
m unsafe.Pointer
}
func NewUserMap(usersFile string, onUpdate func()) *UserMap {
um := &UserMap{usersFile: usersFile}
m := make(map[string]bool)
atomic.StorePointer(&um.m, unsafe.Pointer(&m))
if usersFile != "" { if usersFile != "" {
log.Printf("using authenticated emails file %s", usersFile) log.Printf("using authenticated emails file %s", usersFile)
r, err := os.Open(usersFile) started := WatchForUpdates(usersFile, func() {
if err != nil { um.LoadAuthenticatedEmailsFile()
log.Fatalf("failed opening authenticated-emails-file=%q, %s", usersFile, err) onUpdate()
} })
csv_reader := csv.NewReader(r) if started {
csv_reader.Comma = ',' log.Printf("watching %s for updates", usersFile)
csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll()
for _, r := range records {
validUsers[strings.ToLower(r[0])] = true
} }
um.LoadAuthenticatedEmailsFile()
} }
return um
}
func (um *UserMap) IsValid(email string) (result bool) {
m := *(*map[string]bool)(atomic.LoadPointer(&um.m))
_, result = m[email]
return
}
func (um *UserMap) LoadAuthenticatedEmailsFile() {
r, err := os.Open(um.usersFile)
if err != nil {
log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err)
}
defer r.Close()
csv_reader := csv.NewReader(r)
csv_reader.Comma = ','
csv_reader.Comment = '#'
csv_reader.TrimLeadingSpace = true
records, err := csv_reader.ReadAll()
if err != nil {
log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err)
return
}
updated := make(map[string]bool)
for _, r := range records {
updated[strings.ToLower(r[0])] = true
}
atomic.StorePointer(&um.m, unsafe.Pointer(&updated))
}
func newValidatorImpl(domains []string, usersFile string,
onUpdate func()) func(string) bool {
validUsers := NewUserMap(usersFile, onUpdate)
for i, domain := range domains { for i, domain := range domains {
domains[i] = strings.ToLower(domain) domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
} }
validator := func(email string) bool { validator := func(email string) bool {
email = strings.ToLower(email) email = strings.ToLower(email)
valid := false valid := false
for _, domain := range domains { for _, domain := range domains {
emailSuffix := fmt.Sprintf("@%s", domain) valid = valid || strings.HasSuffix(email, domain)
valid = valid || strings.HasSuffix(email, emailSuffix)
} }
if !valid { if !valid {
_, valid = validUsers[email] valid = validUsers.IsValid(email)
} }
log.Printf("validating: is %s valid? %v", email, valid) log.Printf("validating: is %s valid? %v", email, valid)
return valid return valid
} }
return validator return validator
} }
func NewValidator(domains []string, usersFile string) func(string) bool {
return newValidatorImpl(domains, usersFile, func() {})
}

View File

@ -7,23 +7,117 @@ import (
"testing" "testing"
) )
func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { type ValidatorTest struct {
auth_email_file, err := ioutil.TempFile("", "test_auth_emails_") auth_email_file *os.File
}
func NewValidatorTest(t *testing.T) *ValidatorTest {
vt := &ValidatorTest{}
var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil { if err != nil {
t.Fatal("failed to create temp file: " + err.Error()) t.Fatal("failed to create temp file: " + err.Error())
} }
defer os.Remove(auth_email_file.Name()) return vt
}
auth_email_file.WriteString( func (vt *ValidatorTest) TearDown() {
strings.Join([]string{"Foo.Bar@Example.Com"}, "\n")) os.Remove(vt.auth_email_file.Name())
err = auth_email_file.Close() }
if err != nil {
t.Fatal("failed to close temp file " + auth_email_file.Name() + // This will close vt.auth_email_file.
": " + err.Error()) func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
defer vt.auth_email_file.Close()
vt.auth_email_file.WriteString(strings.Join(emails, "\n"))
if err := vt.auth_email_file.Close(); err != nil {
t.Fatal("failed to close temp file " +
vt.auth_email_file.Name() + ": " + err.Error())
} }
}
func TestValidatorEmpty(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string(nil))
domains := []string(nil)
validator := NewValidator(domains, vt.auth_email_file.Name())
if validator("foo.bar@example.com") {
t.Error("nothing should validate when the email and " +
"domain lists are empty")
}
}
func TestValidatorSingleEmail(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string{"foo.bar@example.com"})
domains := []string(nil)
validator := NewValidator(domains, vt.auth_email_file.Name())
if !validator("foo.bar@example.com") {
t.Error("email should validate")
}
if validator("baz.quux@example.com") {
t.Error("email from same domain but not in list " +
"should not validate when domain list is empty")
}
}
func TestValidatorSingleDomain(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string(nil))
domains := []string{"example.com"}
validator := NewValidator(domains, vt.auth_email_file.Name())
if !validator("foo.bar@example.com") {
t.Error("email should validate")
}
if !validator("baz.quux@example.com") {
t.Error("email from same domain should validate")
}
}
func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string{
"xyzzy@example.com",
"plugh@example.com",
})
domains := []string{"example0.com", "example1.com"}
validator := NewValidator(domains, vt.auth_email_file.Name())
if !validator("foo.bar@example0.com") {
t.Error("email from first domain should validate")
}
if !validator("baz.quux@example1.com") {
t.Error("email from second domain should validate")
}
if !validator("xyzzy@example.com") {
t.Error("first email in list should validate")
}
if !validator("plugh@example.com") {
t.Error("second email in list should validate")
}
if validator("xyzzy.plugh@example.com") {
t.Error("email not in list that matches no domains " +
"should not validate")
}
}
func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"})
domains := []string{"Frobozz.Com"} domains := []string{"Frobozz.Com"}
validator := NewValidator(domains, auth_email_file.Name()) validator := NewValidator(domains, vt.auth_email_file.Name())
if !validator("foo.bar@example.com") { if !validator("foo.bar@example.com") {
t.Error("loaded email addresses are not lower-cased") t.Error("loaded email addresses are not lower-cased")

View File

@ -0,0 +1,50 @@
// +build go1.3
// +build !plan9,!solaris,!windows
// Turns out you can't copy over an existing file on Windows.
package main
import (
"io/ioutil"
"os"
"testing"
)
func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver(
t *testing.T, emails []string) {
orig_file := vt.auth_email_file
var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil {
t.Fatal("failed to create temp file for copy: " + err.Error())
}
vt.WriteEmails(t, emails)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
if err != nil {
t.Fatal("failed to copy over temp file: " + err.Error())
}
vt.auth_email_file = orig_file
}
func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string{"xyzzy@example.com"})
domains := []string(nil)
updated := make(chan bool)
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
func() { updated <- true })
if !validator("xyzzy@example.com") {
t.Error("email in list should validate")
}
vt.UpdateEmailFileViaCopyingOver(t, []string{"plugh@example.com"})
<-updated
if validator("xyzzy@example.com") {
t.Error("email removed from list should not validate")
}
}

105
validator_watcher_test.go Normal file
View File

@ -0,0 +1,105 @@
// +build go1.3
// +build !plan9,!solaris
package main
import (
"io/ioutil"
"os"
"testing"
)
func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) {
var err error
vt.auth_email_file, err = os.OpenFile(
vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
t.Fatal("failed to re-open temp file for updates")
}
vt.WriteEmails(t, emails)
}
func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace(
t *testing.T, emails []string) {
orig_file := vt.auth_email_file
var err error
vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_")
if err != nil {
t.Fatal("failed to create temp file for rename and replace: " +
err.Error())
}
vt.WriteEmails(t, emails)
moved_name := orig_file.Name() + "-moved"
err = os.Rename(orig_file.Name(), moved_name)
err = os.Rename(vt.auth_email_file.Name(), orig_file.Name())
if err != nil {
t.Fatal("failed to rename and replace temp file: " +
err.Error())
}
vt.auth_email_file = orig_file
os.Remove(moved_name)
}
func TestValidatorOverwriteEmailListDirectly(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string{
"xyzzy@example.com",
"plugh@example.com",
})
domains := []string(nil)
updated := make(chan bool)
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
func() { updated <- true })
if !validator("xyzzy@example.com") {
t.Error("first email in list should validate")
}
if !validator("plugh@example.com") {
t.Error("second email in list should validate")
}
if validator("xyzzy.plugh@example.com") {
t.Error("email not in list that matches no domains " +
"should not validate")
}
vt.UpdateEmailFile(t, []string{
"xyzzy.plugh@example.com",
"plugh@example.com",
})
<-updated
if validator("xyzzy@example.com") {
t.Error("email removed from list should not validate")
}
if !validator("plugh@example.com") {
t.Error("email retained in list should validate")
}
if !validator("xyzzy.plugh@example.com") {
t.Error("email added to list should validate")
}
}
func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) {
vt := NewValidatorTest(t)
defer vt.TearDown()
vt.WriteEmails(t, []string{"xyzzy@example.com"})
domains := []string(nil)
updated := make(chan bool)
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
func() { updated <- true })
if !validator("xyzzy@example.com") {
t.Error("email in list should validate")
}
vt.UpdateEmailFileViaRenameAndReplace(t, []string{"plugh@example.com"})
<-updated
if validator("xyzzy@example.com") {
t.Error("email removed from list should not validate")
}
}

64
watcher.go Normal file
View File

@ -0,0 +1,64 @@
// +build go1.3
// +build !plan9,!solaris
package main
import (
"log"
"os"
"path/filepath"
"time"
"gopkg.in/fsnotify.v1"
)
func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) {
const sleep_interval = 50 * time.Millisecond
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
if event.Op&fsnotify.Chmod != 0 {
time.Sleep(sleep_interval)
}
for {
if _, err := os.Stat(event.Name); err == nil {
if err := watcher.Add(event.Name); err == nil {
log.Printf("watching resumed for %s", event.Name)
return
}
}
time.Sleep(sleep_interval)
}
}
func WatchForUpdates(filename string, action func()) bool {
filename = filepath.Clean(filename)
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal("failed to create watcher for ", filename, ": ", err)
}
go func() {
defer watcher.Close()
for {
select {
case event := <-watcher.Events:
// On Arch Linux, it appears Chmod events precede Remove events,
// which causes a race between action() and the coming Remove event.
// If the Remove wins, the action() (which calls
// UserMap.LoadAuthenticatedEmailsFile()) crashes when the file
// can't be opened.
if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
log.Printf("watching interrupted on event: %s", event)
WaitForReplacement(event, watcher)
}
log.Printf("reloading after event: %s", event)
action()
case err := <-watcher.Errors:
log.Printf("error watching %s: %s", filename, err)
}
}
}()
if err = watcher.Add(filename); err != nil {
log.Fatal("failed to add ", filename, " to watcher: ", err)
}
return true
}

13
watcher_unsupported.go Normal file
View File

@ -0,0 +1,13 @@
// +build go1.1
// +build plan9,solaris
package main
import (
"log"
)
func WatchForUpdates(filename string, action func()) bool {
log.Printf("file watching not implemented on this platform")
return false
}