mirror of
https://github.com/volatiletech/authboss.git
synced 2025-01-10 04:17:59 +02:00
2137c827d3
- Remove a test that was obsoleted by optimizations. Not 100% sure this is correct, but it seems like if nothing has changed since the previous session/cookie read then we shouldn't need to write any new headers for them. This is especially true in the typical "I use cookies for everything" use case, but may not be true of other use cases... Remains to be seen. Since they're optimizations they should be able to removed "safely" later.
322 lines
9.2 KiB
Go
322 lines
9.2 KiB
Go
package authboss
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
)
|
|
|
|
const (
|
|
// SessionKey is the primarily used key by authboss.
|
|
SessionKey = "uid"
|
|
// SessionHalfAuthKey is used for sessions that have been authenticated by
|
|
// the remember module. This serves as a way to force full authentication
|
|
// by denying half-authed users acccess to sensitive areas.
|
|
SessionHalfAuthKey = "halfauth"
|
|
// SessionLastAction is the session key to retrieve the last action of a user.
|
|
SessionLastAction = "last_action"
|
|
// SessionOAuth2State is the xsrf protection key for oauth.
|
|
SessionOAuth2State = "oauth2_state"
|
|
// SessionOAuth2Params is the additional settings for oauth like redirection/remember.
|
|
SessionOAuth2Params = "oauth2_params"
|
|
|
|
// CookieRemember is used for cookies and form input names.
|
|
CookieRemember = "rm"
|
|
|
|
// FlashSuccessKey is used for storing sucess flash messages on the session
|
|
FlashSuccessKey = "flash_success"
|
|
// FlashErrorKey is used for storing sucess flash messages on the session
|
|
FlashErrorKey = "flash_error"
|
|
)
|
|
|
|
// ClientStateEventKind is an enum.
|
|
type ClientStateEventKind int
|
|
|
|
// ClientStateEvent kinds
|
|
const (
|
|
ClientStateEventPut ClientStateEventKind = iota
|
|
ClientStateEventDel
|
|
)
|
|
|
|
// ClientStateEvent are the different events that can be recorded during
|
|
type ClientStateEvent struct {
|
|
Kind ClientStateEventKind
|
|
Key string
|
|
Value string
|
|
}
|
|
|
|
// ClientStateReadWriter is used to create a cookie storer from an http request.
|
|
// Keep in mind security considerations for your implementation, Secure,
|
|
// HTTP-Only, etc flags.
|
|
//
|
|
// There's two major uses for this. To create session storage, and remember me
|
|
// cookies.
|
|
type ClientStateReadWriter interface {
|
|
// ReadState should return a map like structure allowing it to look up
|
|
// any values in the current session, or any cookie in the request
|
|
ReadState(http.ResponseWriter, *http.Request) (ClientState, error)
|
|
// WriteState can sometimes be called with a nil ClientState in the event
|
|
// that no ClientState was recovered from the request context.
|
|
WriteState(http.ResponseWriter, ClientState, []ClientStateEvent) error
|
|
}
|
|
|
|
// UnderlyingResponseWriter retrieves the response
|
|
// writer underneath the current one. This allows us
|
|
// to wrap and later discover the particular one that we want.
|
|
// Keep in mind this should not be used to call the normal methods
|
|
// of a responsewriter, just additional ones particular to that type
|
|
// because it's possible to introduce subtle bugs otherwise.
|
|
type UnderlyingResponseWriter interface {
|
|
UnderlyingResponseWriter() http.ResponseWriter
|
|
}
|
|
|
|
// ClientState represents the client's current state and can answer queries
|
|
// about it.
|
|
type ClientState interface {
|
|
Get(key string) (string, bool)
|
|
}
|
|
|
|
// ClientStateResponseWriter is used to write out the client state at the last
|
|
// moment before the response code is written.
|
|
type ClientStateResponseWriter struct {
|
|
http.ResponseWriter
|
|
|
|
cookieState ClientStateReadWriter
|
|
sessionState ClientStateReadWriter
|
|
|
|
hasWritten bool
|
|
ctx context.Context
|
|
sessionStateEvents []ClientStateEvent
|
|
cookieStateEvents []ClientStateEvent
|
|
}
|
|
|
|
// LoadClientStateMiddleware wraps all requests with the ClientStateResponseWriter
|
|
// as well as loading the current client state into the context for use.
|
|
func (a *Authboss) LoadClientStateMiddleware(h http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
request, err := a.LoadClientState(w, r)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to load client state: %+v", err))
|
|
}
|
|
|
|
writer := a.NewResponse(w, request)
|
|
|
|
h.ServeHTTP(writer, request)
|
|
})
|
|
}
|
|
|
|
// NewResponse wraps the ResponseWriter with a ClientStateResponseWriter
|
|
func (a *Authboss) NewResponse(w http.ResponseWriter, r *http.Request) *ClientStateResponseWriter {
|
|
return &ClientStateResponseWriter{
|
|
ResponseWriter: w,
|
|
cookieState: a.Config.Storage.CookieState,
|
|
sessionState: a.Config.Storage.SessionState,
|
|
ctx: r.Context(),
|
|
}
|
|
}
|
|
|
|
// LoadClientState loads the state from sessions and cookies into the request context
|
|
func (a *Authboss) LoadClientState(w http.ResponseWriter, r *http.Request) (*http.Request, error) {
|
|
if a.Storage.SessionState != nil {
|
|
state, err := a.Storage.SessionState.ReadState(w, r)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if state == nil {
|
|
return r, nil
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), CTXKeySessionState, state)
|
|
r = r.WithContext(ctx)
|
|
}
|
|
if a.Storage.CookieState != nil {
|
|
state, err := a.Storage.CookieState.ReadState(w, r)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if state == nil {
|
|
return r, nil
|
|
}
|
|
ctx := context.WithValue(r.Context(), CTXKeyCookieState, state)
|
|
r = r.WithContext(ctx)
|
|
}
|
|
|
|
return r, nil
|
|
}
|
|
|
|
// MustClientStateResponseWriter tries to find a csrw inside the response
|
|
// writer by using the UnderlyingResponseWriter interface.
|
|
func MustClientStateResponseWriter(w http.ResponseWriter) *ClientStateResponseWriter {
|
|
for {
|
|
if c, ok := w.(*ClientStateResponseWriter); ok {
|
|
return c
|
|
}
|
|
|
|
if u, ok := w.(UnderlyingResponseWriter); ok {
|
|
w = u.UnderlyingResponseWriter()
|
|
continue
|
|
}
|
|
|
|
panic(fmt.Sprintf("failed to find a ClientStateResponseWriter or UnderlyingResponseWriter in: %T", w))
|
|
}
|
|
}
|
|
|
|
// WriteHeader writes the header, but in order to handle errors from the
|
|
// underlying ClientStateReadWriter, it has to panic.
|
|
func (c *ClientStateResponseWriter) WriteHeader(code int) {
|
|
if !c.hasWritten {
|
|
if err := c.putClientState(); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
c.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// Header retrieves the underlying headers
|
|
func (c ClientStateResponseWriter) Header() http.Header {
|
|
return c.ResponseWriter.Header()
|
|
}
|
|
|
|
// Write ensures that the client state is written before any writes
|
|
// to the body occur (before header flush to http client)
|
|
func (c *ClientStateResponseWriter) Write(b []byte) (int, error) {
|
|
if !c.hasWritten {
|
|
if err := c.putClientState(); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
return c.ResponseWriter.Write(b)
|
|
}
|
|
|
|
// UnderlyingResponseWriter for this instance
|
|
func (c *ClientStateResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
|
return c.ResponseWriter
|
|
}
|
|
|
|
func (c *ClientStateResponseWriter) putClientState() error {
|
|
if c.hasWritten {
|
|
panic("should not call putClientState twice")
|
|
}
|
|
c.hasWritten = true
|
|
|
|
if len(c.cookieStateEvents) == 0 && len(c.sessionStateEvents) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if c.sessionState != nil && len(c.sessionStateEvents) > 0 {
|
|
sessionStateIntf := c.ctx.Value(CTXKeySessionState)
|
|
|
|
var session ClientState
|
|
if sessionStateIntf != nil {
|
|
session = sessionStateIntf.(ClientState)
|
|
}
|
|
|
|
err := c.sessionState.WriteState(c, session, c.sessionStateEvents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if c.cookieState != nil && len(c.cookieStateEvents) > 0 {
|
|
cookieStateIntf := c.ctx.Value(CTXKeyCookieState)
|
|
|
|
var cookie ClientState
|
|
if cookieStateIntf != nil {
|
|
cookie = cookieStateIntf.(ClientState)
|
|
}
|
|
|
|
err := c.cookieState.WriteState(c, cookie, c.cookieStateEvents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// PutSession puts a value into the session
|
|
func PutSession(w http.ResponseWriter, key, val string) {
|
|
putState(w, CTXKeySessionState, key, val)
|
|
}
|
|
|
|
// DelSession deletes a key-value from the session.
|
|
func DelSession(w http.ResponseWriter, key string) {
|
|
delState(w, CTXKeySessionState, key)
|
|
}
|
|
|
|
// GetSession fetches a value from the session
|
|
func GetSession(r *http.Request, key string) (string, bool) {
|
|
return getState(r, CTXKeySessionState, key)
|
|
}
|
|
|
|
// PutCookie puts a value into the session
|
|
func PutCookie(w http.ResponseWriter, key, val string) {
|
|
putState(w, CTXKeyCookieState, key, val)
|
|
}
|
|
|
|
// DelCookie deletes a key-value from the session.
|
|
func DelCookie(w http.ResponseWriter, key string) {
|
|
delState(w, CTXKeyCookieState, key)
|
|
}
|
|
|
|
// GetCookie fetches a value from the session
|
|
func GetCookie(r *http.Request, key string) (string, bool) {
|
|
return getState(r, CTXKeyCookieState, key)
|
|
}
|
|
|
|
func putState(w http.ResponseWriter, CTXKey contextKey, key, val string) {
|
|
setState(w, CTXKey, ClientStateEventPut, key, val)
|
|
}
|
|
|
|
func delState(w http.ResponseWriter, CTXKey contextKey, key string) {
|
|
setState(w, CTXKey, ClientStateEventDel, key, "")
|
|
}
|
|
|
|
func setState(w http.ResponseWriter, CTXKey contextKey, op ClientStateEventKind, key, val string) {
|
|
csrw := MustClientStateResponseWriter(w)
|
|
ev := ClientStateEvent{
|
|
Kind: op,
|
|
Key: key,
|
|
}
|
|
|
|
if op == ClientStateEventPut {
|
|
ev.Value = val
|
|
}
|
|
|
|
switch CTXKey {
|
|
case CTXKeySessionState:
|
|
csrw.sessionStateEvents = append(csrw.sessionStateEvents, ev)
|
|
case CTXKeyCookieState:
|
|
csrw.cookieStateEvents = append(csrw.cookieStateEvents, ev)
|
|
}
|
|
}
|
|
|
|
func getState(r *http.Request, CTXKey contextKey, key string) (string, bool) {
|
|
val := r.Context().Value(CTXKey)
|
|
if val == nil {
|
|
return "", false
|
|
}
|
|
|
|
state := val.(ClientState)
|
|
return state.Get(key)
|
|
}
|
|
|
|
// FlashSuccess returns FlashSuccessKey from the session and removes it.
|
|
func FlashSuccess(w http.ResponseWriter, r *http.Request) string {
|
|
str, ok := GetSession(r, FlashSuccessKey)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
|
|
DelSession(w, FlashSuccessKey)
|
|
return str
|
|
}
|
|
|
|
// FlashError returns FlashError from the session and removes it.
|
|
func FlashError(w http.ResponseWriter, r *http.Request) string {
|
|
str, ok := GetSession(r, FlashErrorKey)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
|
|
DelSession(w, FlashErrorKey)
|
|
return str
|
|
}
|