mirror of
https://github.com/alexedwards/scs.git
synced 2025-07-13 01:00:17 +02:00
492 lines
13 KiB
Go
492 lines
13 KiB
Go
![]() |
package scs
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"encoding/base64"
|
||
|
"encoding/gob"
|
||
|
"fmt"
|
||
|
"sort"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
// Status represents the state of the session data during a request cycle.
|
||
|
type Status int
|
||
|
|
||
|
const (
|
||
|
// Unmodified indicates that the session data hasn't been changed in the
|
||
|
// current request cycle.
|
||
|
Unmodified Status = iota
|
||
|
|
||
|
// Modified indicates that the session data has been changed in the current
|
||
|
// request cycle.
|
||
|
Modified
|
||
|
|
||
|
// Destroyed indicates that the session data has been destroyed in the
|
||
|
// current request cycle.
|
||
|
Destroyed
|
||
|
)
|
||
|
|
||
|
type sessionData struct {
|
||
|
Deadline time.Time // Exported for gob encoding.
|
||
|
status Status
|
||
|
token string
|
||
|
Values map[string]interface{} // Exported for gob encoding.
|
||
|
mu sync.Mutex
|
||
|
}
|
||
|
|
||
|
func newSessionData(lifetime time.Duration) *sessionData {
|
||
|
return &sessionData{
|
||
|
Deadline: time.Now().Add(lifetime).UTC(),
|
||
|
status: Unmodified,
|
||
|
Values: make(map[string]interface{}),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Load retrieves the session data for the given token from the session store,
|
||
|
// and returns a new context.Context containing the session data. If no matching
|
||
|
// token is found then this will create a new session.
|
||
|
//
|
||
|
// Most applications will use the LoadAndSave() middleware and will not need to
|
||
|
// use this method.
|
||
|
func (s *Session) Load(ctx context.Context, token string) (context.Context, error) {
|
||
|
if _, ok := ctx.Value(s.contextKey).(*sessionData); ok {
|
||
|
return ctx, nil
|
||
|
}
|
||
|
|
||
|
if token == "" {
|
||
|
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
|
||
|
}
|
||
|
|
||
|
b, found, err := s.Store.Find(token)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
} else if !found {
|
||
|
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
|
||
|
}
|
||
|
|
||
|
sd := &sessionData{
|
||
|
status: Unmodified,
|
||
|
token: token,
|
||
|
}
|
||
|
err = sd.decode(b)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
// Mark the session data as modified if an idle timeout is being used. This
|
||
|
// will force the session data to be re-committed to the session store with
|
||
|
// a new expiry time.
|
||
|
if s.IdleTimeout > 0 {
|
||
|
sd.status = Modified
|
||
|
}
|
||
|
|
||
|
return s.addSessionDataToContext(ctx, sd), nil
|
||
|
}
|
||
|
|
||
|
// Commit saves the session data to the session store and returns the session
|
||
|
// token and expiry time.
|
||
|
//
|
||
|
// Most applications will use the LoadAndSave() middleware and will not need to
|
||
|
// use this method.
|
||
|
func (s *Session) Commit(ctx context.Context) (string, time.Time, error) {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
if sd.token == "" {
|
||
|
var err error
|
||
|
sd.token, err = generateToken()
|
||
|
if err != nil {
|
||
|
return "", time.Time{}, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
b, err := sd.encode()
|
||
|
if err != nil {
|
||
|
return "", time.Time{}, err
|
||
|
}
|
||
|
|
||
|
expiry := sd.Deadline
|
||
|
if s.IdleTimeout > 0 {
|
||
|
ie := time.Now().Add(s.IdleTimeout)
|
||
|
if ie.Before(expiry) {
|
||
|
expiry = ie
|
||
|
}
|
||
|
}
|
||
|
|
||
|
err = s.Store.Commit(sd.token, b, expiry)
|
||
|
if err != nil {
|
||
|
return "", time.Time{}, err
|
||
|
}
|
||
|
|
||
|
return sd.token, expiry, nil
|
||
|
}
|
||
|
|
||
|
// Destroy deletes the session data from the session store and sets the session
|
||
|
// status to Destroyed. Any futher operations in the same request cycle will
|
||
|
// result in a new session being created.
|
||
|
func (s *Session) Destroy(ctx context.Context) error {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
err := s.Store.Delete(sd.token)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
sd.status = Destroyed
|
||
|
|
||
|
// Reset everything else to defaults.
|
||
|
sd.token = ""
|
||
|
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
|
||
|
for key := range sd.Values {
|
||
|
delete(sd.Values, key)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Put adds a key and corresponding value to the session data. Any existing
|
||
|
// value for the key will be replaced. The session data status will be set to
|
||
|
// Modified.
|
||
|
func (s *Session) Put(ctx context.Context, key string, val interface{}) {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
sd.Values[key] = val
|
||
|
sd.status = Modified
|
||
|
sd.mu.Unlock()
|
||
|
}
|
||
|
|
||
|
// Get returns the value for a given key from the session data. The return
|
||
|
// value has the type interface{} so will usually need to be type asserted
|
||
|
// before you can use it. For example:
|
||
|
//
|
||
|
// foo, ok := session.Get(r, "foo").(string)
|
||
|
// if !ok {
|
||
|
// return errors.New("type assertion to string failed")
|
||
|
// }
|
||
|
//
|
||
|
// Also see the GetString(), GetInt(), GetBytes() and other helper methods which
|
||
|
// wrap the type conversion for common types.
|
||
|
func (s *Session) Get(ctx context.Context, key string) interface{} {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
return sd.Values[key]
|
||
|
}
|
||
|
|
||
|
// Pop acts like a one-time Get. It returns the value for a given key from the
|
||
|
// session data and deletes the key and value from the session data. The
|
||
|
// session data status will be set to Modified. The return value has the type
|
||
|
// interface{} so will usually need to be type asserted before you can use it.
|
||
|
func (s *Session) Pop(ctx context.Context, key string) interface{} {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
val, exists := sd.Values[key]
|
||
|
if !exists {
|
||
|
return nil
|
||
|
}
|
||
|
delete(sd.Values, key)
|
||
|
sd.status = Modified
|
||
|
|
||
|
return val
|
||
|
}
|
||
|
|
||
|
// Remove deletes the given key and corresponding value from the session data.
|
||
|
// The session data status will be set to Modified. If the key is not present
|
||
|
// this operation is a no-op.
|
||
|
func (s *Session) Remove(ctx context.Context, key string) {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
_, exists := sd.Values[key]
|
||
|
if !exists {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
delete(sd.Values, key)
|
||
|
sd.status = Modified
|
||
|
}
|
||
|
|
||
|
// Exists returns true if the given key is present in the session data.
|
||
|
func (s *Session) Exists(ctx context.Context, key string) bool {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
_, exists := sd.Values[key]
|
||
|
sd.mu.Unlock()
|
||
|
|
||
|
return exists
|
||
|
}
|
||
|
|
||
|
// Keys returns a slice of all key names present in the session data, sorted
|
||
|
// alphabetically. If the data contains no data then an empty slice will be
|
||
|
// returned.
|
||
|
func (s *Session) Keys(ctx context.Context) []string {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
keys := make([]string, len(sd.Values))
|
||
|
i := 0
|
||
|
for key := range sd.Values {
|
||
|
keys[i] = key
|
||
|
i++
|
||
|
}
|
||
|
sd.mu.Unlock()
|
||
|
|
||
|
sort.Strings(keys)
|
||
|
return keys
|
||
|
}
|
||
|
|
||
|
// RenewToken updates the session data to have a new session token while
|
||
|
// retaining the current session data. The session lifetime is also reset and
|
||
|
// the session data status will be set to Modified.
|
||
|
//
|
||
|
// The old session token and accompanying data are deleted from the session store.
|
||
|
//
|
||
|
// To mitigate the risk of session fixation attacks, it's important that you call
|
||
|
// RenewToken before making any changes to privilege levels (e.g. login and
|
||
|
// logout operations). See https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#renew-the-session-id-after-any-privilege-level-change
|
||
|
// for additional information.
|
||
|
func (s *Session) RenewToken(ctx context.Context) error {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
err := s.Store.Delete(sd.token)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
newToken, err := generateToken()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
sd.token = newToken
|
||
|
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
|
||
|
sd.status = Modified
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Status returns the current status of the session data.
|
||
|
func (s *Session) Status(ctx context.Context) Status {
|
||
|
sd := s.getSessionDataFromContext(ctx)
|
||
|
|
||
|
sd.mu.Lock()
|
||
|
defer sd.mu.Unlock()
|
||
|
|
||
|
return sd.status
|
||
|
}
|
||
|
|
||
|
// GetString returns the string value for a given key from the session data.
|
||
|
// The zero value for a string ("") is returned if the key does not exist or the
|
||
|
// value could not be type asserted to a string.
|
||
|
func (s *Session) GetString(ctx context.Context, key string) string {
|
||
|
val := s.Get(ctx, key)
|
||
|
str, ok := val.(string)
|
||
|
if !ok {
|
||
|
return ""
|
||
|
}
|
||
|
return str
|
||
|
}
|
||
|
|
||
|
// GetBool returns the bool value for a given key from the session data. The
|
||
|
// zero value for a bool (false) is returned if the key does not exist or the
|
||
|
// value could not be type asserted to a bool.
|
||
|
func (s *Session) GetBool(ctx context.Context, key string) bool {
|
||
|
val := s.Get(ctx, key)
|
||
|
b, ok := val.(bool)
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// GetInt returns the int value for a given key from the session data. The
|
||
|
// zero value for an int (0) is returned if the key does not exist or the
|
||
|
// value could not be type asserted to an int.
|
||
|
func (s *Session) GetInt(ctx context.Context, key string) int {
|
||
|
val := s.Get(ctx, key)
|
||
|
i, ok := val.(int)
|
||
|
if !ok {
|
||
|
return 0
|
||
|
}
|
||
|
return i
|
||
|
}
|
||
|
|
||
|
// GetFloat returns the float64 value for a given key from the session data. The
|
||
|
// zero value for an float64 (0) is returned if the key does not exist or the
|
||
|
// value could not be type asserted to a float64.
|
||
|
func (s *Session) GetFloat(ctx context.Context, key string) float64 {
|
||
|
val := s.Get(ctx, key)
|
||
|
f, ok := val.(float64)
|
||
|
if !ok {
|
||
|
return 0
|
||
|
}
|
||
|
return f
|
||
|
}
|
||
|
|
||
|
// GetBytes returns the byte slice ([]byte) value for a given key from the session
|
||
|
// data. The zero value for a slice (nil) is returned if the key does not exist
|
||
|
// or could not be type asserted to []byte.
|
||
|
func (s *Session) GetBytes(ctx context.Context, key string) []byte {
|
||
|
val := s.Get(ctx, key)
|
||
|
b, ok := val.([]byte)
|
||
|
if !ok {
|
||
|
return nil
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// GetTime returns the time.Time value for a given key from the session data. The
|
||
|
// zero value for a time.Time object is returned if the key does not exist or the
|
||
|
// value could not be type asserted to a time.Time. This can be tested with the
|
||
|
// time.IsZero() method.
|
||
|
func (s *Session) GetTime(ctx context.Context, key string) time.Time {
|
||
|
val := s.Get(ctx, key)
|
||
|
t, ok := val.(time.Time)
|
||
|
if !ok {
|
||
|
return time.Time{}
|
||
|
}
|
||
|
return t
|
||
|
}
|
||
|
|
||
|
// PopString returns the string value for a given key and then deletes it from the
|
||
|
// session data. The session data status will be set to Modified. The zero
|
||
|
// value for a string ("") is returned if the key does not exist or the value
|
||
|
// could not be type asserted to a string.
|
||
|
func (s *Session) PopString(ctx context.Context, key string) string {
|
||
|
val := s.Pop(ctx, key)
|
||
|
str, ok := val.(string)
|
||
|
if !ok {
|
||
|
return ""
|
||
|
}
|
||
|
return str
|
||
|
}
|
||
|
|
||
|
// PopBool returns the bool value for a given key and then deletes it from the
|
||
|
// session data. The session data status will be set to Modified. The zero
|
||
|
// value for a bool (false) is returned if the key does not exist or the value
|
||
|
// could not be type asserted to a bool.
|
||
|
func (s *Session) PopBool(ctx context.Context, key string) bool {
|
||
|
val := s.Pop(ctx, key)
|
||
|
b, ok := val.(bool)
|
||
|
if !ok {
|
||
|
return false
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// PopInt returns the int value for a given key and then deletes it from the
|
||
|
// session data. The session data status will be set to Modified. The zero
|
||
|
// value for an int (0) is returned if the key does not exist or the value could
|
||
|
// not be type asserted to an int.
|
||
|
func (s *Session) PopInt(ctx context.Context, key string) int {
|
||
|
val := s.Pop(ctx, key)
|
||
|
i, ok := val.(int)
|
||
|
if !ok {
|
||
|
return 0
|
||
|
}
|
||
|
return i
|
||
|
}
|
||
|
|
||
|
// PopFloat returns the float64 value for a given key and then deletes it from the
|
||
|
// session data. The session data status will be set to Modified. The zero
|
||
|
// value for an float64 (0) is returned if the key does not exist or the value
|
||
|
// could not be type asserted to a float64.
|
||
|
func (s *Session) PopFloat(ctx context.Context, key string) float64 {
|
||
|
val := s.Pop(ctx, key)
|
||
|
f, ok := val.(float64)
|
||
|
if !ok {
|
||
|
return 0
|
||
|
}
|
||
|
return f
|
||
|
}
|
||
|
|
||
|
// PopBytes returns the byte slice ([]byte) value for a given key and then
|
||
|
// deletes it from the from the session data. The session data status will be
|
||
|
// set to Modified. The zero value for a slice (nil) is returned if the key does
|
||
|
// not exist or could not be type asserted to []byte.
|
||
|
func (s *Session) PopBytes(ctx context.Context, key string) []byte {
|
||
|
val := s.Pop(ctx, key)
|
||
|
b, ok := val.([]byte)
|
||
|
if !ok {
|
||
|
return nil
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// PopTime returns the time.Time value for a given key and then deletes it from
|
||
|
// the session data. The session data status will be set to Modified. The zero
|
||
|
// value for a time.Time object is returned if the key does not exist or the
|
||
|
// value could not be type asserted to a time.Time.
|
||
|
func (s *Session) PopTime(ctx context.Context, key string) time.Time {
|
||
|
val := s.Pop(ctx, key)
|
||
|
t, ok := val.(time.Time)
|
||
|
if !ok {
|
||
|
return time.Time{}
|
||
|
}
|
||
|
return t
|
||
|
}
|
||
|
|
||
|
func (s *Session) addSessionDataToContext(ctx context.Context, sd *sessionData) context.Context {
|
||
|
return context.WithValue(ctx, s.contextKey, sd)
|
||
|
}
|
||
|
|
||
|
func (s *Session) getSessionDataFromContext(ctx context.Context) *sessionData {
|
||
|
c, ok := ctx.Value(s.contextKey).(*sessionData)
|
||
|
if !ok {
|
||
|
panic("scs: no session data in context")
|
||
|
}
|
||
|
return c
|
||
|
}
|
||
|
|
||
|
func (sd *sessionData) encode() ([]byte, error) {
|
||
|
var b bytes.Buffer
|
||
|
err := gob.NewEncoder(&b).Encode(sd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return b.Bytes(), nil
|
||
|
}
|
||
|
|
||
|
func (sd *sessionData) decode(b []byte) error {
|
||
|
r := bytes.NewReader(b)
|
||
|
return gob.NewDecoder(r).Decode(sd)
|
||
|
}
|
||
|
|
||
|
func generateToken() (string, error) {
|
||
|
b := make([]byte, 32)
|
||
|
_, err := rand.Read(b)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||
|
}
|
||
|
|
||
|
type contextKey string
|
||
|
|
||
|
var contextKeyID int
|
||
|
|
||
|
func generateContextKey() contextKey {
|
||
|
contextKeyID = contextKeyID + 1
|
||
|
return contextKey(fmt.Sprintf("session.%d", contextKeyID))
|
||
|
}
|