mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
Provide graceful shutdown of file watcher in tests
This test failure from #92 inspired this change: https://travis-ci.org/bitly/google_auth_proxy/jobs/62425336 2015/05/13 16:27:33 using authenticated emails file /tmp/test_auth_emails_952353477 2015/05/13 16:27:33 watching /tmp/test_auth_emails_952353477 for updates 2015/05/13 16:27:33 validating: is xyzzy@example.com valid? true 2015/05/13 16:27:33 watching interrupted on event: "/tmp/test_auth_emails_952353477": CHMOD 2015/05/13 16:27:33 watching resumed for /tmp/test_auth_emails_952353477 2015/05/13 16:27:33 reloading after event: "/tmp/test_auth_emails_952353477": CHMOD 2015/05/13 16:27:33 watching interrupted on event: "/tmp/test_auth_emails_952353477": REMOVE 2015/05/13 16:27:33 validating: is xyzzy@example.com valid? false 2015/05/13 16:27:33 watching resumed for /tmp/test_auth_emails_952353477 2015/05/13 16:27:33 reloading after event: "/tmp/test_auth_emails_952353477": REMOVE 2015/05/13 16:27:33 failed opening authenticated-emails-file="/tmp/test_auth_emails_952353477", open /tmp/test_auth_emails_952353477: no such file or directory I believe that what happened was that the call to reload the file after the second "reloading after event" lost the race when the test shut down and the file was removed. This change introduces a `done` channel that ensures outstanding actions complete and the watcher exits before the test removes the file.
This commit is contained in:
parent
254b26d4a0
commit
6a0f119fc2
10
validator.go
10
validator.go
@ -15,13 +15,13 @@ type UserMap struct {
|
|||||||
m unsafe.Pointer
|
m unsafe.Pointer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserMap(usersFile string, onUpdate func()) *UserMap {
|
func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
|
||||||
um := &UserMap{usersFile: usersFile}
|
um := &UserMap{usersFile: usersFile}
|
||||||
m := make(map[string]bool)
|
m := make(map[string]bool)
|
||||||
atomic.StorePointer(&um.m, unsafe.Pointer(&m))
|
atomic.StorePointer(&um.m, unsafe.Pointer(&m))
|
||||||
if usersFile != "" {
|
if usersFile != "" {
|
||||||
log.Printf("using authenticated emails file %s", usersFile)
|
log.Printf("using authenticated emails file %s", usersFile)
|
||||||
started := WatchForUpdates(usersFile, func() {
|
started := WatchForUpdates(usersFile, done, func() {
|
||||||
um.LoadAuthenticatedEmailsFile()
|
um.LoadAuthenticatedEmailsFile()
|
||||||
onUpdate()
|
onUpdate()
|
||||||
})
|
})
|
||||||
@ -62,8 +62,8 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newValidatorImpl(domains []string, usersFile string,
|
func newValidatorImpl(domains []string, usersFile string,
|
||||||
onUpdate func()) func(string) bool {
|
done <-chan bool, onUpdate func()) func(string) bool {
|
||||||
validUsers := NewUserMap(usersFile, onUpdate)
|
validUsers := NewUserMap(usersFile, done, onUpdate)
|
||||||
|
|
||||||
for i, domain := range domains {
|
for i, domain := range domains {
|
||||||
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
|
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
|
||||||
@ -85,5 +85,5 @@ func newValidatorImpl(domains []string, usersFile string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewValidator(domains []string, usersFile string) func(string) bool {
|
func NewValidator(domains []string, usersFile string) func(string) bool {
|
||||||
return newValidatorImpl(domains, usersFile, func() {})
|
return newValidatorImpl(domains, usersFile, nil, func() {})
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
type ValidatorTest struct {
|
type ValidatorTest struct {
|
||||||
auth_email_file *os.File
|
auth_email_file *os.File
|
||||||
|
done chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewValidatorTest(t *testing.T) *ValidatorTest {
|
func NewValidatorTest(t *testing.T) *ValidatorTest {
|
||||||
@ -18,13 +19,21 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to create temp file: " + err.Error())
|
t.Fatal("failed to create temp file: " + err.Error())
|
||||||
}
|
}
|
||||||
|
vt.done = make(chan bool)
|
||||||
return vt
|
return vt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vt *ValidatorTest) TearDown() {
|
func (vt *ValidatorTest) TearDown() {
|
||||||
|
vt.done <- true
|
||||||
os.Remove(vt.auth_email_file.Name())
|
os.Remove(vt.auth_email_file.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (vt *ValidatorTest) NewValidator(domains []string,
|
||||||
|
updated chan<- bool) func(string) bool {
|
||||||
|
return newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||||
|
vt.done, func() { updated <- true })
|
||||||
|
}
|
||||||
|
|
||||||
// This will close vt.auth_email_file.
|
// This will close vt.auth_email_file.
|
||||||
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
|
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
|
||||||
defer vt.auth_email_file.Close()
|
defer vt.auth_email_file.Close()
|
||||||
@ -41,7 +50,7 @@ func TestValidatorEmpty(t *testing.T) {
|
|||||||
|
|
||||||
vt.WriteEmails(t, []string(nil))
|
vt.WriteEmails(t, []string(nil))
|
||||||
domains := []string(nil)
|
domains := []string(nil)
|
||||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
validator := vt.NewValidator(domains, nil)
|
||||||
|
|
||||||
if validator("foo.bar@example.com") {
|
if validator("foo.bar@example.com") {
|
||||||
t.Error("nothing should validate when the email and " +
|
t.Error("nothing should validate when the email and " +
|
||||||
@ -55,7 +64,7 @@ func TestValidatorSingleEmail(t *testing.T) {
|
|||||||
|
|
||||||
vt.WriteEmails(t, []string{"foo.bar@example.com"})
|
vt.WriteEmails(t, []string{"foo.bar@example.com"})
|
||||||
domains := []string(nil)
|
domains := []string(nil)
|
||||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
validator := vt.NewValidator(domains, nil)
|
||||||
|
|
||||||
if !validator("foo.bar@example.com") {
|
if !validator("foo.bar@example.com") {
|
||||||
t.Error("email should validate")
|
t.Error("email should validate")
|
||||||
@ -72,7 +81,7 @@ func TestValidatorSingleDomain(t *testing.T) {
|
|||||||
|
|
||||||
vt.WriteEmails(t, []string(nil))
|
vt.WriteEmails(t, []string(nil))
|
||||||
domains := []string{"example.com"}
|
domains := []string{"example.com"}
|
||||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
validator := vt.NewValidator(domains, nil)
|
||||||
|
|
||||||
if !validator("foo.bar@example.com") {
|
if !validator("foo.bar@example.com") {
|
||||||
t.Error("email should validate")
|
t.Error("email should validate")
|
||||||
@ -91,7 +100,7 @@ func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) {
|
|||||||
"plugh@example.com",
|
"plugh@example.com",
|
||||||
})
|
})
|
||||||
domains := []string{"example0.com", "example1.com"}
|
domains := []string{"example0.com", "example1.com"}
|
||||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
validator := vt.NewValidator(domains, nil)
|
||||||
|
|
||||||
if !validator("foo.bar@example0.com") {
|
if !validator("foo.bar@example0.com") {
|
||||||
t.Error("email from first domain should validate")
|
t.Error("email from first domain should validate")
|
||||||
@ -117,7 +126,7 @@ func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) {
|
|||||||
|
|
||||||
vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"})
|
vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"})
|
||||||
domains := []string{"Frobozz.Com"}
|
domains := []string{"Frobozz.Com"}
|
||||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
validator := vt.NewValidator(domains, nil)
|
||||||
|
|
||||||
if !validator("foo.bar@example.com") {
|
if !validator("foo.bar@example.com") {
|
||||||
t.Error("loaded email addresses are not lower-cased")
|
t.Error("loaded email addresses are not lower-cased")
|
||||||
|
@ -34,8 +34,7 @@ func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {
|
|||||||
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
||||||
domains := []string(nil)
|
domains := []string(nil)
|
||||||
updated := make(chan bool)
|
updated := make(chan bool)
|
||||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
validator := vt.NewValidator(domains, updated)
|
||||||
func() { updated <- true })
|
|
||||||
|
|
||||||
if !validator("xyzzy@example.com") {
|
if !validator("xyzzy@example.com") {
|
||||||
t.Error("email in list should validate")
|
t.Error("email in list should validate")
|
||||||
|
@ -51,8 +51,7 @@ func TestValidatorOverwriteEmailListDirectly(t *testing.T) {
|
|||||||
})
|
})
|
||||||
domains := []string(nil)
|
domains := []string(nil)
|
||||||
updated := make(chan bool)
|
updated := make(chan bool)
|
||||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
validator := vt.NewValidator(domains, updated)
|
||||||
func() { updated <- true })
|
|
||||||
|
|
||||||
if !validator("xyzzy@example.com") {
|
if !validator("xyzzy@example.com") {
|
||||||
t.Error("first email in list should validate")
|
t.Error("first email in list should validate")
|
||||||
@ -89,8 +88,7 @@ func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) {
|
|||||||
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
||||||
domains := []string(nil)
|
domains := []string(nil)
|
||||||
updated := make(chan bool)
|
updated := make(chan bool)
|
||||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
validator := vt.NewValidator(domains, updated)
|
||||||
func() { updated <- true })
|
|
||||||
|
|
||||||
if !validator("xyzzy@example.com") {
|
if !validator("xyzzy@example.com") {
|
||||||
t.Error("email in list should validate")
|
t.Error("email in list should validate")
|
||||||
|
@ -30,7 +30,7 @@ func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WatchForUpdates(filename string, action func()) bool {
|
func WatchForUpdates(filename string, done <-chan bool, action func()) bool {
|
||||||
filename = filepath.Clean(filename)
|
filename = filepath.Clean(filename)
|
||||||
watcher, err := fsnotify.NewWatcher()
|
watcher, err := fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -40,6 +40,10 @@ func WatchForUpdates(filename string, action func()) bool {
|
|||||||
defer watcher.Close()
|
defer watcher.Close()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case _ = <-done:
|
||||||
|
log.Printf("Shutting down watcher for: %s",
|
||||||
|
filename)
|
||||||
|
return
|
||||||
case event := <-watcher.Events:
|
case event := <-watcher.Events:
|
||||||
// On Arch Linux, it appears Chmod events precede Remove events,
|
// On Arch Linux, it appears Chmod events precede Remove events,
|
||||||
// which causes a race between action() and the coming Remove event.
|
// which causes a race between action() and the coming Remove event.
|
||||||
|
Loading…
Reference in New Issue
Block a user