You've already forked oauth2-proxy
mirror of
https://github.com/oauth2-proxy/oauth2-proxy.git
synced 2025-08-10 22:51:31 +02:00
Watch the htpasswd
file for changes and update the htpasswdMap
(#1701)
* dynamically update the htpasswdMap based on the changes made to the htpasswd file * added tests to validate that htpasswdMap is updated after the htpasswd file is changed * refactored `htpasswd` and `watcher` to lower cognitive complexity * returned errors and refactored tests * added `CHANGELOG.md` entry for #1701 and fixed the codeclimate issue * Apply suggestions from code review Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk> * Fix lint issue from code suggestion * Wrap htpasswd load and watch errors with context * add the htpasswd wrapped error context to the test Co-authored-by: Joel Speed <Joel.speed@hotmail.co.uk>
This commit is contained in:
committed by
GitHub
parent
fcecbeb13c
commit
037cb041d3
@@ -17,6 +17,8 @@ N/A
|
|||||||
|
|
||||||
- [#1669](https://github.com/oauth2-proxy/oauth2-proxy/pull/1699) Fix method deprecated error in lint (@t-katsumura)
|
- [#1669](https://github.com/oauth2-proxy/oauth2-proxy/pull/1699) Fix method deprecated error in lint (@t-katsumura)
|
||||||
|
|
||||||
|
- [#1701](https://github.com/oauth2-proxy/oauth2-proxy/pull/1701) Watch the htpasswd file for changes and update the htpasswdMap (@aiciobanu)
|
||||||
|
|
||||||
- [#1709](https://github.com/oauth2-proxy/oauth2-proxy/pull/1709) Show an alert message when basic auth credentials are invalid (@aiciobanu)
|
- [#1709](https://github.com/oauth2-proxy/oauth2-proxy/pull/1709) Show an alert message when basic auth credentials are invalid (@aiciobanu)
|
||||||
- [#1723](https://github.com/oauth2-proxy/oauth2-proxy/pull/1723) Added ability to specify allowed TLS cipher suites. (@crbednarz)
|
- [#1723](https://github.com/oauth2-proxy/oauth2-proxy/pull/1723) Added ability to specify allowed TLS cipher suites. (@crbednarz)
|
||||||
|
|
||||||
|
@@ -113,7 +113,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
|
|||||||
var err error
|
var err error
|
||||||
basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile)
|
basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not load htpasswdfile: %v", err)
|
return nil, fmt.Errorf("could not validate htpasswd: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -8,8 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,6 +19,7 @@ import (
|
|||||||
// Passwords must be generated with -B for bcrypt or -s for SHA1.
|
// Passwords must be generated with -B for bcrypt or -s for SHA1.
|
||||||
type htpasswdMap struct {
|
type htpasswdMap struct {
|
||||||
users map[string]interface{}
|
users map[string]interface{}
|
||||||
|
rwm sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// bcryptPass is used to identify bcrypt passwords in the
|
// bcryptPass is used to identify bcrypt passwords in the
|
||||||
@@ -30,10 +33,30 @@ type sha1Pass string
|
|||||||
// NewHTPasswdValidator constructs an httpasswd based validator from the file
|
// NewHTPasswdValidator constructs an httpasswd based validator from the file
|
||||||
// at the path given.
|
// at the path given.
|
||||||
func NewHTPasswdValidator(path string) (Validator, error) {
|
func NewHTPasswdValidator(path string) (Validator, error) {
|
||||||
|
h := &htpasswdMap{users: make(map[string]interface{})}
|
||||||
|
|
||||||
|
if err := h.loadHTPasswdFile(path); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not load htpasswd file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := watcher.WatchFileForUpdates(path, nil, func() {
|
||||||
|
err := h.loadHTPasswdFile(path)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf("%v: no changes were made to the current htpasswd map", err)
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not watch htpasswd file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadHTPasswdFile loads htpasswd entries from an io.Reader (an opened file) into a htpasswdMap.
|
||||||
|
func (h *htpasswdMap) loadHTPasswdFile(filename string) error {
|
||||||
// We allow HTPasswd location via config options
|
// We allow HTPasswd location via config options
|
||||||
r, err := os.Open(path) // #nosec G304
|
r, err := os.Open(filename) // #nosec G304
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not open htpasswd file: %v", err)
|
return fmt.Errorf("could not open htpasswd file: %v", err)
|
||||||
}
|
}
|
||||||
defer func(c io.Closer) {
|
defer func(c io.Closer) {
|
||||||
cerr := c.Close()
|
cerr := c.Close()
|
||||||
@@ -41,48 +64,81 @@ func NewHTPasswdValidator(path string) (Validator, error) {
|
|||||||
logger.Fatalf("error closing the htpasswd file: %v", cerr)
|
logger.Fatalf("error closing the htpasswd file: %v", cerr)
|
||||||
}
|
}
|
||||||
}(r)
|
}(r)
|
||||||
return newHtpasswd(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// newHtpasswd consctructs an htpasswd from an io.Reader (an opened file).
|
csvReader := csv.NewReader(r)
|
||||||
func newHtpasswd(file io.Reader) (*htpasswdMap, error) {
|
|
||||||
csvReader := csv.NewReader(file)
|
|
||||||
csvReader.Comma = ':'
|
csvReader.Comma = ':'
|
||||||
csvReader.Comment = '#'
|
csvReader.Comment = '#'
|
||||||
csvReader.TrimLeadingSpace = true
|
csvReader.TrimLeadingSpace = true
|
||||||
|
|
||||||
records, err := csvReader.ReadAll()
|
records, err := csvReader.ReadAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not read htpasswd file: %v", err)
|
return fmt.Errorf("could not read htpasswd file: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return createHtpasswdMap(records)
|
updated, err := createHtpasswdMap(records)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("htpasswd entries error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.rwm.Lock()
|
||||||
|
h.users = updated.users
|
||||||
|
h.rwm.Unlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createHtasswdMap constructs an htpasswdMap from the given records
|
// createHtasswdMap constructs an htpasswdMap from the given records
|
||||||
func createHtpasswdMap(records [][]string) (*htpasswdMap, error) {
|
func createHtpasswdMap(records [][]string) (*htpasswdMap, error) {
|
||||||
h := &htpasswdMap{users: make(map[string]interface{})}
|
h := &htpasswdMap{users: make(map[string]interface{})}
|
||||||
|
var invalidRecords, invalidEntries []string
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
user, realPassword := record[0], record[1]
|
// If a record is invalid or malformed don't panic with index out of range,
|
||||||
shaPrefix := realPassword[:5]
|
// return a formatted error.
|
||||||
if shaPrefix == "{SHA}" {
|
lr := len(record)
|
||||||
h.users[user] = sha1Pass(realPassword[5:])
|
switch {
|
||||||
continue
|
case lr == 2:
|
||||||
|
user, realPassword := record[0], record[1]
|
||||||
|
invalidEntries = passShaOrBcrypt(h, user, realPassword)
|
||||||
|
case lr == 1, lr > 2:
|
||||||
|
invalidRecords = append(invalidRecords, record[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
bcryptPrefix := realPassword[:4]
|
|
||||||
if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" {
|
|
||||||
h.users[user] = bcryptPass(realPassword)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Password is neither sha1 or bcrypt
|
|
||||||
// TODO(JoelSpeed): In the next breaking release, make this return an error.
|
|
||||||
logger.Errorf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(invalidRecords) > 0 {
|
||||||
|
return h, fmt.Errorf("invalid htpasswd record(s) %+q", invalidRecords)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(invalidEntries) > 0 {
|
||||||
|
return h, fmt.Errorf("'%+q' user(s) could not be added: invalid password, must be a SHA or bcrypt entry", invalidEntries)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(h.users) == 0 {
|
||||||
|
return nil, fmt.Errorf("could not construct htpasswdMap: htpasswd file doesn't contain a single valid user entry")
|
||||||
|
}
|
||||||
|
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// passShaOrBcrypt checks if a htpasswd entry is valid and the password is encrypted with SHA or bcrypt.
|
||||||
|
// Valid user entries are saved in the htpasswdMap, invalid records are reurned.
|
||||||
|
func passShaOrBcrypt(h *htpasswdMap, user, password string) (invalidEntries []string) {
|
||||||
|
passLen := len(password)
|
||||||
|
switch {
|
||||||
|
case passLen > 6 && password[:5] == "{SHA}":
|
||||||
|
h.users[user] = sha1Pass(password[5:])
|
||||||
|
case passLen > 5 &&
|
||||||
|
(password[:4] == "$2b$" ||
|
||||||
|
password[:4] == "$2y$" ||
|
||||||
|
password[:4] == "$2x$" ||
|
||||||
|
password[:4] == "$2a$"):
|
||||||
|
h.users[user] = bcryptPass(password)
|
||||||
|
default:
|
||||||
|
invalidEntries = append(invalidEntries, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
return invalidEntries
|
||||||
|
}
|
||||||
|
|
||||||
// Validate checks a users password against the htpasswd entries
|
// Validate checks a users password against the htpasswd entries
|
||||||
func (h *htpasswdMap) Validate(user string, password string) bool {
|
func (h *htpasswdMap) Validate(user string, password string) bool {
|
||||||
realPassword, exists := h.users[user]
|
realPassword, exists := h.users[user]
|
||||||
|
@@ -1,8 +1,11 @@
|
|||||||
package basic
|
package basic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
. "github.com/onsi/gomega/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -84,13 +87,89 @@ var _ = Describe("HTPasswd Suite", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
It("returns an error", func() {
|
It("returns an error", func() {
|
||||||
Expect(err).To(MatchError("could not open htpasswd file: open ./test/htpasswd-doesnt-exist.txt: no such file or directory"))
|
Expect(err).To(MatchError("could not load htpasswd file: could not open htpasswd file: open ./test/htpasswd-doesnt-exist.txt: no such file or directory"))
|
||||||
})
|
})
|
||||||
|
|
||||||
It("returns a nil validator", func() {
|
It("returns a nil validator", func() {
|
||||||
Expect(validator).To(BeNil())
|
Expect(validator).To(BeNil())
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Context("htpasswd file is updated", func() {
|
||||||
|
const filePathPrefix = "htpasswd-file-updated-"
|
||||||
|
const adminUserHtpasswdEntry = "admin:$2y$05$SXWrNM7ldtbRzBvUC3VXyOvUeiUcP45XPwM93P5eeGOEPIiAZmJjC"
|
||||||
|
const user1HtpasswdEntry = "user1:$2y$05$/sZYJOk8.3Etg4V6fV7puuXfCJLmV5Q7u3xvKpjBSJUka.t2YtmmG"
|
||||||
|
var fileNames []string
|
||||||
|
|
||||||
|
AfterSuite(func() {
|
||||||
|
for _, v := range fileNames {
|
||||||
|
err := os.Remove(v)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
type htpasswdUpdate struct {
|
||||||
|
testText string
|
||||||
|
remove bool
|
||||||
|
expectedLen int
|
||||||
|
expectedGomegaMatcher GomegaMatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
assertHtpasswdMapUpdate := func(hu htpasswdUpdate) {
|
||||||
|
var validator Validator
|
||||||
|
var file *os.File
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Create a temporary file with at least one entry
|
||||||
|
file, err = os.CreateTemp("", filePathPrefix)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
_, err = file.WriteString(adminUserHtpasswdEntry + "\n")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
validator, err = NewHTPasswdValidator(file.Name())
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
htpasswd, ok := validator.(*htpasswdMap)
|
||||||
|
Expect(ok).To(BeTrue())
|
||||||
|
|
||||||
|
if hu.remove {
|
||||||
|
// Overwrite the original file with another entry
|
||||||
|
err = os.WriteFile(file.Name(), []byte(user1HtpasswdEntry+"\n"), 0644)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
} else {
|
||||||
|
// Add another entry to the original file in append mode
|
||||||
|
_, err = file.WriteString(user1HtpasswdEntry + "\n")
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = file.Close()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
fileNames = append(fileNames, file.Name())
|
||||||
|
|
||||||
|
It("has the correct number of users", func() {
|
||||||
|
Expect(len(htpasswd.users)).To(Equal(hu.expectedLen))
|
||||||
|
})
|
||||||
|
|
||||||
|
It(hu.testText, func() {
|
||||||
|
Expect(htpasswd.Validate(adminUser, adminPassword)).To(hu.expectedGomegaMatcher)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("new entry is present", func() {
|
||||||
|
Expect(htpasswd.Validate(user1, user1Password)).To(BeTrue())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
Context("htpasswd entry is added", func() {
|
||||||
|
assertHtpasswdMapUpdate(htpasswdUpdate{"initial entry is present", false, 2, BeTrue()})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("htpasswd entry is removed", func() {
|
||||||
|
assertHtpasswdMapUpdate(htpasswdUpdate{"initial entry is removed", true, 1, BeFalse()})
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
81
pkg/watcher/watcher.go
Normal file
81
pkg/watcher/watcher.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WatchFileForUpdates performs an action every time a file on disk is updated
|
||||||
|
func WatchFileForUpdates(filename string, done <-chan bool, action func()) error {
|
||||||
|
filename = filepath.Clean(filename)
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create watcher for '%s': %s", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer watcher.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.Printf("shutting down watcher for: %s", filename)
|
||||||
|
return
|
||||||
|
case event := <-watcher.Events:
|
||||||
|
filterEvent(watcher, event, filename, action)
|
||||||
|
case err = <-watcher.Errors:
|
||||||
|
logger.Errorf("error watching '%s': %s", filename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := watcher.Add(filename); err != nil {
|
||||||
|
return fmt.Errorf("failed to add '%s' to watcher: %v", filename, err)
|
||||||
|
}
|
||||||
|
logger.Printf("watching '%s' for updates", filename)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter file operations based on the events sent by the watcher.
|
||||||
|
// Execute the action() function when the following conditions are met:
|
||||||
|
// - the real path of the file was changed (Kubernetes ConfigMap/Secret)
|
||||||
|
// - the file is modified or created
|
||||||
|
func filterEvent(watcher *fsnotify.Watcher, event fsnotify.Event, filename string, action func()) {
|
||||||
|
switch filepath.Clean(event.Name) == filename {
|
||||||
|
// In Kubernetes the file path is a symlink, so we should take action
|
||||||
|
// when the ConfigMap/Secret is replaced.
|
||||||
|
case event.Op&fsnotify.Remove != 0:
|
||||||
|
logger.Printf("watching interrupted on event: %s", event)
|
||||||
|
WaitForReplacement(filename, event.Op, watcher)
|
||||||
|
action()
|
||||||
|
case event.Op&(fsnotify.Create|fsnotify.Write) != 0:
|
||||||
|
logger.Printf("reloading after event: %s", event)
|
||||||
|
action()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForReplacement waits for a file to exist on disk and then starts a watch
|
||||||
|
// for the file
|
||||||
|
func WaitForReplacement(filename string, op fsnotify.Op, watcher *fsnotify.Watcher) {
|
||||||
|
const sleepInterval = 50 * time.Millisecond
|
||||||
|
|
||||||
|
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
|
||||||
|
if op&fsnotify.Chmod != 0 {
|
||||||
|
time.Sleep(sleepInterval)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
if _, err := os.Stat(filename); err == nil {
|
||||||
|
if err := watcher.Add(filename); err == nil {
|
||||||
|
logger.Printf("watching resumed for '%s'", filename)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(sleepInterval)
|
||||||
|
}
|
||||||
|
}
|
@@ -9,6 +9,7 @@ import (
|
|||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
||||||
|
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/watcher"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserMap holds information from the authenticated emails file
|
// UserMap holds information from the authenticated emails file
|
||||||
@@ -26,7 +27,7 @@ func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
|
|||||||
atomic.StorePointer(&um.m, unsafe.Pointer(&m)) // #nosec G103
|
atomic.StorePointer(&um.m, unsafe.Pointer(&m)) // #nosec G103
|
||||||
if usersFile != "" {
|
if usersFile != "" {
|
||||||
logger.Printf("using authenticated emails file %s", usersFile)
|
logger.Printf("using authenticated emails file %s", usersFile)
|
||||||
WatchForUpdates(usersFile, done, func() {
|
watcher.WatchFileForUpdates(usersFile, done, func() {
|
||||||
um.LoadAuthenticatedEmailsFile()
|
um.LoadAuthenticatedEmailsFile()
|
||||||
onUpdate()
|
onUpdate()
|
||||||
})
|
})
|
||||||
|
81
watcher.go
81
watcher.go
@@ -1,81 +0,0 @@
|
|||||||
//go:build go1.3 && !plan9 && !solaris
|
|
||||||
// +build go1.3,!plan9,!solaris
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
|
|
||||||
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WaitForReplacement waits for a file to exist on disk and then starts a watch
|
|
||||||
// for the file
|
|
||||||
func WaitForReplacement(filename string, op fsnotify.Op,
|
|
||||||
watcher *fsnotify.Watcher) {
|
|
||||||
const sleepInterval = 50 * time.Millisecond
|
|
||||||
|
|
||||||
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
|
|
||||||
if op&fsnotify.Chmod != 0 {
|
|
||||||
time.Sleep(sleepInterval)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
if _, err := os.Stat(filename); err == nil {
|
|
||||||
if err := watcher.Add(filename); err == nil {
|
|
||||||
logger.Printf("watching resumed for %s", filename)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(sleepInterval)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WatchForUpdates performs an action every time a file on disk is updated
|
|
||||||
func WatchForUpdates(filename string, done <-chan bool, action func()) {
|
|
||||||
filename = filepath.Clean(filename)
|
|
||||||
watcher, err := fsnotify.NewWatcher()
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("failed to create watcher for ", filename, ": ", err)
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
defer func(w *fsnotify.Watcher) {
|
|
||||||
cerr := w.Close()
|
|
||||||
if cerr != nil {
|
|
||||||
logger.Fatalf("error closing watcher: %v", err)
|
|
||||||
}
|
|
||||||
}(watcher)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
logger.Printf("Shutting down watcher for: %s", filename)
|
|
||||||
return
|
|
||||||
case event := <-watcher.Events:
|
|
||||||
// On Arch Linux, it appears Chmod events precede Remove events,
|
|
||||||
// which causes a race between action() and the coming Remove event.
|
|
||||||
// If the Remove wins, the action() (which calls
|
|
||||||
// UserMap.LoadAuthenticatedEmailsFile()) crashes when the file
|
|
||||||
// can't be opened.
|
|
||||||
if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
|
|
||||||
logger.Printf("watching interrupted on event: %s", event)
|
|
||||||
err = watcher.Remove(filename)
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("error removing watcher on %s: %v", filename, err)
|
|
||||||
}
|
|
||||||
WaitForReplacement(filename, event.Op, watcher)
|
|
||||||
}
|
|
||||||
logger.Printf("reloading after event: %s", event)
|
|
||||||
action()
|
|
||||||
case err = <-watcher.Errors:
|
|
||||||
logger.Errorf("error watching %s: %s", filename, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if err = watcher.Add(filename); err != nil {
|
|
||||||
logger.Fatal("failed to add ", filename, " to watcher: ", err)
|
|
||||||
}
|
|
||||||
logger.Printf("watching %s for updates", filename)
|
|
||||||
}
|
|
@@ -1,11 +0,0 @@
|
|||||||
//go:build !go1.3 || plan9 || solaris
|
|
||||||
// +build !go1.3 plan9 solaris
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
|
|
||||||
|
|
||||||
func WatchForUpdates(filename string, done <-chan bool, action func()) {
|
|
||||||
logger.Errorf("file watching not implemented on this platform")
|
|
||||||
go func() { <-done }()
|
|
||||||
}
|
|
Reference in New Issue
Block a user