mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2024-11-24 08:52:25 +02:00
Merge pull request #93 from 18F/watcher-done
Provide graceful shutdown of file watcher in tests
This commit is contained in:
commit
aca1fe81f4
10
validator.go
10
validator.go
@ -15,13 +15,13 @@ type UserMap struct {
|
||||
m unsafe.Pointer
|
||||
}
|
||||
|
||||
func NewUserMap(usersFile string, onUpdate func()) *UserMap {
|
||||
func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
|
||||
um := &UserMap{usersFile: usersFile}
|
||||
m := make(map[string]bool)
|
||||
atomic.StorePointer(&um.m, unsafe.Pointer(&m))
|
||||
if usersFile != "" {
|
||||
log.Printf("using authenticated emails file %s", usersFile)
|
||||
started := WatchForUpdates(usersFile, func() {
|
||||
started := WatchForUpdates(usersFile, done, func() {
|
||||
um.LoadAuthenticatedEmailsFile()
|
||||
onUpdate()
|
||||
})
|
||||
@ -62,8 +62,8 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() {
|
||||
}
|
||||
|
||||
func newValidatorImpl(domains []string, usersFile string,
|
||||
onUpdate func()) func(string) bool {
|
||||
validUsers := NewUserMap(usersFile, onUpdate)
|
||||
done <-chan bool, onUpdate func()) func(string) bool {
|
||||
validUsers := NewUserMap(usersFile, done, onUpdate)
|
||||
|
||||
for i, domain := range domains {
|
||||
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
|
||||
@ -85,5 +85,5 @@ func newValidatorImpl(domains []string, usersFile string,
|
||||
}
|
||||
|
||||
func NewValidator(domains []string, usersFile string) func(string) bool {
|
||||
return newValidatorImpl(domains, usersFile, func() {})
|
||||
return newValidatorImpl(domains, usersFile, nil, func() {})
|
||||
}
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
|
||||
type ValidatorTest struct {
|
||||
auth_email_file *os.File
|
||||
done chan bool
|
||||
update_seen bool
|
||||
}
|
||||
|
||||
func NewValidatorTest(t *testing.T) *ValidatorTest {
|
||||
@ -18,13 +20,26 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
|
||||
if err != nil {
|
||||
t.Fatal("failed to create temp file: " + err.Error())
|
||||
}
|
||||
vt.done = make(chan bool)
|
||||
return vt
|
||||
}
|
||||
|
||||
func (vt *ValidatorTest) TearDown() {
|
||||
vt.done <- true
|
||||
os.Remove(vt.auth_email_file.Name())
|
||||
}
|
||||
|
||||
func (vt *ValidatorTest) NewValidator(domains []string,
|
||||
updated chan<- bool) func(string) bool {
|
||||
return newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
vt.done, func() {
|
||||
if vt.update_seen == false {
|
||||
updated <- true
|
||||
vt.update_seen = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// This will close vt.auth_email_file.
|
||||
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
|
||||
defer vt.auth_email_file.Close()
|
||||
@ -41,7 +56,7 @@ func TestValidatorEmpty(t *testing.T) {
|
||||
|
||||
vt.WriteEmails(t, []string(nil))
|
||||
domains := []string(nil)
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
validator := vt.NewValidator(domains, nil)
|
||||
|
||||
if validator("foo.bar@example.com") {
|
||||
t.Error("nothing should validate when the email and " +
|
||||
@ -55,7 +70,7 @@ func TestValidatorSingleEmail(t *testing.T) {
|
||||
|
||||
vt.WriteEmails(t, []string{"foo.bar@example.com"})
|
||||
domains := []string(nil)
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
validator := vt.NewValidator(domains, nil)
|
||||
|
||||
if !validator("foo.bar@example.com") {
|
||||
t.Error("email should validate")
|
||||
@ -72,7 +87,7 @@ func TestValidatorSingleDomain(t *testing.T) {
|
||||
|
||||
vt.WriteEmails(t, []string(nil))
|
||||
domains := []string{"example.com"}
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
validator := vt.NewValidator(domains, nil)
|
||||
|
||||
if !validator("foo.bar@example.com") {
|
||||
t.Error("email should validate")
|
||||
@ -91,7 +106,7 @@ func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) {
|
||||
"plugh@example.com",
|
||||
})
|
||||
domains := []string{"example0.com", "example1.com"}
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
validator := vt.NewValidator(domains, nil)
|
||||
|
||||
if !validator("foo.bar@example0.com") {
|
||||
t.Error("email from first domain should validate")
|
||||
@ -117,7 +132,7 @@ func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) {
|
||||
|
||||
vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"})
|
||||
domains := []string{"Frobozz.Com"}
|
||||
validator := NewValidator(domains, vt.auth_email_file.Name())
|
||||
validator := vt.NewValidator(domains, nil)
|
||||
|
||||
if !validator("foo.bar@example.com") {
|
||||
t.Error("loaded email addresses are not lower-cased")
|
||||
|
@ -34,8 +34,7 @@ func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {
|
||||
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
||||
domains := []string(nil)
|
||||
updated := make(chan bool)
|
||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
func() { updated <- true })
|
||||
validator := vt.NewValidator(domains, updated)
|
||||
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("email in list should validate")
|
||||
|
@ -51,8 +51,7 @@ func TestValidatorOverwriteEmailListDirectly(t *testing.T) {
|
||||
})
|
||||
domains := []string(nil)
|
||||
updated := make(chan bool)
|
||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
func() { updated <- true })
|
||||
validator := vt.NewValidator(domains, updated)
|
||||
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("first email in list should validate")
|
||||
@ -89,8 +88,7 @@ func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) {
|
||||
vt.WriteEmails(t, []string{"xyzzy@example.com"})
|
||||
domains := []string(nil)
|
||||
updated := make(chan bool)
|
||||
validator := newValidatorImpl(domains, vt.auth_email_file.Name(),
|
||||
func() { updated <- true })
|
||||
validator := vt.NewValidator(domains, updated)
|
||||
|
||||
if !validator("xyzzy@example.com") {
|
||||
t.Error("email in list should validate")
|
||||
|
20
watcher.go
20
watcher.go
@ -12,17 +12,18 @@ import (
|
||||
"gopkg.in/fsnotify.v1"
|
||||
)
|
||||
|
||||
func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) {
|
||||
func WaitForReplacement(filename string, op fsnotify.Op,
|
||||
watcher *fsnotify.Watcher) {
|
||||
const sleep_interval = 50 * time.Millisecond
|
||||
|
||||
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
|
||||
if event.Op&fsnotify.Chmod != 0 {
|
||||
if op&fsnotify.Chmod != 0 {
|
||||
time.Sleep(sleep_interval)
|
||||
}
|
||||
for {
|
||||
if _, err := os.Stat(event.Name); err == nil {
|
||||
if err := watcher.Add(event.Name); err == nil {
|
||||
log.Printf("watching resumed for %s", event.Name)
|
||||
if _, err := os.Stat(filename); err == nil {
|
||||
if err := watcher.Add(filename); err == nil {
|
||||
log.Printf("watching resumed for %s", filename)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -30,7 +31,7 @@ func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) {
|
||||
}
|
||||
}
|
||||
|
||||
func WatchForUpdates(filename string, action func()) bool {
|
||||
func WatchForUpdates(filename string, done <-chan bool, action func()) bool {
|
||||
filename = filepath.Clean(filename)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
@ -40,6 +41,10 @@ func WatchForUpdates(filename string, action func()) bool {
|
||||
defer watcher.Close()
|
||||
for {
|
||||
select {
|
||||
case _ = <-done:
|
||||
log.Printf("Shutting down watcher for: %s",
|
||||
filename)
|
||||
return
|
||||
case event := <-watcher.Events:
|
||||
// On Arch Linux, it appears Chmod events precede Remove events,
|
||||
// which causes a race between action() and the coming Remove event.
|
||||
@ -48,7 +53,8 @@ func WatchForUpdates(filename string, action func()) bool {
|
||||
// can't be opened.
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
|
||||
log.Printf("watching interrupted on event: %s", event)
|
||||
WaitForReplacement(event, watcher)
|
||||
watcher.Remove(filename)
|
||||
WaitForReplacement(filename, event.Op, watcher)
|
||||
}
|
||||
log.Printf("reloading after event: %s", event)
|
||||
action()
|
||||
|
Loading…
Reference in New Issue
Block a user