diff --git a/auth/auth.go b/auth/auth.go index bead0a3..23b481b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -33,7 +33,7 @@ func (a *Auth) Init(ab *authboss.Authboss) (err error) { } var logoutRouteMethod func(string, http.Handler) - switch a.Authboss.Config.Modules.LogoutMethod { + switch a.Authboss.Config.Modules.AuthLogoutMethod { case "GET": logoutRouteMethod = a.Authboss.Config.Core.Router.Get case "POST": @@ -41,7 +41,7 @@ func (a *Auth) Init(ab *authboss.Authboss) (err error) { case "DELETE": logoutRouteMethod = a.Authboss.Config.Core.Router.Delete default: - return errors.Errorf("auth wants to register a logout route but is given an invalid method: %s", a.Authboss.Config.Modules.LogoutMethod) + return errors.Errorf("auth wants to register a logout route but is given an invalid method: %s", a.Authboss.Config.Modules.AuthLogoutMethod) } a.Authboss.Config.Core.Router.Get("/login", a.Authboss.Core.ErrorHandler.Wrap(a.LoginGet)) diff --git a/client_state.go b/client_state.go index a487595..f4931c1 100644 --- a/client_state.go +++ b/client_state.go @@ -52,7 +52,11 @@ type ClientStateEvent struct { // There's two major uses for this. To create session storage, and remember me // cookies. type ClientStateReadWriter interface { + // ReadState should return a map like structure allowing it to look up + // any values in the current session, or any cookie in the request ReadState(http.ResponseWriter, *http.Request) (ClientState, error) + // WriteState can sometimes be called with a nil ClientState in the event + // that no ClientState was recovered from the request context. WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error } @@ -75,20 +79,37 @@ type ClientState interface { // ClientStateResponseWriter is used to write out the client state at the last // moment before the response code is written. type ClientStateResponseWriter struct { - ab *Authboss http.ResponseWriter + cookieState ClientStateReadWriter + sessionState ClientStateReadWriter + hasWritten bool ctx context.Context sessionStateEvents []ClientStateEvent cookieStateEvents []ClientStateEvent } +// ClientStateMiddleware wraps all requests with the ClientStateResponseWriter +func (a *Authboss) ClientStateMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + request, err := a.LoadClientState(w, r) + if err != nil { + panic(fmt.Sprintf("failed to load client state: %+v", err)) + } + + writer := a.NewResponse(w, request) + + h.ServeHTTP(writer, request) + }) +} + // NewResponse wraps the ResponseWriter with a ClientStateResponseWriter func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) *ClientStateResponseWriter { return &ClientStateResponseWriter{ - ab: a, ResponseWriter: w, + cookieState: a.Config.Storage.CookieState, + sessionState: a.Config.Storage.SessionState, ctx: r.Context(), } } @@ -174,24 +195,32 @@ func (c *ClientStateResponseWriter) putClientState() error { } c.hasWritten = true - sessionStateIntf := c.ctx.Value(ctxKeySessionState) - cookieStateIntf := c.ctx.Value(ctxKeyCookieState) - var session, cookie ClientState - if sessionStateIntf != nil { - session = sessionStateIntf.(ClientState) - } - if cookieStateIntf != nil { - cookie = cookieStateIntf.(ClientState) + if len(c.cookieStateEvents) == 0 && len(c.sessionStateEvents) == 0 { + return nil } - if c.ab.Storage.SessionState != nil { - err := c.ab.Storage.SessionState.WriteState(c, session, c.sessionStateEvents) + if c.sessionState != nil && len(c.sessionStateEvents) > 0 { + sessionStateIntf := c.ctx.Value(ctxKeySessionState) + + var session ClientState + if sessionStateIntf != nil { + session = sessionStateIntf.(ClientState) + } + + err := c.sessionState.WriteState(c, session, c.sessionStateEvents) if err != nil { return err } } - if c.ab.Storage.CookieState != nil { - err := c.ab.Storage.CookieState.WriteState(c, cookie, c.cookieStateEvents) + if c.cookieState != nil && len(c.cookieStateEvents) > 0 { + cookieStateIntf := c.ctx.Value(ctxKeyCookieState) + + var cookie ClientState + if cookieStateIntf != nil { + cookie = cookieStateIntf.(ClientState) + } + + err := c.cookieState.WriteState(c, cookie, c.cookieStateEvents) if err != nil { return err } diff --git a/config.go b/config.go index 52be36a..0488204 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,8 @@ package authboss import ( "time" + + "golang.org/x/crypto/bcrypt" ) // Config holds all the configuration for both authboss and it's modules. @@ -29,31 +31,31 @@ type Config struct { // BCryptCost is the cost of the bcrypt password hashing function. AuthBCryptCost int - // LogoutMethod is the method the logout route should use (default should be DELETE) - LogoutMethod string - - // OAuth2Providers lists all providers that can be used. See - // OAuthProvider documentation for more details. - OAuth2Providers map[string]OAuth2Provider - - // PreserveFields are fields used with registration that are to be rendered when - // post fails. - PreserveFields []string + // AuthLogoutMethod is the method the logout route should use (default should be DELETE) + AuthLogoutMethod string // ExpireAfter controls the time an account is idle before being logged out // by the ExpireMiddleware. ExpireAfter time.Duration - // RecoverTokenDuration controls how long a token sent via email for password - // recovery is valid for. - RecoverTokenDuration time.Duration - // LockAfter this many tries. LockAfter int // LockWindow is the waiting time before the number of attemps are reset. LockWindow time.Duration // LockDuration is how long an account is locked for. LockDuration time.Duration + + // RegisterPreserveFields are fields used with registration that are to be rendered when + // post fails. + RegisterPreserveFields []string + + // RecoverTokenDuration controls how long a token sent via email for password + // recovery is valid for. + RecoverTokenDuration time.Duration + + // OAuth2Providers lists all providers that can be used. See + // OAuthProvider documentation for more details. + OAuth2Providers map[string]OAuth2Provider } Mail struct { @@ -117,49 +119,18 @@ type Config struct { // Defaults sets the configuration's default values. func (c *Config) Defaults() { - /*c.MountPath = "/" - c.ViewsPath = "./" - c.RootURL = "http://localhost:8080" - c.BCryptCost = bcrypt.DefaultCost + c.Paths.Mount = "/" + c.Paths.RootURL = "http://localhost:8080" + c.Paths.AuthLoginOK = "/" + c.Paths.AuthLogoutOK = "/" + c.Paths.RecoverOK = "/" + c.Paths.RegisterOK = "/" - c.PrimaryID = StoreEmail - - c.AuthLoginOKPath = "/" - c.AuthLoginFailPath = "/" - c.AuthLogoutOKPath = "/" - - c.RecoverOKPath = "/" - c.RecoverTokenDuration = time.Duration(24) * time.Hour - - c.RegisterOKPath = "/" - - c.Policies = []Validator{ - Rules{ - FieldName: "email", - Required: true, - AllowWhitespace: false, - }, - Rules{ - FieldName: "password", - Required: true, - MinLength: 4, - MaxLength: 8, - AllowWhitespace: false, - }, - } - c.ConfirmFields = []string{ - StorePassword, ConfirmPrefix + StorePassword, - } - - c.ExpireAfter = 60 * time.Minute - - c.LockAfter = 3 - c.LockWindow = 5 * time.Minute - c.LockDuration = 5 * time.Hour - - c.LogWriter = NewDefaultLogger() - c.Mailer = LogMailer(ioutil.Discard) - c.ContextProvider = func(req *http.Request) context.Context { - return context.TODO() - }*/ + c.Modules.AuthBCryptCost = bcrypt.DefaultCost + c.Modules.AuthLogoutMethod = "DELETE" + c.Modules.ExpireAfter = 60 * time.Minute + c.Modules.LockAfter = 3 + c.Modules.LockWindow = 5 * time.Minute + c.Modules.LockDuration = 5 * time.Hour + c.Modules.RecoverTokenDuration = time.Duration(24) * time.Hour }