diff --git a/config.go b/config.go index 1da34b0..658a41d 100644 --- a/config.go +++ b/config.go @@ -1,7 +1,6 @@ package authboss import ( - "io" "time" ) @@ -108,8 +107,10 @@ type Config struct { // Mailer is the mailer being used to send e-mails out via smtp Mailer Mailer - // LogWriter is written to when errors occur - LogWriter io.Writer + // Logger implies just a few log levels for use, can optionally + // also implement the ContextLogger to be able to upgrade to a + // request specific logger. + Logger Logger } } diff --git a/defaults/error_handler.go b/defaults/error_handler.go new file mode 100644 index 0000000..bd6e4fa --- /dev/null +++ b/defaults/error_handler.go @@ -0,0 +1,39 @@ +package defaults + +import ( + "fmt" + "io" + "net/http" +) + +// ErrorHandler wraps http handlers with errors with itself +// to provide error handling. +// +// The pieces provided to this struct must be thread-safe +// since they will be handed to many pointers to themselves. +type ErrorHandler struct { + LogWriter io.Writer +} + +// Wrap an http handler with an error +func (e ErrorHandler) Wrap(handler func(w http.ResponseWriter, r *http.Request) error) http.Handler { + return errorHandler{ + Handler: handler, + LogWriter: e.LogWriter, + } +} + +type errorHandler struct { + Handler func(w http.ResponseWriter, r *http.Request) error + LogWriter io.Writer +} + +// ServeHTTP handles errors +func (e errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + err := e.Handler(w, r) + if err == nil { + return + } + + fmt.Fprintf(e.LogWriter, "error at %s: %+v", r.URL.String(), err) +} diff --git a/defaults/error_handler_test.go b/defaults/error_handler_test.go new file mode 100644 index 0000000..c184559 --- /dev/null +++ b/defaults/error_handler_test.go @@ -0,0 +1,31 @@ +package defaults + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/pkg/errors" +) + +func TestErrorHandler(t *testing.T) { + t.Parallel() + + b := &bytes.Buffer{} + + eh := ErrorHandler{LogWriter: b} + + handler := eh.Wrap(func(w http.ResponseWriter, r *http.Request) error { + return errors.New("error occurred") + }) + // Assert that it's the right type + var _ http.Handler = handler + + handler.ServeHTTP(nil, httptest.NewRequest("GET", "/target", nil)) + + if !strings.Contains(b.String(), "error at /target: error occurred") { + t.Error("output was wrong:", b.String()) + } +} diff --git a/defaults/logger.go b/defaults/logger.go new file mode 100644 index 0000000..dfb2173 --- /dev/null +++ b/defaults/logger.go @@ -0,0 +1,29 @@ +package defaults + +import ( + "fmt" + "io" + "time" +) + +// Logger writes exactly once for each log line to underlying io.Writer +// that's passed in and ends each message with a newline. +// It has RFC3339 as a date format, and emits a log level. +type Logger struct { + Writer io.Writer +} + +// NewLogger creates a new logger from an io.Writer +func NewLogger(writer io.Writer) Logger { + return Logger{Writer: writer} +} + +// Info logs go here +func (l Logger) Info(s string) { + fmt.Fprintf(l.Writer, "%s [INFO]: %s\n", time.Now().UTC().Format(time.RFC3339), s) +} + +// Error logs go here +func (l Logger) Error(s string) { + fmt.Fprintf(l.Writer, "%s [EROR]: %s\n", time.Now().UTC().Format(time.RFC3339), s) +} diff --git a/defaults/logger_test.go b/defaults/logger_test.go new file mode 100644 index 0000000..9925c80 --- /dev/null +++ b/defaults/logger_test.go @@ -0,0 +1,26 @@ +package defaults + +import ( + "bytes" + "regexp" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestLogger(t *testing.T) { + t.Parallel() + + b := &bytes.Buffer{} + logger := NewLogger(b) + + logger.Info("hello") + logger.Error("world") + + rgxTimestamp := `[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z` + rgx := regexp.MustCompile(rgxTimestamp + ` \[INFO\]: hello\n` + rgxTimestamp + ` \[EROR\]: world\n`) + if !rgx.Match(b.Bytes()) { + t.Errorf("output from log file did not match regex:\n%s\n%v", b.String(), b.Bytes()) + spew.Dump(b.Bytes()) + } +} diff --git a/errors.go b/errors.go index 798eb64..58d1d1e 100644 --- a/errors.go +++ b/errors.go @@ -1,25 +1,11 @@ package authboss -import "fmt" +import ( + "net/http" +) -// ClientDataErr represents a failure to retrieve a critical -// piece of client information such as a cookie or session value. -type ClientDataErr struct { - Name string -} - -func (c ClientDataErr) Error() string { - return fmt.Sprintf("Failed to retrieve client attribute: %s", c.Name) -} - -// RenderErr represents an error that occured during rendering -// of a template. -type RenderErr struct { - TemplateName string - Data interface{} - Err error -} - -func (r RenderErr) Error() string { - return fmt.Sprintf("error rendering response %q: %v, data: %#v", r.TemplateName, r.Err, r.Data) +// ErrorHandler allows routing to http.HandlerFunc's that additionally +// return an error for a higher level error handling mechanism. +type ErrorHandler interface { + Wrap(func(w http.ResponseWriter, r *http.Request) error) http.Handler } diff --git a/errors_test.go b/errors_test.go deleted file mode 100644 index 8a5b7f9..0000000 --- a/errors_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package authboss - -import ( - "testing" - - "github.com/pkg/errors" -) - -func TestClientDataErr(t *testing.T) { - t.Parallel() - - estr := "Failed to retrieve client attribute: lol" - err := ClientDataErr{"lol"} - if str := err.Error(); str != estr { - t.Error("Error was wrong:", str) - } -} - -func TestRenderErr(t *testing.T) { - t.Parallel() - - estr := `error rendering response "lol": cause, data: authboss.HTMLData{"a":5}` - err := RenderErr{"lol", NewHTMLData("a", 5), errors.New("cause")} - if str := err.Error(); str != estr { - t.Error("Error was wrong:", str) - } -} diff --git a/logger.go b/logger.go index f060c65..52c0f83 100644 --- a/logger.go +++ b/logger.go @@ -1,26 +1,36 @@ package authboss import ( - "io" - "log" - "net/http" - "os" + "context" ) -// DefaultLogger is a basic logger. -type DefaultLogger log.Logger - -// LogWriteMaker is used to create a logger from an http request. -// TODO(aarondl): decide what to do with this, should we keep it? -type LogWriteMaker func(http.ResponseWriter, *http.Request) io.Writer - -// NewDefaultLogger creates a logger to stdout. -func NewDefaultLogger() *DefaultLogger { - return ((*DefaultLogger)(log.New(os.Stdout, "", log.LstdFlags))) +// Logger is the basic logging structure that's required +type Logger interface { + Info(string) + Error(string) } -// Write writes to the internal logger. -func (d *DefaultLogger) Write(b []byte) (int, error) { - ((*log.Logger)(d)).Printf("%s", b) - return len(b), nil +// ContextLogger creates a logger from a request context +type ContextLogger interface { + FromContext(ctx context.Context) Logger +} + +// Logger returns an appopriate logger for the context: +// If context is nil, then it simply returns the configured +// logger. +// If context is not nil, then it will attempt to upgrade +// the configured logger to a ContextLogger, and create +// a context-specific logger for use. +func (a *Authboss) Logger(ctx context.Context) Logger { + logger := a.Config.Core.Logger + if ctx == nil { + return logger + } + + ctxLogger, ok := logger.(ContextLogger) + if !ok { + return logger + } + + return ctxLogger.FromContext(ctx) } diff --git a/logger_test.go b/logger_test.go index 0e29951..6432d63 100644 --- a/logger_test.go +++ b/logger_test.go @@ -1,29 +1,35 @@ package authboss import ( - "bytes" - "io" - "log" - "strings" + "context" "testing" ) -func TestDefaultLogger(t *testing.T) { +type ( + testLogger struct{} + testCtxLogger struct{} +) + +func (t testLogger) Info(string) {} +func (t testLogger) Error(string) {} + +func (t testLogger) FromContext(ctx context.Context) Logger { return testCtxLogger{} } + +func (t testCtxLogger) Info(string) {} +func (t testCtxLogger) Error(string) {} + +func TestLogger(t *testing.T) { t.Parallel() - logger := NewDefaultLogger() - if logger == nil { - t.Error("Logger was not created.") - } -} - -func TestDefaultLoggerOutput(t *testing.T) { - t.Parallel() - - buffer := &bytes.Buffer{} - logger := (*DefaultLogger)(log.New(buffer, "", log.LstdFlags)) - io.WriteString(logger, "hello world") - if s := buffer.String(); !strings.HasSuffix(s, "hello world\n") { - t.Error("Output was wrong:", s) + ab := New() + logger := testLogger{} + ab.Config.Core.Logger = logger + + if logger != ab.Logger(nil).(testLogger) { + t.Error("wanted our logger back") + } + + if _, ok := ab.Logger(context.Background()).(testCtxLogger); !ok { + t.Error("wanted ctx logger back") } }