1
0
mirror of https://github.com/volatiletech/authboss.git synced 2025-01-24 05:17:10 +02:00
authboss/client_state.go
Aaron L 2137c827d3 Fix tests back up again.
- Remove a test that was obsoleted by optimizations. Not 100% sure this
  is correct, but it seems like if nothing has changed since the
  previous session/cookie read then we shouldn't need to write any new
  headers for them. This is especially true in the typical "I use
  cookies for everything" use case, but may not be true of other use
  cases... Remains to be seen. Since they're optimizations they should
  be able to removed "safely" later.
2018-02-14 15:16:44 -08:00

322 lines
9.2 KiB
Go

package authboss
import (
"context"
"fmt"
"net/http"
)
const (
// SessionKey is the primarily used key by authboss.
SessionKey = "uid"
// SessionHalfAuthKey is used for sessions that have been authenticated by
// the remember module. This serves as a way to force full authentication
// by denying half-authed users acccess to sensitive areas.
SessionHalfAuthKey = "halfauth"
// SessionLastAction is the session key to retrieve the last action of a user.
SessionLastAction = "last_action"
// SessionOAuth2State is the xsrf protection key for oauth.
SessionOAuth2State = "oauth2_state"
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
SessionOAuth2Params = "oauth2_params"
// CookieRemember is used for cookies and form input names.
CookieRemember = "rm"
// FlashSuccessKey is used for storing sucess flash messages on the session
FlashSuccessKey = "flash_success"
// FlashErrorKey is used for storing sucess flash messages on the session
FlashErrorKey = "flash_error"
)
// ClientStateEventKind is an enum.
type ClientStateEventKind int
// ClientStateEvent kinds
const (
ClientStateEventPut ClientStateEventKind = iota
ClientStateEventDel
)
// ClientStateEvent are the different events that can be recorded during
type ClientStateEvent struct {
Kind ClientStateEventKind
Key string
Value string
}
// ClientStateReadWriter is used to create a cookie storer from an http request.
// Keep in mind security considerations for your implementation, Secure,
// HTTP-Only, etc flags.
//
// There's two major uses for this. To create session storage, and remember me
// cookies.
type ClientStateReadWriter interface {
// ReadState should return a map like structure allowing it to look up
// any values in the current session, or any cookie in the request
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
// WriteState can sometimes be called with a nil ClientState in the event
// that no ClientState was recovered from the request context.
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
}
// UnderlyingResponseWriter retrieves the response
// writer underneath the current one. This allows us
// to wrap and later discover the particular one that we want.
// Keep in mind this should not be used to call the normal methods
// of a responsewriter, just additional ones particular to that type
// because it's possible to introduce subtle bugs otherwise.
type UnderlyingResponseWriter interface {
UnderlyingResponseWriter() http.ResponseWriter
}
// ClientState represents the client's current state and can answer queries
// about it.
type ClientState interface {
Get(key string) (string, bool)
}
// ClientStateResponseWriter is used to write out the client state at the last
// moment before the response code is written.
type ClientStateResponseWriter struct {
http.ResponseWriter
cookieState ClientStateReadWriter
sessionState ClientStateReadWriter
hasWritten bool
ctx context.Context
sessionStateEvents []ClientStateEvent
cookieStateEvents []ClientStateEvent
}
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
// as well as loading the current client state into the context for use.
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
request, err := a.LoadClientState(w, r)
if err != nil {
panic(fmt.Sprintf("failed to load client state: %+v", err))
}
writer := a.NewResponse(w, request)
h.ServeHTTP(writer, request)
})
}
// NewResponse wraps the ResponseWriter with a ClientStateResponseWriter
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) *ClientStateResponseWriter {
return &ClientStateResponseWriter{
ResponseWriter: w,
cookieState: a.Config.Storage.CookieState,
sessionState: a.Config.Storage.SessionState,
ctx: r.Context(),
}
}
// LoadClientState loads the state from sessions and cookies into the request context
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
if a.Storage.SessionState != nil {
state, err := a.Storage.SessionState.ReadState(w, r)
if err != nil {
return nil, err
} else if state == nil {
return r, nil
}
ctx := context.WithValue(r.Context(), CTXKeySessionState, state)
r = r.WithContext(ctx)
}
if a.Storage.CookieState != nil {
state, err := a.Storage.CookieState.ReadState(w, r)
if err != nil {
return nil, err
} else if state == nil {
return r, nil
}
ctx := context.WithValue(r.Context(), CTXKeyCookieState, state)
r = r.WithContext(ctx)
}
return r, nil
}
// MustClientStateResponseWriter tries to find a csrw inside the response
// writer by using the UnderlyingResponseWriter interface.
func MustClientStateResponseWriter(w http.ResponseWriter) *ClientStateResponseWriter {
for {
if c, ok := w.(*ClientStateResponseWriter); ok {
return c
}
if u, ok := w.(UnderlyingResponseWriter); ok {
w = u.UnderlyingResponseWriter()
continue
}
panic(fmt.Sprintf("failed to find a ClientStateResponseWriter or UnderlyingResponseWriter in: %T", w))
}
}
// WriteHeader writes the header, but in order to handle errors from the
// underlying ClientStateReadWriter, it has to panic.
func (c *ClientStateResponseWriter) WriteHeader(code int) {
if !c.hasWritten {
if err := c.putClientState(); err != nil {
panic(err)
}
}
c.ResponseWriter.WriteHeader(code)
}
// Header retrieves the underlying headers
func (c ClientStateResponseWriter) Header() http.Header {
return c.ResponseWriter.Header()
}
// Write ensures that the client state is written before any writes
// to the body occur (before header flush to http client)
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
if !c.hasWritten {
if err := c.putClientState(); err != nil {
return 0, err
}
}
return c.ResponseWriter.Write(b)
}
// UnderlyingResponseWriter for this instance
func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return c.ResponseWriter
}
func (c *ClientStateResponseWriter) putClientState() error {
if c.hasWritten {
panic("should not call putClientState twice")
}
c.hasWritten = true
if len(c.cookieStateEvents) == 0 && len(c.sessionStateEvents) == 0 {
return nil
}
if c.sessionState != nil && len(c.sessionStateEvents) > 0 {
sessionStateIntf := c.ctx.Value(CTXKeySessionState)
var session ClientState
if sessionStateIntf != nil {
session = sessionStateIntf.(ClientState)
}
err := c.sessionState.WriteState(c, session, c.sessionStateEvents)
if err != nil {
return err
}
}
if c.cookieState != nil && len(c.cookieStateEvents) > 0 {
cookieStateIntf := c.ctx.Value(CTXKeyCookieState)
var cookie ClientState
if cookieStateIntf != nil {
cookie = cookieStateIntf.(ClientState)
}
err := c.cookieState.WriteState(c, cookie, c.cookieStateEvents)
if err != nil {
return err
}
}
return nil
}
// PutSession puts a value into the session
func PutSession(w http.ResponseWriter, key, val string) {
putState(w, CTXKeySessionState, key, val)
}
// DelSession deletes a key-value from the session.
func DelSession(w http.ResponseWriter, key string) {
delState(w, CTXKeySessionState, key)
}
// GetSession fetches a value from the session
func GetSession(r *http.Request, key string) (string, bool) {
return getState(r, CTXKeySessionState, key)
}
// PutCookie puts a value into the session
func PutCookie(w http.ResponseWriter, key, val string) {
putState(w, CTXKeyCookieState, key, val)
}
// DelCookie deletes a key-value from the session.
func DelCookie(w http.ResponseWriter, key string) {
delState(w, CTXKeyCookieState, key)
}
// GetCookie fetches a value from the session
func GetCookie(r *http.Request, key string) (string, bool) {
return getState(r, CTXKeyCookieState, key)
}
func putState(w http.ResponseWriter, CTXKey contextKey, key, val string) {
setState(w, CTXKey, ClientStateEventPut, key, val)
}
func delState(w http.ResponseWriter, CTXKey contextKey, key string) {
setState(w, CTXKey, ClientStateEventDel, key, "")
}
func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind, key, val string) {
csrw := MustClientStateResponseWriter(w)
ev := ClientStateEvent{
Kind: op,
Key: key,
}
if op == ClientStateEventPut {
ev.Value = val
}
switch CTXKey {
case CTXKeySessionState:
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
case CTXKeyCookieState:
csrw.cookieStateEvents = append(csrw.cookieStateEvents, ev)
}
}
func getState(r *http.Request, CTXKey contextKey, key string) (string, bool) {
val := r.Context().Value(CTXKey)
if val == nil {
return "", false
}
state := val.(ClientState)
return state.Get(key)
}
// FlashSuccess returns FlashSuccessKey from the session and removes it.
func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
str, ok := GetSession(r, FlashSuccessKey)
if !ok {
return ""
}
DelSession(w, FlashSuccessKey)
return str
}
// FlashError returns FlashError from the session and removes it.
func FlashError(w http.ResponseWriter, r *http.Request) string {
str, ok := GetSession(r, FlashErrorKey)
if !ok {
return ""
}
DelSession(w, FlashErrorKey)
return str
}