mirror of
https://github.com/alexedwards/scs.git
synced 2025-07-11 00:50:14 +02:00
630 lines
17 KiB
Go
630 lines
17 KiB
Go
package scs
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
// Status represents the state of the session data during a request cycle.
|
|
type Status int
|
|
|
|
const (
|
|
// Unmodified indicates that the session data hasn't been changed in the
|
|
// current request cycle.
|
|
Unmodified Status = iota
|
|
|
|
// Modified indicates that the session data has been changed in the current
|
|
// request cycle.
|
|
Modified
|
|
|
|
// Destroyed indicates that the session data has been destroyed in the
|
|
// current request cycle.
|
|
Destroyed
|
|
)
|
|
|
|
type sessionData struct {
|
|
deadline time.Time
|
|
status Status
|
|
token string
|
|
values map[string]interface{}
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newSessionData(lifetime time.Duration) *sessionData {
|
|
return &sessionData{
|
|
deadline: time.Now().Add(lifetime).UTC(),
|
|
status: Unmodified,
|
|
values: make(map[string]interface{}),
|
|
}
|
|
}
|
|
|
|
// Load retrieves the session data for the given token from the session store,
|
|
// and returns a new context.Context containing the session data. If no matching
|
|
// token is found then this will create a new session.
|
|
//
|
|
// Most applications will use the LoadAndSave() middleware and will not need to
|
|
// use this method.
|
|
func (s *SessionManager) Load(ctx context.Context, token string) (context.Context, error) {
|
|
if _, ok := ctx.Value(s.contextKey).(*sessionData); ok {
|
|
return ctx, nil
|
|
}
|
|
|
|
if token == "" {
|
|
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
|
|
}
|
|
|
|
b, found, err := s.doStoreFind(ctx, token)
|
|
if err != nil {
|
|
return nil, err
|
|
} else if !found {
|
|
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
|
|
}
|
|
|
|
sd := &sessionData{
|
|
status: Unmodified,
|
|
token: token,
|
|
}
|
|
if sd.deadline, sd.values, err = s.Codec.Decode(b); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Mark the session data as modified if an idle timeout is being used. This
|
|
// will force the session data to be re-committed to the session store with
|
|
// a new expiry time.
|
|
if s.IdleTimeout > 0 {
|
|
sd.status = Modified
|
|
}
|
|
|
|
return s.addSessionDataToContext(ctx, sd), nil
|
|
}
|
|
|
|
// Commit saves the session data to the session store and returns the session
|
|
// token and expiry time.
|
|
//
|
|
// Most applications will use the LoadAndSave() middleware and will not need to
|
|
// use this method.
|
|
func (s *SessionManager) Commit(ctx context.Context) (string, time.Time, error) {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
if sd.token == "" {
|
|
var err error
|
|
if sd.token, err = generateToken(); err != nil {
|
|
return "", time.Time{}, err
|
|
}
|
|
}
|
|
|
|
b, err := s.Codec.Encode(sd.deadline, sd.values)
|
|
if err != nil {
|
|
return "", time.Time{}, err
|
|
}
|
|
|
|
expiry := sd.deadline
|
|
if s.IdleTimeout > 0 {
|
|
ie := time.Now().Add(s.IdleTimeout).UTC()
|
|
if ie.Before(expiry) {
|
|
expiry = ie
|
|
}
|
|
}
|
|
|
|
if err := s.doStoreCommit(ctx, sd.token, b, expiry); err != nil {
|
|
return "", time.Time{}, err
|
|
}
|
|
|
|
return sd.token, expiry, nil
|
|
}
|
|
|
|
// Destroy deletes the session data from the session store and sets the session
|
|
// status to Destroyed. Any further operations in the same request cycle will
|
|
// result in a new session being created.
|
|
func (s *SessionManager) Destroy(ctx context.Context) error {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
err := s.doStoreDelete(ctx, sd.token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sd.status = Destroyed
|
|
|
|
// Reset everything else to defaults.
|
|
sd.token = ""
|
|
sd.deadline = time.Now().Add(s.Lifetime).UTC()
|
|
for key := range sd.values {
|
|
delete(sd.values, key)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Put adds a key and corresponding value to the session data. Any existing
|
|
// value for the key will be replaced. The session data status will be set to
|
|
// Modified.
|
|
func (s *SessionManager) Put(ctx context.Context, key string, val interface{}) {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
sd.values[key] = val
|
|
sd.status = Modified
|
|
sd.mu.Unlock()
|
|
}
|
|
|
|
// Get returns the value for a given key from the session data. The return
|
|
// value has the type interface{} so will usually need to be type asserted
|
|
// before you can use it. For example:
|
|
//
|
|
// foo, ok := session.Get(r, "foo").(string)
|
|
// if !ok {
|
|
// return errors.New("type assertion to string failed")
|
|
// }
|
|
//
|
|
// Also see the GetString(), GetInt(), GetBytes() and other helper methods which
|
|
// wrap the type conversion for common types.
|
|
func (s *SessionManager) Get(ctx context.Context, key string) interface{} {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
return sd.values[key]
|
|
}
|
|
|
|
// Pop acts like a one-time Get. It returns the value for a given key from the
|
|
// session data and deletes the key and value from the session data. The
|
|
// session data status will be set to Modified. The return value has the type
|
|
// interface{} so will usually need to be type asserted before you can use it.
|
|
func (s *SessionManager) Pop(ctx context.Context, key string) interface{} {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
val, exists := sd.values[key]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
delete(sd.values, key)
|
|
sd.status = Modified
|
|
|
|
return val
|
|
}
|
|
|
|
// Remove deletes the given key and corresponding value from the session data.
|
|
// The session data status will be set to Modified. If the key is not present
|
|
// this operation is a no-op.
|
|
func (s *SessionManager) Remove(ctx context.Context, key string) {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
_, exists := sd.values[key]
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
delete(sd.values, key)
|
|
sd.status = Modified
|
|
}
|
|
|
|
// Clear removes all data for the current session. The session token and
|
|
// lifetime are unaffected. If there is no data in the current session this is
|
|
// a no-op.
|
|
func (s *SessionManager) Clear(ctx context.Context) error {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
if len(sd.values) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for key := range sd.values {
|
|
delete(sd.values, key)
|
|
}
|
|
sd.status = Modified
|
|
return nil
|
|
}
|
|
|
|
// Exists returns true if the given key is present in the session data.
|
|
func (s *SessionManager) Exists(ctx context.Context, key string) bool {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
_, exists := sd.values[key]
|
|
sd.mu.Unlock()
|
|
|
|
return exists
|
|
}
|
|
|
|
// Keys returns a slice of all key names present in the session data, sorted
|
|
// alphabetically. If the data contains no data then an empty slice will be
|
|
// returned.
|
|
func (s *SessionManager) Keys(ctx context.Context) []string {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
keys := make([]string, len(sd.values))
|
|
i := 0
|
|
for key := range sd.values {
|
|
keys[i] = key
|
|
i++
|
|
}
|
|
sd.mu.Unlock()
|
|
|
|
sort.Strings(keys)
|
|
return keys
|
|
}
|
|
|
|
// RenewToken updates the session data to have a new session token while
|
|
// retaining the current session data. The session lifetime is also reset and
|
|
// the session data status will be set to Modified.
|
|
//
|
|
// The old session token and accompanying data are deleted from the session store.
|
|
//
|
|
// To mitigate the risk of session fixation attacks, it's important that you call
|
|
// RenewToken before making any changes to privilege levels (e.g. login and
|
|
// logout operations). See https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#renew-the-session-id-after-any-privilege-level-change
|
|
// for additional information.
|
|
func (s *SessionManager) RenewToken(ctx context.Context) error {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
err := s.doStoreDelete(ctx, sd.token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
newToken, err := generateToken()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sd.token = newToken
|
|
sd.deadline = time.Now().Add(s.Lifetime).UTC()
|
|
sd.status = Modified
|
|
|
|
return nil
|
|
}
|
|
|
|
// Status returns the current status of the session data.
|
|
func (s *SessionManager) Status(ctx context.Context) Status {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
return sd.status
|
|
}
|
|
|
|
// GetString returns the string value for a given key from the session data.
|
|
// The zero value for a string ("") is returned if the key does not exist or the
|
|
// value could not be type asserted to a string.
|
|
func (s *SessionManager) GetString(ctx context.Context, key string) string {
|
|
val := s.Get(ctx, key)
|
|
str, ok := val.(string)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return str
|
|
}
|
|
|
|
// GetBool returns the bool value for a given key from the session data. The
|
|
// zero value for a bool (false) is returned if the key does not exist or the
|
|
// value could not be type asserted to a bool.
|
|
func (s *SessionManager) GetBool(ctx context.Context, key string) bool {
|
|
val := s.Get(ctx, key)
|
|
b, ok := val.(bool)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return b
|
|
}
|
|
|
|
// GetInt returns the int value for a given key from the session data. The
|
|
// zero value for an int (0) is returned if the key does not exist or the
|
|
// value could not be type asserted to an int.
|
|
func (s *SessionManager) GetInt(ctx context.Context, key string) int {
|
|
val := s.Get(ctx, key)
|
|
i, ok := val.(int)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return i
|
|
}
|
|
|
|
// GetInt64 returns the int64 value for a given key from the session data. The
|
|
// zero value for an int64 (0) is returned if the key does not exist or the
|
|
// value could not be type asserted to an int64.
|
|
func (s *SessionManager) GetInt64(ctx context.Context, key string) int64 {
|
|
val := s.Get(ctx, key)
|
|
i, ok := val.(int64)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return i
|
|
}
|
|
|
|
// GetInt32 returns the int value for a given key from the session data. The
|
|
// zero value for an int32 (0) is returned if the key does not exist or the
|
|
// value could not be type asserted to an int32.
|
|
func (s *SessionManager) GetInt32(ctx context.Context, key string) int32 {
|
|
val := s.Get(ctx, key)
|
|
i, ok := val.(int32)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return i
|
|
}
|
|
|
|
// GetFloat returns the float64 value for a given key from the session data. The
|
|
// zero value for an float64 (0) is returned if the key does not exist or the
|
|
// value could not be type asserted to a float64.
|
|
func (s *SessionManager) GetFloat(ctx context.Context, key string) float64 {
|
|
val := s.Get(ctx, key)
|
|
f, ok := val.(float64)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return f
|
|
}
|
|
|
|
// GetBytes returns the byte slice ([]byte) value for a given key from the session
|
|
// data. The zero value for a slice (nil) is returned if the key does not exist
|
|
// or could not be type asserted to []byte.
|
|
func (s *SessionManager) GetBytes(ctx context.Context, key string) []byte {
|
|
val := s.Get(ctx, key)
|
|
b, ok := val.([]byte)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return b
|
|
}
|
|
|
|
// GetTime returns the time.Time value for a given key from the session data. The
|
|
// zero value for a time.Time object is returned if the key does not exist or the
|
|
// value could not be type asserted to a time.Time. This can be tested with the
|
|
// time.IsZero() method.
|
|
func (s *SessionManager) GetTime(ctx context.Context, key string) time.Time {
|
|
val := s.Get(ctx, key)
|
|
t, ok := val.(time.Time)
|
|
if !ok {
|
|
return time.Time{}
|
|
}
|
|
return t
|
|
}
|
|
|
|
// PopString returns the string value for a given key and then deletes it from the
|
|
// session data. The session data status will be set to Modified. The zero
|
|
// value for a string ("") is returned if the key does not exist or the value
|
|
// could not be type asserted to a string.
|
|
func (s *SessionManager) PopString(ctx context.Context, key string) string {
|
|
val := s.Pop(ctx, key)
|
|
str, ok := val.(string)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return str
|
|
}
|
|
|
|
// PopBool returns the bool value for a given key and then deletes it from the
|
|
// session data. The session data status will be set to Modified. The zero
|
|
// value for a bool (false) is returned if the key does not exist or the value
|
|
// could not be type asserted to a bool.
|
|
func (s *SessionManager) PopBool(ctx context.Context, key string) bool {
|
|
val := s.Pop(ctx, key)
|
|
b, ok := val.(bool)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return b
|
|
}
|
|
|
|
// PopInt returns the int value for a given key and then deletes it from the
|
|
// session data. The session data status will be set to Modified. The zero
|
|
// value for an int (0) is returned if the key does not exist or the value could
|
|
// not be type asserted to an int.
|
|
func (s *SessionManager) PopInt(ctx context.Context, key string) int {
|
|
val := s.Pop(ctx, key)
|
|
i, ok := val.(int)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return i
|
|
}
|
|
|
|
// PopFloat returns the float64 value for a given key and then deletes it from the
|
|
// session data. The session data status will be set to Modified. The zero
|
|
// value for an float64 (0) is returned if the key does not exist or the value
|
|
// could not be type asserted to a float64.
|
|
func (s *SessionManager) PopFloat(ctx context.Context, key string) float64 {
|
|
val := s.Pop(ctx, key)
|
|
f, ok := val.(float64)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return f
|
|
}
|
|
|
|
// PopBytes returns the byte slice ([]byte) value for a given key and then
|
|
// deletes it from the from the session data. The session data status will be
|
|
// set to Modified. The zero value for a slice (nil) is returned if the key does
|
|
// not exist or could not be type asserted to []byte.
|
|
func (s *SessionManager) PopBytes(ctx context.Context, key string) []byte {
|
|
val := s.Pop(ctx, key)
|
|
b, ok := val.([]byte)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return b
|
|
}
|
|
|
|
// PopTime returns the time.Time value for a given key and then deletes it from
|
|
// the session data. The session data status will be set to Modified. The zero
|
|
// value for a time.Time object is returned if the key does not exist or the
|
|
// value could not be type asserted to a time.Time.
|
|
func (s *SessionManager) PopTime(ctx context.Context, key string) time.Time {
|
|
val := s.Pop(ctx, key)
|
|
t, ok := val.(time.Time)
|
|
if !ok {
|
|
return time.Time{}
|
|
}
|
|
return t
|
|
}
|
|
|
|
// RememberMe controls whether the session cookie is persistent (i.e whether it
|
|
// is retained after a user closes their browser). RememberMe only has an effect
|
|
// if you have set SessionManager.Cookie.Persist = false (the default is true) and
|
|
// you are using the standard LoadAndSave() middleware.
|
|
func (s *SessionManager) RememberMe(ctx context.Context, val bool) {
|
|
s.Put(ctx, "__rememberMe", val)
|
|
}
|
|
|
|
// Iterate retrieves all active (i.e. not expired) sessions from the store and
|
|
// executes the provided function fn for each session. If the session store
|
|
// being used does not support iteration then Iterate will panic.
|
|
func (s *SessionManager) Iterate(ctx context.Context, fn func(context.Context) error) error {
|
|
allSessions, err := s.doStoreAll(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for token, b := range allSessions {
|
|
sd := &sessionData{
|
|
status: Unmodified,
|
|
token: token,
|
|
}
|
|
|
|
sd.deadline, sd.values, err = s.Codec.Decode(b)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx = s.addSessionDataToContext(ctx, sd)
|
|
|
|
err = fn(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Deadline returns the 'absolute' expiry time for the session. Please note
|
|
// that if you are using an idle timeout, it is possible that a session will
|
|
// expire due to non-use before the returned deadline.
|
|
func (s *SessionManager) Deadline(ctx context.Context) time.Time {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
return sd.deadline
|
|
}
|
|
|
|
// Token returns the session token. Please note that this will return the
|
|
// empty string "" if it is called before the session has been committed to
|
|
// the store.
|
|
func (s *SessionManager) Token(ctx context.Context) string {
|
|
sd := s.getSessionDataFromContext(ctx)
|
|
|
|
sd.mu.Lock()
|
|
defer sd.mu.Unlock()
|
|
|
|
return sd.token
|
|
}
|
|
|
|
func (s *SessionManager) addSessionDataToContext(ctx context.Context, sd *sessionData) context.Context {
|
|
return context.WithValue(ctx, s.contextKey, sd)
|
|
}
|
|
|
|
func (s *SessionManager) getSessionDataFromContext(ctx context.Context) *sessionData {
|
|
c, ok := ctx.Value(s.contextKey).(*sessionData)
|
|
if !ok {
|
|
panic("scs: no session data in context")
|
|
}
|
|
return c
|
|
}
|
|
|
|
func generateToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
_, err := rand.Read(b)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
var (
|
|
contextKeyID uint64
|
|
contextKeyIDMutex = &sync.Mutex{}
|
|
)
|
|
|
|
func generateContextKey() contextKey {
|
|
contextKeyIDMutex.Lock()
|
|
defer contextKeyIDMutex.Unlock()
|
|
atomic.AddUint64(&contextKeyID, 1)
|
|
return contextKey(fmt.Sprintf("session.%d", contextKeyID))
|
|
}
|
|
|
|
func (s *SessionManager) doStoreDelete(ctx context.Context, token string) (err error) {
|
|
c, ok := s.Store.(interface {
|
|
DeleteCtx(context.Context, string) error
|
|
})
|
|
if ok {
|
|
return c.DeleteCtx(ctx, token)
|
|
}
|
|
return s.Store.Delete(token)
|
|
}
|
|
|
|
func (s *SessionManager) doStoreFind(ctx context.Context, token string) (b []byte, found bool, err error) {
|
|
c, ok := s.Store.(interface {
|
|
FindCtx(context.Context, string) ([]byte, bool, error)
|
|
})
|
|
if ok {
|
|
return c.FindCtx(ctx, token)
|
|
}
|
|
return s.Store.Find(token)
|
|
}
|
|
|
|
func (s *SessionManager) doStoreCommit(ctx context.Context, token string, b []byte, expiry time.Time) (err error) {
|
|
c, ok := s.Store.(interface {
|
|
CommitCtx(context.Context, string, []byte, time.Time) error
|
|
})
|
|
if ok {
|
|
return c.CommitCtx(ctx, token, b, expiry)
|
|
}
|
|
return s.Store.Commit(token, b, expiry)
|
|
}
|
|
|
|
func (s *SessionManager) doStoreAll(ctx context.Context) (map[string][]byte, error) {
|
|
cs, ok := s.Store.(CtxStore)
|
|
if ok {
|
|
return cs.AllCtx(ctx)
|
|
}
|
|
|
|
is, ok := s.Store.(IterableStore)
|
|
if ok {
|
|
return is.All()
|
|
}
|
|
|
|
panic(fmt.Sprintf("type %T does not support iteration", s.Store))
|
|
}
|