1
0
mirror of https://github.com/alexedwards/scs.git synced 2025-07-13 01:00:17 +02:00

Merge v2 changes

This commit is contained in:
Alex Edwards
2019-04-28 07:30:35 +02:00
parent cfcbf41460
commit 47d01d5475
42 changed files with 1997 additions and 4203 deletions

350
README.md
View File

@ -1,25 +1,28 @@
# SCS: A HTTP Session Manager
[![godoc](https://godoc.org/github.com/alexedwards/scs?status.png)](https://godoc.org/github.com/alexedwards/scs) [![go report card](https://goreportcard.com/badge/github.com/alexedwards/scs)](https://goreportcard.com/report/github.com/alexedwards/scs)
# SCS: HTTP Session Management for Go
SCS is a fast and lightweight HTTP session manager for Go. It features:
* [Installation](#installation)
* [The Basics](#the-basics)
* [Configuring Session Behavior](#configuring-session-behavior)
* [Working with Session Data](#working-with-session-data)
* [Loading and Saving Sessions](#loading-and-saving-sessions)
* [Configuring the Session Store](#configuring-the-session-store)
* [Using with PostgreSQL](#using-with-postgresql)
* [Using with MySQL](#using-with-mysql)
* [Using Custom Session Stores](#using-custom-session-stores)
* [Preventing Session Fixation](#preventing-session-fixation)
* [Multiple Sessions per Request](#multiple-sessions-per-request)
* [Compatibility](#compatibility)
* Built-in PostgreSQL, MySQL, Redis, Memcached, encrypted cookie and in-memory storage engines. Custom storage engines are also supported.
* Supports OWASP good-practices, including absolute and idle session timeouts and easy regeneration of session tokens.
* Fast and very memory-efficient performance.
* Type-safe and sensible API for managing session data. Safe for concurrent use.
* Automatic saving of session data.
**Recent changes:** Release v1.0.0 made breaking changes to the package layout and API. If you need the old version please vendor [release v0.1.1](https://github.com/alexedwards/scs/releases/tag/v0.1.1).
## Installation
## Installation & Usage
Install with `go get`:
```sh
$ go get github.com/alexedwards/scs
```
$ go get github.com/alexedwards/scs/v2
```
### Basic use
## The Basics
SCS implements a session management pattern following the [OWASP security guidelines](https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md). Session data is stored on the server, and a randomly-generated unique session token (or *session ID*) is communicated to and from the client in a session cookie.
```go
package main
@ -28,105 +31,284 @@ import (
"io"
"net/http"
"github.com/alexedwards/scs"
"github.com/alexedwards/scs/v2"
)
// Initialize a new encrypted-cookie based session manager and store it in a global
// variable. In a real application, you might inject the session manager as a
// dependency to your handlers instead. The parameter to the NewCookieManager()
// function is a 32 character long random key, which is used to encrypt and
// authenticate the session cookies.
var sessionManager = scs.NewCookieManager("u46IpCV9y5Vlur8YvODJEhgOY8m9JVE4")
var session *scs.Session
func main() {
// Set up your HTTP handlers in the normal way.
// Initialize the session manager.
session = scs.NewSession()
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
// Wrap your handlers with the session manager middleware.
http.ListenAndServe(":4000", sessionManager.Use(mux))
// Wrap your handlers with the LoadAndSave() middleware.
http.ListenAndServe(":4000", session.LoadAndSave(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
// Load the session data for the current request. Any errors are deferred
// until you actually use the session data.
session := sessionManager.Load(r)
// Use the PutString() method to add a new key and associated string value
// to the session data. Methods for many other common data types are also
// provided. The session data is automatically saved.
err := session.PutString(w, "message", "Hello world!")
if err != nil {
http.Error(w, err.Error(), 500)
}
// Store a new key and value in the session data.
session.Put(r.Context(), "message", "Hello from a session!")
}
func getHandler(w http.ResponseWriter, r *http.Request) {
// Load the session data for the current request.
session := sessionManager.Load(r)
// Use the GetString() helper to retrieve the string value for the "message"
// key from the session data. The zero value for a string is returned if the
// key does not exist.
message, err := session.GetString("message")
if err != nil {
http.Error(w, err.Error(), 500)
}
io.WriteString(w, message)
// Use the GetString helper to retrieve the string value associated with a
// key. The zero value is returned if the key does not exist.
msg := session.GetString(r.Context(), "message")
io.WriteString(w, msg)
}
```
SCS provides a wide range of functions for working with session data.
```
$ curl -i --cookie-jar cj --cookie cj localhost:4000/put
HTTP/1.1 200 OK
Cache-Control: no-cache="Set-Cookie"
Set-Cookie: session=lHqcPNiQp_5diPxumzOklsSdE-MJ7zyU6kjch1Ee0UM; Path=/; Expires=Sat, 27 Apr 2019 10:28:20 GMT; Max-Age=86400; HttpOnly; SameSite=Lax
Vary: Cookie
Date: Fri, 26 Apr 2019 10:28:19 GMT
Content-Length: 0
* `Put…` and `Get…` methods for storing and retrieving a variety of common data types and custom objects.
* `Pop…` methods for one-time retrieval of common data types (and custom objects) from the session data.
* `Keys` returns an alphabetically-sorted slice of all keys in the session data.
* `Exists` returns whether a specific key exists in the session data.
* `Remove` removes an individual key and value from the session data.
* `Clear` removes all data for the current session.
* `RenewToken` creates a new session token. This should be used before privilege changes to help avoid session fixation.
* `Destroy` deletes the current session and instructs the browser to delete the session cookie.
$ curl -i --cookie-jar cj --cookie cj localhost:4000/get
HTTP/1.1 200 OK
Date: Fri, 26 Apr 2019 10:28:24 GMT
Content-Length: 21
Content-Type: text/plain; charset=utf-8
A full list of available functions can be found in [the GoDoc](https://godoc.org/github.com/alexedwards/scs/#pkg-index).
Hello from a session!
```
### Customizing the session manager
## Configuring Session Behavior
The session manager can be configured to customize its behavior. For example:
Session behavior can be configured via the `Session` fields. For example:
```go
sessionManager = scs.NewCookieManager("u46IpCV9y5Vlur8YvODJEhgOY8m9JVE4")
sessionManager.Lifetime(time.Hour) // Set the maximum session lifetime to 1 hour.
sessionManager.Persist(true) // Persist the session after a user has closed their browser.
sessionManager.Secure(true) // Set the Secure flag on the session cookie.
session = scs.NewSession()
session.Lifetime = 3 * time.Hour
session.IdleTimeout = 20 * time.Minute
session.Cookie.Persist = false
session.Cookie.SameSite = http.SameSiteStrictMode
session.Cookie.Secure = true
```
A full list of available settings can be found in [the GoDoc](https://godoc.org/github.com/alexedwards/scs/#pkg-index).
Documentation for all available settings and their default values can be [found here](https://godoc.org/github.com/alexedwards/scs#Session).
### Using a different session store
## Working with Session Data
The above examples use encrypted cookies to store session data, but SCS also supports a range of server-side stores.
Data can be set using the [`Put()`](https://godoc.org/github.com/alexedwards/scs#Session.Put) method and retrieved with the [`Get()`](https://godoc.org/github.com/alexedwards/scs#Session.Get) method. A variety of helper methods like [`GetString()`](https://godoc.org/github.com/alexedwards/scs#Session.GetString), [`GetInt()`](https://godoc.org/github.com/alexedwards/scs#Session.GetInt) and [`GetBytes()`](https://godoc.org/github.com/alexedwards/scs#Session.GetBytes) are included for common data types. Please see [the documentation](https://godoc.org/github.com/alexedwards/scs#pkg-index) for a full list of helper methods.
The [`Pop()`](https://godoc.org/github.com/alexedwards/scs#Session.Pop) method (and accompanying helpers for common data types) act like a one-time `Get()`, retrieving the data and removing it from the session in one step. These are useful if you want to implement 'flash' message functionality in your application, where messages are displayed to the user once only.
Some other useful functions are [`Exists()`](https://godoc.org/github.com/alexedwards/scs#Session.Exists) (which returns a `bool` indicating whether or not a given key exists in the session data) and [`Keys()`](https://godoc.org/github.com/alexedwards/scs#Session.Keys) (which returns a sorted slice of keys in the session data).
Individual data items can be deleted from the session using the [`Remove()`](https://godoc.org/github.com/alexedwards/scs#Session.Remove) method. Alternatively, all session data can de deleted by using the [`Destroy()`](https://godoc.org/github.com/alexedwards/scs#Session.Destroy) method. After calling `Destroy()`, any further operations in the same request cycle will result in a new session being created --- with a new session token and a new lifetime.
## Loading and Saving Sessions
Most applications will use the [`LoadAndSave()`](https://godoc.org/github.com/alexedwards/scs#Session.LoadAndSave) middleware. This middleware takes care of loading and committing session data to the session store, and communicating the session token to/from the client in a cookie as necessary.
If you want to communicate the session token to/from the client in a different way (for example in a different HTTP header) you are encouraged to create your own alternative middleware using the code in [`LoadAndSave()`](https://godoc.org/github.com/alexedwards/scs#Session.LoadAndSave) as a template. An example is [given here](https://gist.github.com/alexedwards/cc6190195acfa466bf27f05aa5023f50).
Or for more fine-grained control you can load and save sessions within your individual handlers (or from anywhere in your application). [See here](https://gist.github.com/alexedwards/0570e5a59677e278e13acb8ea53a3b30) for an example.
## Configuring the Session Store
By default SCS uses an in-memory store for session data. This is convenient (no setup!) and very fast, but all session data will be lost when your application is stopped or restarted. Therefore it's useful for applications where data loss is an acceptable trade off for fast performance, or for prototyping and testing purposes. In most production applications you will want to use a persistent session store like PostgreSQL or MySQL instead.
The session stores currently included are:
| Package | |
|:------------------------------------------------------------------------------------- |-----------------------------------------------------------------------------------|
| [stores/boltstore](https://godoc.org/github.com/alexedwards/scs/stores/boltstore) | BoltDB-based session store |
| [stores/buntstore](https://godoc.org/github.com/alexedwards/scs/stores/buntstore) | BuntDB based session store |
| [stores/cookiestore](https://godoc.org/github.com/alexedwards/scs/stores/cookiestore) | Encrypted-cookie session store |
| [stores/dynamostore](https://godoc.org/github.com/alexedwards/scs/stores/dynamostore) | DynamoDB-based session store |
| [stores/memstore](https://godoc.org/github.com/alexedwards/scs/stores/memstore) | In-memory session store |
| [stores/mysqlstore](https://godoc.org/github.com/alexedwards/scs/stores/mysqlstore) | MySQL-based session store |
| [stores/pgstore](https://godoc.org/github.com/alexedwards/scs/stores/pgstore) | PostgreSQL-based storage eninge |
| [stores/qlstore](https://godoc.org/github.com/alexedwards/scs/stores/qlstore) | QL-based session store |
| [stores/redisstore](https://godoc.org/github.com/alexedwards/scs/stores/redisstore) | Redis-based session store |
| [stores/memcached](https://godoc.org/github.com/alexedwards/scs/stores/memcachedstore)| Memcached-based session store |
|:------------------------------------------------------------------------------------- |----------------------------------------------------------------------------------|
| [memstore](https://github.com/alexedwards/scs/tree/master/memstore) | In-memory session store (default) |
| [mysqlstore](https://github.com/alexedwards/scs/tree/master/mysqlstore) | MySQL based session store |
| [postgresstore](https://github.com/alexedwards/scs/tree/master/postgresstore) | PostgreSQL based session store |
### Compatibility
Custom session stores are also supported. Please [see here](#using-custom-session-stores) for more information.
SCS is designed to be compatible with Go's `net/http` package and the `http.Handler` interface.
### Using with PostgreSQL
If you're using the [Echo](https://echo.labstack.com/) framework, the [official session middleware](https://echo.labstack.com/middleware/session) for Echo is likely to be a better fit for your application.
Please see the `postgresstore` [package documentation](https://github.com/alexedwards/scs/tree/master/postgresstore) for full information and sample code. But in summary...
### Examples
You'll need to create a `sessions` table:
* [RequireLogin middleware](https://gist.github.com/alexedwards/6eac2f19b9b5c064ca90f756c32f94cc)
```sql
CREATE TABLE sessions (
token TEXT PRIMARY KEY,
data BYTEA NOT NULL,
expiry TIMESTAMPTZ NOT NULL
);
CREATE INDEX sessions_expiry_idx ON sessions (expiry);
```
And then you can then use it like this:
```go
package main
import (
"database/sql"
"io"
"log"
"net/http"
"github.com/alexedwards/scs/v2"
"github.com/alexedwards/scs/v2/postgresstore"
_ "github.com/lib/pq"
)
var session *scs.Session
func main() {
db, err := sql.Open("postgres", "postgres://user:pass@localhost/db")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// Initialize a new session manager and configure it to use PostgreSQL as
// the session store.
session = scs.NewSession()
session.Store = postgresstore.New(db)
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
http.ListenAndServe(":4000", session.LoadAndSave(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "message", "Hello from a session!")
}
func getHandler(w http.ResponseWriter, r *http.Request) {
msg := session.GetString(r.Context(), "message")
io.WriteString(w, msg)
}
```
A background 'cleanup' goroutine is automatically run to delete expired session data. This stops the database table from holding on to invalid sessions indefinitely and growing unnecessarily large. By default the cleanup will run every 5 minutes.
### Using with MySQL
Please see the `mysqlstore` [package documentation](https://github.com/alexedwards/scs/tree/master/mysqlstore) for full information and sample code. But in summary...
You'll need to create a `sessions` table:
```sql
CREATE TABLE sessions (
token CHAR(43) PRIMARY KEY,
data BLOB NOT NULL,
expiry TIMESTAMP(6) NOT NULL
);
CREATE INDEX sessions_expiry_idx ON sessions (expiry);
```
And then you can then use it like this:
```go
package main
import (
"database/sql"
"io"
"log"
"net/http"
"github.com/alexedwards/scs/v2"
"github.com/alexedwards/scs/v2/mysqlstore"
_ "github.com/go-sql-driver/mysql"
)
var session *scs.Session
func main() {
db, err := sql.Open("mysql", "user:pass@/db?parseTime=true")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// Initialize a new session manager and configure it to use PostgreSQL as
// the session store.
session = scs.NewSession()
session.Store = mysqlstore.New(db)
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
http.ListenAndServe(":4000", session.LoadAndSave(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "message", "Hello from a session!")
}
func getHandler(w http.ResponseWriter, r *http.Request) {
msg := session.GetString(r.Context(), "message")
io.WriteString(w, msg)
}
```
A background 'cleanup' goroutine is automatically run to delete expired session data. This stops the database table from holding on to invalid sessions indefinitely and growing unnecessarily large. By default the cleanup will run every 5 minutes.
### Using Custom Session Stores
[`scs.Store`](https://godoc.org/github.com/alexedwards/scs#Store) defines the interface for custom session stores. Any object that implements this interface can be set as the store when configuring the session.
```go
type Store interface {
// Delete should remove the session token and corresponding data from the
// session store. If the token does not exist then Delete should be a no-op
// and return nil (not an error).
Delete(token string) (err error)
// Find should return the data for a session token from the store. If the
// session token is not found or is expired, the found return value should
// be false (and the err return value should be nil). Similarly, tampered
// or malformed tokens should result in a found return value of false and a
// nil err value. The err return value should be used for system errors only.
Find(token string) (b []byte, found bool, err error)
// Commit should add the session token and data to the store, with the given
// expiry time. If the session token already exists, then the data and
// expiry time should be overwritten.
Commit(token string, b []byte, expiry time.Time) (err error)
}
```
## Preventing Session Fixation
To help prevent session fixation attacks you should [renew the session token after any privilege level change](https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#renew-the-session-id-after-any-privilege-level-change). Commonly, this means that the session token must to be changed when a user logs in or out of your application. You can do this using the [`RenewToken()`](https://godoc.org/github.com/alexedwards/scs#Session.RenewToken) method like so:
```go
func loginHandler(w http.ResponseWriter, r *http.Request) {
userID := 123
// First renew the session token...
err := session.RenewToken(r.Context())
if err != nil {
http.Error(w, err.Error(), 500)
return
}
// Then make the privilege-level change.
session.Put(r.Context(), "userID", userID)
}
```
## Multiple Sessions per Request
It is possible for an application to support multiple sessions per request, with different lifetime lengths and even different stores. Please [see here for an example](https://gist.github.com/alexedwards/22535f758356bfaf96038fffad154824).
## Compatibility
This package requires Go 1.11 or newer.
It is not compatible with the [Echo](https://echo.labstack.com/) framework. Please consider using the [Echo session manager](https://echo.labstack.com/middleware/session) instead.

491
data.go Normal file
View File

@ -0,0 +1,491 @@
package scs
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/gob"
"fmt"
"sort"
"sync"
"time"
)
// Status represents the state of the session data during a request cycle.
type Status int
const (
// Unmodified indicates that the session data hasn't been changed in the
// current request cycle.
Unmodified Status = iota
// Modified indicates that the session data has been changed in the current
// request cycle.
Modified
// Destroyed indicates that the session data has been destroyed in the
// current request cycle.
Destroyed
)
type sessionData struct {
Deadline time.Time // Exported for gob encoding.
status Status
token string
Values map[string]interface{} // Exported for gob encoding.
mu sync.Mutex
}
func newSessionData(lifetime time.Duration) *sessionData {
return &sessionData{
Deadline: time.Now().Add(lifetime).UTC(),
status: Unmodified,
Values: make(map[string]interface{}),
}
}
// Load retrieves the session data for the given token from the session store,
// and returns a new context.Context containing the session data. If no matching
// token is found then this will create a new session.
//
// Most applications will use the LoadAndSave() middleware and will not need to
// use this method.
func (s *Session) Load(ctx context.Context, token string) (context.Context, error) {
if _, ok := ctx.Value(s.contextKey).(*sessionData); ok {
return ctx, nil
}
if token == "" {
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
}
b, found, err := s.Store.Find(token)
if err != nil {
return nil, err
} else if !found {
return s.addSessionDataToContext(ctx, newSessionData(s.Lifetime)), nil
}
sd := &sessionData{
status: Unmodified,
token: token,
}
err = sd.decode(b)
if err != nil {
return nil, err
}
// Mark the session data as modified if an idle timeout is being used. This
// will force the session data to be re-committed to the session store with
// a new expiry time.
if s.IdleTimeout > 0 {
sd.status = Modified
}
return s.addSessionDataToContext(ctx, sd), nil
}
// Commit saves the session data to the session store and returns the session
// token and expiry time.
//
// Most applications will use the LoadAndSave() middleware and will not need to
// use this method.
func (s *Session) Commit(ctx context.Context) (string, time.Time, error) {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
if sd.token == "" {
var err error
sd.token, err = generateToken()
if err != nil {
return "", time.Time{}, err
}
}
b, err := sd.encode()
if err != nil {
return "", time.Time{}, err
}
expiry := sd.Deadline
if s.IdleTimeout > 0 {
ie := time.Now().Add(s.IdleTimeout)
if ie.Before(expiry) {
expiry = ie
}
}
err = s.Store.Commit(sd.token, b, expiry)
if err != nil {
return "", time.Time{}, err
}
return sd.token, expiry, nil
}
// Destroy deletes the session data from the session store and sets the session
// status to Destroyed. Any futher operations in the same request cycle will
// result in a new session being created.
func (s *Session) Destroy(ctx context.Context) error {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
err := s.Store.Delete(sd.token)
if err != nil {
return err
}
sd.status = Destroyed
// Reset everything else to defaults.
sd.token = ""
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
for key := range sd.Values {
delete(sd.Values, key)
}
return nil
}
// Put adds a key and corresponding value to the session data. Any existing
// value for the key will be replaced. The session data status will be set to
// Modified.
func (s *Session) Put(ctx context.Context, key string, val interface{}) {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
sd.Values[key] = val
sd.status = Modified
sd.mu.Unlock()
}
// Get returns the value for a given key from the session data. The return
// value has the type interface{} so will usually need to be type asserted
// before you can use it. For example:
//
// foo, ok := session.Get(r, "foo").(string)
// if !ok {
// return errors.New("type assertion to string failed")
// }
//
// Also see the GetString(), GetInt(), GetBytes() and other helper methods which
// wrap the type conversion for common types.
func (s *Session) Get(ctx context.Context, key string) interface{} {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
return sd.Values[key]
}
// Pop acts like a one-time Get. It returns the value for a given key from the
// session data and deletes the key and value from the session data. The
// session data status will be set to Modified. The return value has the type
// interface{} so will usually need to be type asserted before you can use it.
func (s *Session) Pop(ctx context.Context, key string) interface{} {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
val, exists := sd.Values[key]
if !exists {
return nil
}
delete(sd.Values, key)
sd.status = Modified
return val
}
// Remove deletes the given key and corresponding value from the session data.
// The session data status will be set to Modified. If the key is not present
// this operation is a no-op.
func (s *Session) Remove(ctx context.Context, key string) {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
_, exists := sd.Values[key]
if !exists {
return
}
delete(sd.Values, key)
sd.status = Modified
}
// Exists returns true if the given key is present in the session data.
func (s *Session) Exists(ctx context.Context, key string) bool {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
_, exists := sd.Values[key]
sd.mu.Unlock()
return exists
}
// Keys returns a slice of all key names present in the session data, sorted
// alphabetically. If the data contains no data then an empty slice will be
// returned.
func (s *Session) Keys(ctx context.Context) []string {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
keys := make([]string, len(sd.Values))
i := 0
for key := range sd.Values {
keys[i] = key
i++
}
sd.mu.Unlock()
sort.Strings(keys)
return keys
}
// RenewToken updates the session data to have a new session token while
// retaining the current session data. The session lifetime is also reset and
// the session data status will be set to Modified.
//
// The old session token and accompanying data are deleted from the session store.
//
// To mitigate the risk of session fixation attacks, it's important that you call
// RenewToken before making any changes to privilege levels (e.g. login and
// logout operations). See https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#renew-the-session-id-after-any-privilege-level-change
// for additional information.
func (s *Session) RenewToken(ctx context.Context) error {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
err := s.Store.Delete(sd.token)
if err != nil {
return err
}
newToken, err := generateToken()
if err != nil {
return err
}
sd.token = newToken
sd.Deadline = time.Now().Add(s.Lifetime).UTC()
sd.status = Modified
return nil
}
// Status returns the current status of the session data.
func (s *Session) Status(ctx context.Context) Status {
sd := s.getSessionDataFromContext(ctx)
sd.mu.Lock()
defer sd.mu.Unlock()
return sd.status
}
// GetString returns the string value for a given key from the session data.
// The zero value for a string ("") is returned if the key does not exist or the
// value could not be type asserted to a string.
func (s *Session) GetString(ctx context.Context, key string) string {
val := s.Get(ctx, key)
str, ok := val.(string)
if !ok {
return ""
}
return str
}
// GetBool returns the bool value for a given key from the session data. The
// zero value for a bool (false) is returned if the key does not exist or the
// value could not be type asserted to a bool.
func (s *Session) GetBool(ctx context.Context, key string) bool {
val := s.Get(ctx, key)
b, ok := val.(bool)
if !ok {
return false
}
return b
}
// GetInt returns the int value for a given key from the session data. The
// zero value for an int (0) is returned if the key does not exist or the
// value could not be type asserted to an int.
func (s *Session) GetInt(ctx context.Context, key string) int {
val := s.Get(ctx, key)
i, ok := val.(int)
if !ok {
return 0
}
return i
}
// GetFloat returns the float64 value for a given key from the session data. The
// zero value for an float64 (0) is returned if the key does not exist or the
// value could not be type asserted to a float64.
func (s *Session) GetFloat(ctx context.Context, key string) float64 {
val := s.Get(ctx, key)
f, ok := val.(float64)
if !ok {
return 0
}
return f
}
// GetBytes returns the byte slice ([]byte) value for a given key from the session
// data. The zero value for a slice (nil) is returned if the key does not exist
// or could not be type asserted to []byte.
func (s *Session) GetBytes(ctx context.Context, key string) []byte {
val := s.Get(ctx, key)
b, ok := val.([]byte)
if !ok {
return nil
}
return b
}
// GetTime returns the time.Time value for a given key from the session data. The
// zero value for a time.Time object is returned if the key does not exist or the
// value could not be type asserted to a time.Time. This can be tested with the
// time.IsZero() method.
func (s *Session) GetTime(ctx context.Context, key string) time.Time {
val := s.Get(ctx, key)
t, ok := val.(time.Time)
if !ok {
return time.Time{}
}
return t
}
// PopString returns the string value for a given key and then deletes it from the
// session data. The session data status will be set to Modified. The zero
// value for a string ("") is returned if the key does not exist or the value
// could not be type asserted to a string.
func (s *Session) PopString(ctx context.Context, key string) string {
val := s.Pop(ctx, key)
str, ok := val.(string)
if !ok {
return ""
}
return str
}
// PopBool returns the bool value for a given key and then deletes it from the
// session data. The session data status will be set to Modified. The zero
// value for a bool (false) is returned if the key does not exist or the value
// could not be type asserted to a bool.
func (s *Session) PopBool(ctx context.Context, key string) bool {
val := s.Pop(ctx, key)
b, ok := val.(bool)
if !ok {
return false
}
return b
}
// PopInt returns the int value for a given key and then deletes it from the
// session data. The session data status will be set to Modified. The zero
// value for an int (0) is returned if the key does not exist or the value could
// not be type asserted to an int.
func (s *Session) PopInt(ctx context.Context, key string) int {
val := s.Pop(ctx, key)
i, ok := val.(int)
if !ok {
return 0
}
return i
}
// PopFloat returns the float64 value for a given key and then deletes it from the
// session data. The session data status will be set to Modified. The zero
// value for an float64 (0) is returned if the key does not exist or the value
// could not be type asserted to a float64.
func (s *Session) PopFloat(ctx context.Context, key string) float64 {
val := s.Pop(ctx, key)
f, ok := val.(float64)
if !ok {
return 0
}
return f
}
// PopBytes returns the byte slice ([]byte) value for a given key and then
// deletes it from the from the session data. The session data status will be
// set to Modified. The zero value for a slice (nil) is returned if the key does
// not exist or could not be type asserted to []byte.
func (s *Session) PopBytes(ctx context.Context, key string) []byte {
val := s.Pop(ctx, key)
b, ok := val.([]byte)
if !ok {
return nil
}
return b
}
// PopTime returns the time.Time value for a given key and then deletes it from
// the session data. The session data status will be set to Modified. The zero
// value for a time.Time object is returned if the key does not exist or the
// value could not be type asserted to a time.Time.
func (s *Session) PopTime(ctx context.Context, key string) time.Time {
val := s.Pop(ctx, key)
t, ok := val.(time.Time)
if !ok {
return time.Time{}
}
return t
}
func (s *Session) addSessionDataToContext(ctx context.Context, sd *sessionData) context.Context {
return context.WithValue(ctx, s.contextKey, sd)
}
func (s *Session) getSessionDataFromContext(ctx context.Context) *sessionData {
c, ok := ctx.Value(s.contextKey).(*sessionData)
if !ok {
panic("scs: no session data in context")
}
return c
}
func (sd *sessionData) encode() ([]byte, error) {
var b bytes.Buffer
err := gob.NewEncoder(&b).Encode(sd)
if err != nil {
return nil, err
}
return b.Bytes(), nil
}
func (sd *sessionData) decode(b []byte) error {
r := bytes.NewReader(b)
return gob.NewDecoder(r).Decode(sd)
}
func generateToken() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
type contextKey string
var contextKeyID int
func generateContextKey() contextKey {
contextKeyID = contextKeyID + 1
return contextKey(fmt.Sprintf("session.%d", contextKeyID))
}

275
data_test.go Normal file
View File

@ -0,0 +1,275 @@
package scs
import (
"bytes"
"context"
"reflect"
"testing"
"time"
)
func TestSessionDataFromContext(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("the code did not panic")
}
}()
s := NewSession()
s.getSessionDataFromContext(context.Background())
}
func TestPut(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
ctx := s.addSessionDataToContext(context.Background(), sd)
s.Put(ctx, "foo", "bar")
if sd.Values["foo"] != "bar" {
t.Errorf("got %q: expected %q", sd.Values["foo"], "bar")
}
if sd.status != Modified {
t.Errorf("got %v: expected %v", sd.status, "modified")
}
}
func TestGet(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
ctx := s.addSessionDataToContext(context.Background(), sd)
str, ok := s.Get(ctx, "foo").(string)
if !ok {
t.Errorf("could not convert %T to string", s.Get(ctx, "foo"))
}
if str != "bar" {
t.Errorf("got %q: expected %q", str, "bar")
}
}
func TestPop(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
ctx := s.addSessionDataToContext(context.Background(), sd)
str, ok := s.Pop(ctx, "foo").(string)
if !ok {
t.Errorf("could not convert %T to string", s.Get(ctx, "foo"))
}
if str != "bar" {
t.Errorf("got %q: expected %q", str, "bar")
}
_, ok = sd.Values["foo"]
if ok {
t.Errorf("got %v: expected %v", ok, false)
}
if sd.status != Modified {
t.Errorf("got %v: expected %v", sd.status, "modified")
}
}
func TestRemove(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
ctx := s.addSessionDataToContext(context.Background(), sd)
s.Remove(ctx, "foo")
if sd.Values["foo"] != nil {
t.Errorf("got %v: expected %v", sd.Values["foo"], nil)
}
if sd.status != Modified {
t.Errorf("got %v: expected %v", sd.status, "modified")
}
}
func TestExists(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
ctx := s.addSessionDataToContext(context.Background(), sd)
if !s.Exists(ctx, "foo") {
t.Errorf("got %v: expected %v", s.Exists(ctx, "foo"), true)
}
if s.Exists(ctx, "baz") {
t.Errorf("got %v: expected %v", s.Exists(ctx, "baz"), false)
}
}
func TestKeys(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
sd.Values["woo"] = "waa"
ctx := s.addSessionDataToContext(context.Background(), sd)
keys := s.Keys(ctx)
if !reflect.DeepEqual(keys, []string{"foo", "woo"}) {
t.Errorf("got %v: expected %v", keys, []string{"foo", "woo"})
}
}
func TestGetString(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
ctx := s.addSessionDataToContext(context.Background(), sd)
str := s.GetString(ctx, "foo")
if str != "bar" {
t.Errorf("got %q: expected %q", str, "bar")
}
str = s.GetString(ctx, "baz")
if str != "" {
t.Errorf("got %q: expected %q", str, "")
}
}
func TestGetBool(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = true
ctx := s.addSessionDataToContext(context.Background(), sd)
b := s.GetBool(ctx, "foo")
if b != true {
t.Errorf("got %v: expected %v", b, true)
}
b = s.GetBool(ctx, "baz")
if b != false {
t.Errorf("got %v: expected %v", b, false)
}
}
func TestGetInt(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = 123
ctx := s.addSessionDataToContext(context.Background(), sd)
i := s.GetInt(ctx, "foo")
if i != 123 {
t.Errorf("got %v: expected %d", i, 123)
}
i = s.GetInt(ctx, "baz")
if i != 0 {
t.Errorf("got %v: expected %d", i, 0)
}
}
func TestGetFloat(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = 123.456
ctx := s.addSessionDataToContext(context.Background(), sd)
f := s.GetFloat(ctx, "foo")
if f != 123.456 {
t.Errorf("got %v: expected %f", f, 123.456)
}
f = s.GetFloat(ctx, "baz")
if f != 0 {
t.Errorf("got %v: expected %f", f, 0.00)
}
}
func TestGetBytes(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = []byte("bar")
ctx := s.addSessionDataToContext(context.Background(), sd)
b := s.GetBytes(ctx, "foo")
if !bytes.Equal(b, []byte("bar")) {
t.Errorf("got %v: expected %v", b, []byte("bar"))
}
b = s.GetBytes(ctx, "baz")
if b != nil {
t.Errorf("got %v: expected %v", b, nil)
}
}
func TestGetTime(t *testing.T) {
now := time.Now()
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = now
ctx := s.addSessionDataToContext(context.Background(), sd)
tm := s.GetTime(ctx, "foo")
if tm != now {
t.Errorf("got %v: expected %v", tm, now)
}
tm = s.GetTime(ctx, "baz")
if !tm.IsZero() {
t.Errorf("got %v: expected %v", tm, time.Time{})
}
}
func TestPopString(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
sd.Values["foo"] = "bar"
ctx := s.addSessionDataToContext(context.Background(), sd)
str := s.PopString(ctx, "foo")
if str != "bar" {
t.Errorf("got %q: expected %q", str, "bar")
}
_, ok := sd.Values["foo"]
if ok {
t.Errorf("got %v: expected %v", ok, false)
}
if sd.status != Modified {
t.Errorf("got %v: expected %v", sd.status, "modified")
}
str = s.PopString(ctx, "bar")
if str != "" {
t.Errorf("got %q: expected %q", str, "")
}
}
func TestStatus(t *testing.T) {
s := NewSession()
sd := newSessionData(time.Hour)
ctx := s.addSessionDataToContext(context.Background(), sd)
status := s.Status(ctx)
if status != Unmodified {
t.Errorf("got %d: expected %d", status, Unmodified)
}
s.Put(ctx, "foo", "bar")
status = s.Status(ctx)
if status != Modified {
t.Errorf("got %d: expected %d", status, Modified)
}
s.Destroy(ctx)
status = s.Status(ctx)
if status != Destroyed {
t.Errorf("got %d: expected %d", status, Destroyed)
}
}

3
go.mod Normal file
View File

@ -0,0 +1,3 @@
module github.com/alexedwards/scs/v2
go 1.12

0
go.sum Normal file
View File

View File

@ -1,158 +0,0 @@
package scs
import (
"context"
"fmt"
"log"
"net/http"
"time"
"github.com/alexedwards/scs/stores/cookiestore"
)
// Manager is a session manager.
type Manager struct {
store Store
opts *options
}
// NewManager returns a pointer to a new session manager.
func NewManager(store Store) *Manager {
defaultOptions := &options{
domain: "",
httpOnly: true,
idleTimeout: 0,
lifetime: 24 * time.Hour,
name: "session",
path: "/",
persist: false,
secure: false,
sameSite: "",
}
return &Manager{
store: store,
opts: defaultOptions,
}
}
// Domain sets the 'Domain' attribute on the session cookie. By default it will
// be set to the domain name that the cookie was issued from.
func (m *Manager) Domain(s string) {
m.opts.domain = s
}
// HttpOnly sets the 'HttpOnly' attribute on the session cookie. The default value
// is true.
func (m *Manager) HttpOnly(b bool) {
m.opts.httpOnly = b
}
// IdleTimeout sets the maximum length of time a session can be inactive before it
// expires. For example, some applications may wish to set this so there is a timeout
// after 20 minutes of inactivity. The inactivity period is reset whenever the
// session data is changed (but not read).
//
// By default IdleTimeout is not set and there is no inactivity timeout.
func (m *Manager) IdleTimeout(t time.Duration) {
m.opts.idleTimeout = t
}
// Lifetime sets the maximum length of time that a session is valid for before
// it expires. The lifetime is an 'absolute expiry' which is set when the session
// is first created and does not change.
//
// The default value is 24 hours.
func (m *Manager) Lifetime(t time.Duration) {
m.opts.lifetime = t
}
// Name sets the name of the session cookie. This name should not contain whitespace,
// commas, semicolons, backslashes, the equals sign or control characters as per
// RFC6265.
func (m *Manager) Name(s string) {
m.opts.name = s
}
// Path sets the 'Path' attribute on the session cookie. The default value is "/".
// Passing the empty string "" will result in it being set to the path that the
// cookie was issued from.
func (m *Manager) Path(s string) {
m.opts.path = s
}
// Persist sets whether the session cookie should be persistent or not (i.e. whether
// it should be retained after a user closes their browser).
//
// The default value is false, which means that the session cookie will be destroyed
// when the user closes their browser. If set to true, explicit 'Expires' and
// 'MaxAge' values will be added to the cookie and it will be retained by the
// user's browser until the given expiry time is reached.
func (m *Manager) Persist(b bool) {
m.opts.persist = b
}
// Secure sets the 'Secure' attribute on the session cookie. The default value
// is false. It's recommended that you set this to true and serve all requests
// over HTTPS in production environments.
func (m *Manager) Secure(b bool) {
m.opts.secure = b
}
// SameSite sets the 'SameSite' attribute on the session cookie. The default value
// is nil; setting no SameSite attribute. Allowed values are "Lax" and "Strict".
// Note that "" (empty-string) causes SameSite to NOT be set -- don't confuse this
// with the cookie's 'SameSite' attribute (without Lax/Strict), which would default
// to "Strict".
func (m *Manager) SameSite(s string) {
m.opts.sameSite = s
}
// Load returns the session data for the current request.
func (m *Manager) Load(r *http.Request) *Session {
return load(r, m.store, m.opts)
}
// LoadFromContext returns session data from a given context.Context object.
func (m *Manager) LoadFromContext(ctx context.Context) *Session {
val := ctx.Value(sessionName(m.opts.name))
if val == nil {
return &Session{loadErr: fmt.Errorf("scs: value %s not in context", m.opts.name)}
}
s, ok := val.(*Session)
if !ok {
return &Session{loadErr: fmt.Errorf("scs: can not assert %T to *Session", val)}
}
return s
}
// AddToContext adds session data to a given context.Context object.
func (m *Manager) AddToContext(ctx context.Context, session *Session) context.Context {
return context.WithValue(ctx, sessionName(m.opts.name), session)
}
func NewCookieManager(key string) *Manager {
store := cookiestore.New([]byte(key))
return NewManager(store)
}
func (m *Manager) Multi(next http.Handler) http.Handler {
return m.Use(next)
}
func (m *Manager) Use(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := m.Load(r)
err := session.Touch(w)
if err != nil {
log.Println(err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
ctx := m.AddToContext(r.Context(), session)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

74
memstore/README.md Normal file
View File

@ -0,0 +1,74 @@
# memstore
An in-memory session store for [SCS](https://github.com/alexedwards/scs).
Because memstore uses in-memory storage only, all session data will be lost when your application is stopped or restarted. Therefore it should only be used in applications where data loss is an acceptable trade off for fast performance, or for prototyping and testing purposes.
## Example
```go
package main
import (
"io"
"net/http"
"github.com/alexedwards/scs/v2"
"github.com/alexedwards/scs/v2/memstore"
)
var session *scs.Session
func main() {
// Initialize a new session manager and configure it to use memstore as
// the session store.
session = scs.NewSession()
session.Store = memstore.New()
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
http.ListenAndServe(":4000", session.LoadAndSave(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "message", "Hello from a session!")
}
func getHandler(w http.ResponseWriter, r *http.Request) {
msg := session.GetString(r.Context(), "message")
io.WriteString(w, msg)
}
```
## Expired Session Cleanup
This package provides a background 'cleanup' goroutine to delete expired session data. This stops the database table from holding on to invalid sessions indefinitely and growing unnecessarily large. By default the cleanup runs once every minute. You can change this by using the `NewWithCleanupInterval()` function to initialize your session store. For example:
```go
// Run a cleanup every 30 seconds.
memstore.NewWithCleanupInterval(db, 30*time.Second)
// Disable the cleanup goroutine by setting the cleanup interval to zero.
memstore.NewWithCleanupInterval(db, 0)
```
### Terminating the Cleanup Goroutine
It's rare that the cleanup goroutine needs to be terminated --- it is generally intended to be long-lived and run for the lifetime of your application.
However, there may be occasions when your use of a session store instance is transient. A common example would be using it in a short-lived test function. In this scenario, the cleanup goroutine (which will run forever) will prevent the session store instance from being garbage collected even after the test function has finished. You can prevent this by either disabling the cleanup goroutine altogether (as described above) or by stopping it using the `StopCleanup()` method. For example:
```go
func TestExample(t *testing.T) {
store := memstore.New()
defer store.StopCleanup()
session = scs.NewSession()
session.Store = store
// Run test...
}
```

131
memstore/memstore.go Normal file
View File

@ -0,0 +1,131 @@
package memstore
import (
"errors"
"sync"
"time"
)
var errTypeAssertionFailed = errors.New("type assertion failed: could not convert interface{} to []byte")
type item struct {
object interface{}
expiration int64
}
// MemStore represents the session store.
type MemStore struct {
items map[string]item
mu sync.RWMutex
stopCleanup chan bool
}
// New returns a new MemStore instance, with a background cleanup goroutine that
// runs every minute to remove expired session data.
func New() *MemStore {
return NewWithCleanupInterval(time.Minute)
}
// NewWithCleanupInterval returns a new MemStore instance. The cleanupInterval
// parameter controls how frequently expired session data is removed by the
// background cleanup goroutine. Setting it to 0 prevents the cleanup goroutine
// from running (i.e. expired sessions will not be removed).
func NewWithCleanupInterval(cleanupInterval time.Duration) *MemStore {
m := &MemStore{
items: make(map[string]item),
}
if cleanupInterval > 0 {
go m.startCleanup(cleanupInterval)
}
return m
}
// Find returns the data for a given session token from the MemStore instance.
// If the session token is not found or is expired, the returned exists flag will
// be set to false.
func (m *MemStore) Find(token string) ([]byte, bool, error) {
m.mu.RLock()
defer m.mu.RUnlock()
item, found := m.items[token]
if !found {
return nil, false, nil
}
if time.Now().UnixNano() > item.expiration {
return nil, false, nil
}
b, ok := item.object.([]byte)
if !ok {
return nil, true, errTypeAssertionFailed
}
return b, true, nil
}
// Commit adds a session token and data to the MemStore instance with the given
// expiry time. If the session token already exists, then the data and expiry
// time are updated.
func (m *MemStore) Commit(token string, b []byte, expiry time.Time) error {
m.mu.Lock()
m.items[token] = item{
object: b,
expiration: expiry.UnixNano(),
}
m.mu.Unlock()
return nil
}
// Delete removes a session token and corresponding data from the MemStore
// instance.
func (m *MemStore) Delete(token string) error {
m.mu.Lock()
delete(m.items, token)
m.mu.Unlock()
return nil
}
func (m *MemStore) startCleanup(interval time.Duration) {
m.stopCleanup = make(chan bool)
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
m.deleteExpired()
case <-m.stopCleanup:
ticker.Stop()
return
}
}
}
// StopCleanup terminates the background cleanup goroutine for the MemStore
// instance. It's rare to terminate this; generally MemStore instances and
// their cleanup goroutines are intended to be long-lived and run for the lifetime
// of your application.
//
// There may be occasions though when your use of the MemStore is transient.
// An example is creating a new MemStore instance in a test function. In this
// scenario, the cleanup goroutine (which will run forever) will prevent the
// MemStore object from being garbage collected even after the test function
// has finished. You can prevent this by manually calling StopCleanup.
func (m *MemStore) StopCleanup() {
if m.stopCleanup != nil {
m.stopCleanup <- true
}
}
func (m *MemStore) deleteExpired() {
now := time.Now().UnixNano()
m.mu.Lock()
for token, item := range m.items {
if now > item.expiration {
delete(m.items, token)
}
}
m.mu.Unlock()
}

View File

@ -8,8 +8,8 @@ import (
)
func TestFind(t *testing.T) {
m := New(time.Minute)
m.cache.Set("session_token", []byte("encoded_data"), 0)
m := NewWithCleanupInterval(0)
m.items["session_token"] = item{object: []byte("encoded_data"), expiration: time.Now().Add(time.Second).UnixNano()}
b, found, err := m.Find("session_token")
if err != nil {
@ -24,7 +24,7 @@ func TestFind(t *testing.T) {
}
func TestFindMissing(t *testing.T) {
m := New(time.Minute)
m := NewWithCleanupInterval(0)
_, found, err := m.Find("missing_session_token")
if err != nil {
@ -36,8 +36,8 @@ func TestFindMissing(t *testing.T) {
}
func TestFindBadData(t *testing.T) {
m := New(time.Minute)
m.cache.Set("session_token", "not_a_byte_slice", 0)
m := NewWithCleanupInterval(0)
m.items["session_token"] = item{object: "not_a_byte_slice", expiration: time.Now().Add(time.Second).UnixNano()}
_, _, err := m.Find("session_token")
if err != errTypeAssertionFailed {
@ -45,19 +45,19 @@ func TestFindBadData(t *testing.T) {
}
}
func TestSaveNew(t *testing.T) {
m := New(time.Minute)
func TestCommitNew(t *testing.T) {
m := NewWithCleanupInterval(0)
err := m.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
err := m.Commit("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
v, found := m.cache.Get("session_token")
v, found := m.items["session_token"]
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
b, ok := v.([]byte)
b, ok := v.object.([]byte)
if ok == false {
t.Fatal("could not convert to []byte")
}
@ -66,21 +66,20 @@ func TestSaveNew(t *testing.T) {
}
}
func TestSaveUpdated(t *testing.T) {
m := New(time.Minute)
m.cache.Set("session_token", []byte("encoded_data"), 0)
func TestCommitUpdated(t *testing.T) {
m := NewWithCleanupInterval(0)
err := m.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
err := m.Commit("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
err = m.Save("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
err = m.Commit("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
v, _ := m.cache.Get("session_token")
v := m.items["session_token"].object
b, ok := v.([]byte)
if ok == false {
t.Fatal("could not convert to []byte")
@ -91,9 +90,9 @@ func TestSaveUpdated(t *testing.T) {
}
func TestExpiry(t *testing.T) {
m := New(time.Minute)
m := NewWithCleanupInterval(0)
err := m.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err := m.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
@ -103,7 +102,7 @@ func TestExpiry(t *testing.T) {
t.Fatalf("got %v: expected %v", found, true)
}
time.Sleep(100 * time.Millisecond)
time.Sleep(101 * time.Millisecond)
_, found, _ = m.Find("session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
@ -111,15 +110,15 @@ func TestExpiry(t *testing.T) {
}
func TestDelete(t *testing.T) {
m := New(time.Minute)
m.cache.Set("session_token", []byte("encoded_data"), 0)
m := NewWithCleanupInterval(0)
m.items["session_token"] = item{object: []byte("encoded_data"), expiration: time.Now().Add(time.Second).UnixNano()}
err := m.Delete("session_token")
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
_, found := m.cache.Get("session_token")
_, found := m.items["session_token"]
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}

103
mysqlstore/README.md Normal file
View File

@ -0,0 +1,103 @@
# mysqlstore
A MySQL-based session store supporting the [go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) driver.
## Setup
You should have a working MySQL database containing a `sessions` table with the definition:
```sql
CREATE TABLE sessions (
token CHAR(43) PRIMARY KEY,
data BLOB NOT NULL,
expiry TIMESTAMP(6) NOT NULL
);
CREATE INDEX sessions_expiry_idx ON sessions (expiry);
```
The database user for your application must have `SELECT`, `INSERT`, `UPDATE` and `DELETE` permissions on this table.
## Example
```go
package main
import (
"database/sql"
"io"
"log"
"net/http"
"github.com/alexedwards/scs/v2"
"github.com/alexedwards/scs/v2/mysqlstore"
_ "github.com/go-sql-driver/mysql"
)
var session *scs.Session
func main() {
db, err := sql.Open("mysql", "user:pass@/db?parseTime=true")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// Initialize a new session manager and configure it to use MySQL as
// the session store.
session = scs.NewSession()
session.Store = mysqlstore.New(db)
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
http.ListenAndServe(":4000", session.LoadAndSave(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "message", "Hello from a session!")
}
func getHandler(w http.ResponseWriter, r *http.Request) {
msg := session.GetString(r.Context(), "message")
io.WriteString(w, msg)
}
```
## Expired Session Cleanup
This package provides a background 'cleanup' goroutine to delete expired session data. This stops the database table from holding on to invalid sessions indefinitely and growing unnecessarily large. By default the cleanup runs every 5 minutes. You can change this by using the `NewWithCleanupInterval()` function to initialize your session store. For example:
```go
// Run a cleanup every 30 minutes.
mysqlstore.NewWithCleanupInterval(db, 30*time.Minute)
// Disable the cleanup goroutine by setting the cleanup interval to zero.
mysqlstore.NewWithCleanupInterval(db, 0)
```
### Terminating the Cleanup Goroutine
It's rare that the cleanup goroutine needs to be terminated --- it is generally intended to be long-lived and run for the lifetime of your application.
However, there may be occasions when your use of a session store instance is transient. A common example would be using it in a short-lived test function. In this scenario, the cleanup goroutine (which will run forever) will prevent the session store instance from being garbage collected even after the test function has finished. You can prevent this by either disabling the cleanup goroutine altogether (as described above) or by stopping it using the `StopCleanup()` method. For example:
```go
func TestExample(t *testing.T) {
db, err := sql.Open("mysql", "user:pass@/db?parseTime=true")
if err != nil {
t.Fatal(err)
}
defer db.Close()
store := mysqlstore.New(db)
defer store.StopCleanup()
session = scs.NewSession()
session.Store = store
// Run test...
}
```

8
mysqlstore/go.mod Normal file
View File

@ -0,0 +1,8 @@
module github.com/alexedwards/scs/v2/mysqlstore
go 1.12
require (
github.com/go-sql-driver/mysql v1.4.1
google.golang.org/appengine v1.5.0 // indirect
)

7
mysqlstore/go.sum Normal file
View File

@ -0,0 +1,7 @@
github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA=
github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c=
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=

View File

@ -1,18 +1,3 @@
// Package mysqlstore is a MySQL-based session store for the SCS session package.
//
// A working MySQL database is required, containing a sessions table with
// the definition:
//
// CREATE TABLE sessions (
// token CHAR(43) PRIMARY KEY,
// data BLOB NOT NULL,
// expiry TIMESTAMP(6) NOT NULL
// );
// CREATE INDEX sessions_expiry_idx ON sessions (expiry);
//
// The mysqlstore package provides a background 'cleanup' goroutine to delete expired
// session data. This stops the database table from holding on to invalid sessions
// forever and growing unnecessarily large.
package mysqlstore
import (
@ -21,24 +6,26 @@ import (
"strconv"
"strings"
"time"
// Register go-sql-driver/mysql with database/sql
_ "github.com/go-sql-driver/mysql"
)
// MySQLStore represents the currently configured session session store.
// MySQLStore represents the session store.
type MySQLStore struct {
*sql.DB
version string
stopCleanup chan bool
}
// New returns a new MySQLStore instance.
//
// The cleanupInterval parameter controls how frequently expired session data
// is removed by the background cleanup goroutine. Setting it to 0 prevents
// the cleanup goroutine from running (i.e. expired sessions will not be removed).
func New(db *sql.DB, cleanupInterval time.Duration) *MySQLStore {
// New returns a new MySQLStore instance, with a background cleanup goroutine
// that runs every 5 minutes to remove expired session data.
func New(db *sql.DB) *MySQLStore {
return NewWithCleanupInterval(db, 5*time.Minute)
}
// NewWithCleanupInterval returns a new MySQLStore instance. The cleanupInterval
// parameter controls how frequently expired session data is removed by the
// background cleanup goroutine. Setting it to 0 prevents the cleanup goroutine
// from running (i.e. expired sessions will not be removed).
func NewWithCleanupInterval(db *sql.DB, cleanupInterval time.Duration) *MySQLStore {
m := &MySQLStore{
DB: db,
version: getVersion(db),
@ -51,9 +38,9 @@ func New(db *sql.DB, cleanupInterval time.Duration) *MySQLStore {
return m
}
// Find returns the data for a given session token from the MySQLStore instance. If
// the session token is not found or is expired, the returned exists flag will be
// set to false.
// Find returns the data for a given session token from the MySQLStore instance.
// If the session token is not found or is expired, the returned exists flag will
// be set to false.
func (m *MySQLStore) Find(token string) ([]byte, bool, error) {
var b []byte
var stmt string
@ -74,9 +61,10 @@ func (m *MySQLStore) Find(token string) ([]byte, bool, error) {
return b, true, nil
}
// Save adds a session token and data to the MySQLStore instance with the given expiry
// time. If the session token already exists then the data and expiry time are updated.
func (m *MySQLStore) Save(token string, b []byte, expiry time.Time) error {
// Commit adds a session token and data to the MySQLStore instance with the given
// expiry time. If the session token already exists, then the data and expiry
// time are updated.
func (m *MySQLStore) Commit(token string, b []byte, expiry time.Time) error {
_, err := m.DB.Exec("INSERT INTO sessions (token, data, expiry) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE data = VALUES(data), expiry = VALUES(expiry)", token, b, expiry.UTC())
if err != nil {
return err
@ -84,7 +72,8 @@ func (m *MySQLStore) Save(token string, b []byte, expiry time.Time) error {
return nil
}
// Delete removes a session token and corresponding data from the MySQLStore instance.
// Delete removes a session token and corresponding data from the MySQLStore
// instance.
func (m *MySQLStore) Delete(token string) error {
_, err := m.DB.Exec("DELETE FROM sessions WHERE token = ?", token)
return err
@ -107,16 +96,16 @@ func (m *MySQLStore) startCleanup(interval time.Duration) {
}
}
// StopCleanup terminates the background cleanup goroutine for the MySQLStore instance.
// It's rare to terminate this; generally MySQLStore instances and their cleanup
// goroutines are intended to be long-lived and run for the lifetime of your
// application.
// StopCleanup terminates the background cleanup goroutine for the MySQLStore
// instance. It's rare to terminate this; generally MySQLStore instances and
// their cleanup goroutines are intended to be long-lived and run for the lifetime
// of your application.
//
// There may be occasions though when your use of the MySQLStore is transient. An
// example is creating a new MySQLStore instance in a test function. In this scenario,
// the cleanup goroutine (which will run forever) will prevent the MySQLStore object
// from being garbage collected even after the test function has finished. You
// can prevent this by manually calling StopCleanup.
// There may be occasions though when your use of the MySQLStore is transient.
// An example is creating a new MySQLStore instance in a test function. In this
// scenario, the cleanup goroutine (which will run forever) will prevent the
// MySQLStore object from being garbage collected even after the test function
// has finished. You can prevent this by manually calling StopCleanup.
func (m *MySQLStore) StopCleanup() {
if m.stopCleanup != nil {
m.stopCleanup <- true

View File

@ -7,10 +7,12 @@ import (
"reflect"
"testing"
"time"
_ "github.com/go-sql-driver/mysql"
)
func TestFind(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -28,7 +30,7 @@ func TestFind(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
b, found, err := m.Find("session_token")
if err != nil {
@ -43,7 +45,7 @@ func TestFind(t *testing.T) {
}
func TestFindMissing(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -57,7 +59,7 @@ func TestFindMissing(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
_, found, err := m.Find("missing_session_token")
if err != nil {
@ -69,7 +71,7 @@ func TestFindMissing(t *testing.T) {
}
func TestSaveNew(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -83,9 +85,9 @@ func TestSaveNew(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
err = m.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
err = m.Commit("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
@ -102,7 +104,7 @@ func TestSaveNew(t *testing.T) {
}
func TestSaveUpdated(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -120,9 +122,9 @@ func TestSaveUpdated(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
err = m.Save("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
err = m.Commit("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
@ -139,7 +141,7 @@ func TestSaveUpdated(t *testing.T) {
}
func TestExpiry(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -153,9 +155,9 @@ func TestExpiry(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
err = m.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = m.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
@ -173,7 +175,7 @@ func TestExpiry(t *testing.T) {
}
func TestDelete(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -191,7 +193,7 @@ func TestDelete(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
err = m.Delete("session_token")
if err != nil {
@ -210,7 +212,7 @@ func TestDelete(t *testing.T) {
}
func TestCleanup(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -224,10 +226,10 @@ func TestCleanup(t *testing.T) {
t.Fatal(err)
}
m := New(db, 200*time.Millisecond)
m := NewWithCleanupInterval(db, 200*time.Millisecond)
defer m.StopCleanup()
err = m.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = m.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
@ -254,7 +256,7 @@ func TestCleanup(t *testing.T) {
}
func TestStopNilCleanup(t *testing.T) {
dsn := os.Getenv("SESSION_MYSQL_TEST_DSN")
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatal(err)
@ -264,7 +266,7 @@ func TestStopNilCleanup(t *testing.T) {
t.Fatal(err)
}
m := New(db, 0)
m := NewWithCleanupInterval(db, 0)
time.Sleep(100 * time.Millisecond)
// A send to a nil channel will block forever
m.StopCleanup()

View File

@ -1,21 +0,0 @@
package scs
import (
"time"
)
// Deprecated: Please use the Manager.Name() method to change the name of the
// session cookie.
var CookieName = "session"
type options struct {
domain string
httpOnly bool
idleTimeout time.Duration
lifetime time.Duration
name string
path string
persist bool
secure bool
sameSite string
}

View File

@ -1,171 +0,0 @@
package scs
import (
"strings"
"testing"
"time"
)
func TestCookieOptions(t *testing.T) {
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Path=/") == false {
t.Errorf("got %q: expected to contain %q", cookie, "Path=/")
}
if strings.Contains(cookie, "Domain=") == true {
t.Errorf("got %q: expected to not contain %q", cookie, "Domain=")
}
if strings.Contains(cookie, "Secure") == true {
t.Errorf("got %q: expected to not contain %q", cookie, "Secure")
}
if strings.Contains(cookie, "HttpOnly") == false {
t.Errorf("got %q: expected to contain %q", cookie, "HttpOnly")
}
if strings.Contains(cookie, "SameSite") == true {
t.Errorf("got %q: expected to not contain %q", cookie, "SameSite")
}
manager = NewManager(newMockStore())
manager.Path("/foo")
manager.Domain("example.org")
manager.Secure(true)
manager.HttpOnly(false)
manager.Lifetime(time.Hour)
manager.Persist(true)
manager.SameSite("Lax")
_, _, cookie = testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Path=/foo") == false {
t.Errorf("got %q: expected to contain %q", cookie, "Path=/foo")
}
if strings.Contains(cookie, "Domain=example.org") == false {
t.Errorf("got %q: expected to contain %q", cookie, "Domain=example.org")
}
if strings.Contains(cookie, "Secure") == false {
t.Errorf("got %q: expected to contain %q", cookie, "Secure")
}
if strings.Contains(cookie, "HttpOnly") == true {
t.Errorf("got %q: expected to not contain %q", cookie, "HttpOnly")
}
if strings.Contains(cookie, "Max-Age=3600") == false {
t.Errorf("got %q: expected to contain %q:", cookie, "Max-Age=86400")
}
if strings.Contains(cookie, "Expires=") == false {
t.Errorf("got %q: expected to contain %q:", cookie, "Expires")
}
if strings.Contains(cookie, "SameSite=Lax") == false {
t.Errorf("got %q: expected to contain %q", cookie, "SameSite=Lax")
}
manager = NewManager(newMockStore())
manager.Lifetime(time.Hour)
_, _, cookie = testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Max-Age=") == true {
t.Errorf("got %q: expected not to contain %q:", cookie, "Max-Age=")
}
if strings.Contains(cookie, "Expires=") == true {
t.Errorf("got %q: expected not to contain %q:", cookie, "Expires")
}
manager = NewManager(newMockStore())
manager.SameSite("Strict")
_, _, cookie = testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "SameSite=Strict") == false {
t.Errorf("got %q: expected to contain %q", cookie, "SameSite=Strict")
}
manager = NewManager(newMockStore())
// empty string disables
manager.SameSite("")
_, _, cookie = testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "SameSite") == true {
t.Errorf("got %q: expected to not contain %q", cookie, "SameSite")
}
}
func TestLifetime(t *testing.T) {
manager := NewManager(newMockStore())
manager.Lifetime(200 * time.Millisecond)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
time.Sleep(100 * time.Millisecond)
_, _, cookie = testRequest(t, testPutString(manager), cookie)
time.Sleep(100 * time.Millisecond)
_, body, _ := testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, _, cookie = testRequest(t, testPutString(manager), cookie)
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatalf("expected a difference")
}
}
func TestIdleTimeout(t *testing.T) {
manager := NewManager(newMockStore())
manager.IdleTimeout(100 * time.Millisecond)
manager.Lifetime(500 * time.Millisecond)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
time.Sleep(150 * time.Millisecond)
_, body, _ := testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, _, cookie = testRequest(t, testPutString(manager), cookie)
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatalf("expected a difference")
}
_, _, cookie = testRequest(t, testPutString(manager), "")
oldToken = extractTokenFromCookie(cookie)
time.Sleep(75 * time.Millisecond)
_, _, cookie = testRequest(t, testPutString(manager), cookie)
time.Sleep(75 * time.Millisecond)
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, _, cookie = testRequest(t, testPutString(manager), cookie)
newToken = extractTokenFromCookie(cookie)
if newToken != oldToken {
t.Fatalf("expected the same")
}
}
func TestPersist(t *testing.T) {
manager := NewManager(newMockStore())
manager.IdleTimeout(5 * time.Minute)
manager.Persist(true)
_, _, cookie := testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Max-Age=300") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=300")
}
}
func TestName(t *testing.T) {
manager := NewManager(newMockStore())
manager.Name("foo")
_, _, cookie := testRequest(t, testPutString(manager), "")
if strings.HasPrefix(cookie, "foo=") == false {
t.Fatalf("got %q: expected prefix %q", cookie, "foo=")
}
_, body, _ := testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
}

103
postgresstore/README.md Normal file
View File

@ -0,0 +1,103 @@
# postgresstore
A PostgreSQL-based session store supporting the [pq](https://github.com/lib/pq) driver.
## Setup
You should have a working PostgreSQL database containing a `sessions` table with the definition:
```sql
CREATE TABLE sessions (
token TEXT PRIMARY KEY,
data BYTEA NOT NULL,
expiry TIMESTAMPTZ NOT NULL
);
CREATE INDEX sessions_expiry_idx ON sessions (expiry);
```
The database user for your application must have `SELECT`, `INSERT`, `UPDATE` and `DELETE` permissions on this table.
## Example
```go
package main
import (
"database/sql"
"io"
"log"
"net/http"
"github.com/alexedwards/scs/v2"
"github.com/alexedwards/scs/v2/postgresstore"
_ "github.com/lib/pq"
)
var session *scs.Session
func main() {
db, err := sql.Open("postgres", "postgres://user:pass@localhost/db")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// Initialize a new session manager and configure it to use PostgreSQL as
// the session store.
session = scs.NewSession()
session.Store = postgresstore.New(db)
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
http.ListenAndServe(":4000", session.LoadAndSave(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "message", "Hello from a session!")
}
func getHandler(w http.ResponseWriter, r *http.Request) {
msg := session.GetString(r.Context(), "message")
io.WriteString(w, msg)
}
```
## Expired Session Cleanup
This package provides a background 'cleanup' goroutine to delete expired session data. This stops the database table from holding on to invalid sessions indefinitely and growing unnecessarily large. By default the cleanup runs every 5 minutes. You can change this by using the `NewWithCleanupInterval()` function to initialize your session store. For example:
```go
// Run a cleanup every 30 minutes.
postgresstore.NewWithCleanupInterval(db, 30*time.Minute)
// Disable the cleanup goroutine by setting the cleanup interval to zero.
postgresstore.NewWithCleanupInterval(db, 0)
```
### Terminating the Cleanup Goroutine
It's rare that the cleanup goroutine needs to be terminated --- it is generally intended to be long-lived and run for the lifetime of your application.
However, there may be occasions when your use of a session store instance is transient. A common example would be using it in a short-lived test function. In this scenario, the cleanup goroutine (which will run forever) will prevent the session store instance from being garbage collected even after the test function has finished. You can prevent this by either disabling the cleanup goroutine altogether (as described above) or by stopping it using the `StopCleanup()` method. For example:
```go
func TestExample(t *testing.T) {
db, err := sql.Open("postgres", "postgres://user:pass@localhost/db")
if err != nil {
t.Fatal(err)
}
defer db.Close()
store := postgresstore.New(db)
defer store.StopCleanup()
session = scs.NewSession()
session.Store = store
// Run test...
}
```

5
postgresstore/go.mod Normal file
View File

@ -0,0 +1,5 @@
module github.com/alexedwards/scs/v2/postgresstore
go 1.12
require github.com/lib/pq v1.1.0

2
postgresstore/go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/lib/pq v1.1.0 h1:/5u4a+KGJptBRqGzPvYQL9p0d/tPR4S31+Tnzj9lEO4=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=

View File

@ -0,0 +1,101 @@
package postgresstore
import (
"database/sql"
"log"
"time"
)
// PostgresStore represents the session store.
type PostgresStore struct {
db *sql.DB
stopCleanup chan bool
}
// New returns a new PostgresStore instance, with a background cleanup goroutine
// that runs every 5 minutes to remove expired session data.
func New(db *sql.DB) *PostgresStore {
return NewWithCleanupInterval(db, 5*time.Minute)
}
// NewWithCleanupInterval returns a new PostgresStore instance. The cleanupInterval
// parameter controls how frequently expired session data is removed by the
// background cleanup goroutine. Setting it to 0 prevents the cleanup goroutine
// from running (i.e. expired sessions will not be removed).
func NewWithCleanupInterval(db *sql.DB, cleanupInterval time.Duration) *PostgresStore {
p := &PostgresStore{db: db}
if cleanupInterval > 0 {
go p.startCleanup(cleanupInterval)
}
return p
}
// Find returns the data for a given session token from the PostgresStore instance.
// If the session token is not found or is expired, the returned exists flag will
// be set to false.
func (p *PostgresStore) Find(token string) (b []byte, exists bool, err error) {
row := p.db.QueryRow("SELECT data FROM sessions WHERE token = $1 AND current_timestamp < expiry", token)
err = row.Scan(&b)
if err == sql.ErrNoRows {
return nil, false, nil
} else if err != nil {
return nil, false, err
}
return b, true, nil
}
// Commit adds a session token and data to the PostgresStore instance with the
// given expiry time. If the session token already exists, then the data and expiry
// time are updated.
func (p *PostgresStore) Commit(token string, b []byte, expiry time.Time) error {
_, err := p.db.Exec("INSERT INTO sessions (token, data, expiry) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET data = EXCLUDED.data, expiry = EXCLUDED.expiry", token, b, expiry)
if err != nil {
return err
}
return nil
}
// Delete removes a session token and corresponding data from the PostgresStore
// instance.
func (p *PostgresStore) Delete(token string) error {
_, err := p.db.Exec("DELETE FROM sessions WHERE token = $1", token)
return err
}
func (p *PostgresStore) startCleanup(interval time.Duration) {
p.stopCleanup = make(chan bool)
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
err := p.deleteExpired()
if err != nil {
log.Println(err)
}
case <-p.stopCleanup:
ticker.Stop()
return
}
}
}
// StopCleanup terminates the background cleanup goroutine for the PostgresStore
// instance. It's rare to terminate this; generally PostgresStore instances and
// their cleanup goroutines are intended to be long-lived and run for the lifetime
// of your application.
//
// There may be occasions though when your use of the PostgresStore is transient.
// An example is creating a new PostgresStore instance in a test function. In this
// scenario, the cleanup goroutine (which will run forever) will prevent the
// PostgresStore object from being garbage collected even after the test function
// has finished. You can prevent this by manually calling StopCleanup.
func (p *PostgresStore) StopCleanup() {
if p.stopCleanup != nil {
p.stopCleanup <- true
}
}
func (p *PostgresStore) deleteExpired() error {
_, err := p.db.Exec("DELETE FROM sessions WHERE expiry < current_timestamp")
return err
}

View File

@ -1,4 +1,4 @@
package pgstore
package postgresstore
import (
"bytes"
@ -7,10 +7,12 @@ import (
"reflect"
"testing"
"time"
_ "github.com/lib/pq"
)
func TestFind(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -28,7 +30,7 @@ func TestFind(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
b, found, err := p.Find("session_token")
if err != nil {
@ -43,7 +45,7 @@ func TestFind(t *testing.T) {
}
func TestFindMissing(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -57,7 +59,7 @@ func TestFindMissing(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
_, found, err := p.Find("missing_session_token")
if err != nil {
@ -69,7 +71,7 @@ func TestFindMissing(t *testing.T) {
}
func TestSaveNew(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -83,9 +85,9 @@ func TestSaveNew(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
err = p.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
@ -102,7 +104,7 @@ func TestSaveNew(t *testing.T) {
}
func TestSaveUpdated(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -120,9 +122,9 @@ func TestSaveUpdated(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
err = p.Save("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
err = p.Commit("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
@ -139,7 +141,7 @@ func TestSaveUpdated(t *testing.T) {
}
func TestExpiry(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -153,9 +155,9 @@ func TestExpiry(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
err = p.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
@ -173,7 +175,7 @@ func TestExpiry(t *testing.T) {
}
func TestDelete(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -191,7 +193,7 @@ func TestDelete(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
err = p.Delete("session_token")
if err != nil {
@ -210,7 +212,7 @@ func TestDelete(t *testing.T) {
}
func TestCleanup(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -224,10 +226,10 @@ func TestCleanup(t *testing.T) {
t.Fatal(err)
}
p := New(db, 200*time.Millisecond)
p := NewWithCleanupInterval(db, 200*time.Millisecond)
defer p.StopCleanup()
err = p.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
err = p.Commit("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
@ -254,7 +256,7 @@ func TestCleanup(t *testing.T) {
}
func TestStopNilCleanup(t *testing.T) {
dsn := os.Getenv("SESSION_PG_TEST_DSN")
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
db, err := sql.Open("postgres", dsn)
if err != nil {
t.Fatal(err)
@ -264,7 +266,7 @@ func TestStopNilCleanup(t *testing.T) {
t.Fatal(err)
}
p := New(db, 0)
p := NewWithCleanupInterval(db, 0)
time.Sleep(100 * time.Millisecond)
// A send to a nil channel will block forever
p.StopCleanup()

View File

@ -1,181 +0,0 @@
package scs
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func testRequest(t *testing.T, h http.Handler, cookie string) (int, string, string) {
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
if cookie != "" {
r.Header.Add("Cookie", cookie)
}
h.ServeHTTP(rr, r)
code := rr.Code
body := string(rr.Body.Bytes())
cookie = rr.Header().Get("Set-Cookie")
return code, body, cookie
}
func extractTokenFromCookie(c string) string {
parts := strings.Split(c, ";")
return strings.SplitN(parts[0], "=", 2)[1]
}
// Test Handlers
func testPutString(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.PutString(w, "test_string", "lorem ipsum")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testGetString(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
s, err := session.GetString("test_string")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, s)
}
}
func testPopString(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
s, err := session.PopString(w, "test_string")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, s)
}
}
func testPutBool(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.PutBool(w, "test_bool", true)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testGetBool(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
b, err := session.GetBool("test_bool")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%v", b)
}
}
func testPutObject(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
u := &testUser{"alice", 21}
err := session.PutObject(w, "test_object", u)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testGetObject(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
u := new(testUser)
err := session.GetObject("test_object", u)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
}
}
func testPopObject(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
u := new(testUser)
err := session.PopObject(w, "test_object", u)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
}
}
func testDestroy(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.Destroy(w)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testRenewToken(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.RenewToken(w)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testClear(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.Clear(w)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testKeys(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
keys, err := session.Keys()
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%v", keys)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,100 +1,194 @@
package scs
import (
"encoding/gob"
"fmt"
"io/ioutil"
"net/http"
"regexp"
"net/http/cookiejar"
"net/http/httptest"
"strings"
"testing"
"time"
)
type testUser struct {
Name string
Age int
type testServer struct {
*httptest.Server
}
func init() {
gob.Register(new(testUser))
func newTestServer(t *testing.T, h http.Handler) *testServer {
ts := httptest.NewTLSServer(h)
jar, err := cookiejar.New(nil)
if err != nil {
t.Fatal(err)
}
ts.Client().Jar = jar
ts.Client().CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
func TestGenerateToken(t *testing.T) {
id, err := generateToken()
return &testServer{ts}
}
func (ts *testServer) execute(t *testing.T, urlPath string) (http.Header, string) {
rs, err := ts.Client().Get(ts.URL + urlPath)
if err != nil {
t.Fatal(err)
}
match, err := regexp.MatchString("^[0-9a-zA-Z_\\-]{43}$", id)
defer rs.Body.Close()
body, err := ioutil.ReadAll(rs.Body)
if err != nil {
t.Fatal(err)
}
if match == false {
t.Errorf("got %q: should match %q", id, "^[0-9a-zA-Z_\\-]{43}$")
return rs.Header, string(body)
}
func extractTokenFromCookie(c string) string {
parts := strings.Split(c, ";")
return strings.SplitN(parts[0], "=", 2)[1]
}
func TestEnable(t *testing.T) {
session := NewSession()
mux := http.NewServeMux()
mux.HandleFunc("/put", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "foo", "bar")
}))
mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s := session.Get(r.Context(), "foo").(string)
w.Write([]byte(s))
}))
ts := newTestServer(t, session.LoadAndSave(mux))
defer ts.Close()
header, _ := ts.execute(t, "/put")
token1 := extractTokenFromCookie(header.Get("Set-Cookie"))
header, body := ts.execute(t, "/get")
if body != "bar" {
t.Errorf("want %q; got %q", "bar", body)
}
if header.Get("Set-Cookie") != "" {
t.Errorf("want %q; got %q", "", header.Get("Set-Cookie"))
}
header, _ = ts.execute(t, "/put")
token2 := extractTokenFromCookie(header.Get("Set-Cookie"))
if token1 != token2 {
t.Error("want tokens to be the same")
}
}
func TestString(t *testing.T) {
manager := NewManager(newMockStore())
func TestLifetime(t *testing.T) {
session := NewSession()
session.Lifetime = 500 * time.Millisecond
_, body, cookie := testRequest(t, testPutString(manager), "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
mux := http.NewServeMux()
mux.HandleFunc("/put", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "foo", "bar")
}))
mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := session.Get(r.Context(), "foo")
if v == nil {
http.Error(w, "foo does not exist in session", 500)
return
}
w.Write([]byte(v.(string)))
}))
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
ts := newTestServer(t, session.LoadAndSave(mux))
defer ts.Close()
ts.execute(t, "/put")
_, body := ts.execute(t, "/get")
if body != "bar" {
t.Errorf("want %q; got %q", "bar", body)
}
time.Sleep(time.Second)
_, body, cookie = testRequest(t, testPopString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
_, body = ts.execute(t, "/get")
if body != "foo does not exist in session\n" {
t.Errorf("want %q; got %q", "foo does not exist in session\n", body)
}
}
func TestObject(t *testing.T) {
manager := NewManager(newMockStore())
func TestIdleTimeout(t *testing.T) {
session := NewSession()
session.IdleTimeout = 200 * time.Millisecond
session.Lifetime = time.Second
_, body, cookie := testRequest(t, testPutObject(manager), "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
mux := http.NewServeMux()
mux.HandleFunc("/put", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "foo", "bar")
}))
mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := session.Get(r.Context(), "foo")
if v == nil {
http.Error(w, "foo does not exist in session", 500)
return
}
w.Write([]byte(v.(string)))
}))
ts := newTestServer(t, session.LoadAndSave(mux))
defer ts.Close()
ts.execute(t, "/put")
time.Sleep(100 * time.Millisecond)
ts.execute(t, "/get")
time.Sleep(150 * time.Millisecond)
_, body := ts.execute(t, "/get")
if body != "bar" {
t.Errorf("want %q; got %q", "bar", body)
}
_, body, _ = testRequest(t, testGetObject(manager), cookie)
if body != "alice: 21" {
t.Fatalf("got %q: expected %q", body, "alice: 21")
}
_, body, cookie = testRequest(t, testPopObject(manager), cookie)
if body != "alice: 21" {
t.Fatalf("got %q: expected %q", body, "alice: 21")
}
_, body, _ = testRequest(t, testGetObject(manager), cookie)
if body != ": 0" {
t.Fatalf("got %q: expected %q", body, ": 0")
time.Sleep(200 * time.Millisecond)
_, body = ts.execute(t, "/get")
if body != "foo does not exist in session\n" {
t.Errorf("want %q; got %q", "foo does not exist in session\n", body)
}
}
func TestDestroy(t *testing.T) {
store := newMockStore()
manager := NewManager(store)
session := NewSession()
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
_, body, cookie := testRequest(t, testDestroy(manager), cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
mux := http.NewServeMux()
mux.HandleFunc("/put", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session.Put(r.Context(), "foo", "bar")
}))
mux.HandleFunc("/destroy", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := session.Destroy(r.Context())
if err != nil {
http.Error(w, err.Error(), 500)
return
}
if strings.HasPrefix(cookie, fmt.Sprintf("%s=;", manager.opts.name)) == false {
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=;", manager.opts.name))
}))
mux.HandleFunc("/get", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
v := session.Get(r.Context(), "foo")
if v == nil {
http.Error(w, "foo does not exist in session", 500)
return
}
w.Write([]byte(v.(string)))
}))
ts := newTestServer(t, session.LoadAndSave(mux))
defer ts.Close()
ts.execute(t, "/put")
header, _ := ts.execute(t, "/destroy")
cookie := header.Get("Set-Cookie")
if strings.HasPrefix(cookie, fmt.Sprintf("%s=;", session.Cookie.Name)) == false {
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=;", session.Cookie.Name))
}
if strings.Contains(cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT")
@ -102,126 +196,9 @@ func TestDestroy(t *testing.T) {
if strings.Contains(cookie, "Max-Age=0") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Max-Age=0")
}
_, found, _ := store.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestRenewToken(t *testing.T) {
store := newMockStore()
manager := NewManager(store)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
_, body, cookie := testRequest(t, testRenewToken(manager), cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatal("expected a difference")
}
_, found, _ := store.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
}
func TestClear(t *testing.T) {
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
_, _, cookie = testRequest(t, testPutBool(manager), cookie)
_, body, cookie := testRequest(t, testClear(manager), cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, body, _ = testRequest(t, testGetBool(manager), cookie)
if body != "false" {
t.Fatalf("got %q: expected %q", body, "false")
}
// Check that it's a no-op if there is no data in the session
_, _, cookie = testRequest(t, testClear(manager), cookie)
if cookie != "" {
t.Fatalf("got %q: expected %q", cookie, "")
}
}
func TestKeys(t *testing.T) {
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
_, _, _ = testRequest(t, testPutBool(manager), cookie)
_, body, _ := testRequest(t, testKeys(manager), cookie)
if body != "[test_bool test_string]" {
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
}
_, _, _ = testRequest(t, testClear(manager), cookie)
_, body, _ = testRequest(t, testKeys(manager), cookie)
if body != "[]" {
t.Fatalf("got %q: expected %q", body, "[]")
}
}
func TestLoadFailure(t *testing.T) {
manager := NewManager(newMockStore())
cookie := http.Cookie{
Name: "session",
Value: "force-error",
}
_, body, _ := testRequest(t, testPutString(manager), cookie.String())
if body != "forced-error\n" {
t.Fatalf("got %q: expected %q", body, "forced-error\n")
}
}
func TestMultipleSessions(t *testing.T) {
manager1 := NewManager(newMockStore())
manager1.Name("foo")
_, _, cookie1 := testRequest(t, testPutString(manager1), "")
manager2 := NewManager(newMockStore())
manager2.Name("bar")
_, _, cookie2 := testRequest(t, testPutBool(manager2), "")
_, body, _ := testRequest(t, testGetString(manager1), cookie1)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, _ = testRequest(t, testGetBool(manager2), cookie2)
if body != "true" {
t.Fatalf("got %q: expected %q", body, "true")
}
_, body, _ = testRequest(t, testGetString(manager2), cookie2)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, body, _ = testRequest(t, testGetBool(manager1), cookie1)
if body != "false" {
t.Fatalf("got %q: expected %q", body, "false")
_, body := ts.execute(t, "/get")
if body != "foo does not exist in session\n" {
t.Errorf("want %q; got %q", "foo does not exist in session\n", body)
}
}

View File

@ -1,6 +1,8 @@
package scs
import "time"
import (
"time"
)
// Store is the interface for session stores.
type Store interface {
@ -9,19 +11,15 @@ type Store interface {
// and return nil (not an error).
Delete(token string) (err error)
// Find should return the data for a session token from the session store.
// If the session token is not found or is expired, the found return value
// should be false (and the err return value should be nil). Similarly, tampered
// Find should return the data for a session token from the store. If the
// session token is not found or is expired, the found return value should
// be false (and the err return value should be nil). Similarly, tampered
// or malformed tokens should result in a found return value of false and a
// nil err value. The err return value should be used for system errors only.
Find(token string) (b []byte, found bool, err error)
// Save should add the session token and data to the session store, with
// the given expiry time. If the session token already exists, then the data
// and expiry time should be overwritten.
Save(token string, b []byte, expiry time.Time) (err error)
}
type cookieStore interface {
MakeToken(b []byte, expiry time.Time) (token string, err error)
// Commit should add the session token and data to the store, with the given
// expiry time. If the session token already exists, then the data and
// expiry time should be overwritten.
Commit(token string, b []byte, expiry time.Time) (err error)
}

View File

@ -1,41 +0,0 @@
package scs
import (
"errors"
"time"
)
type mockStore struct {
m map[string]*mockEntry
}
type mockEntry struct {
b []byte
expiry time.Time
}
func newMockStore() *mockStore {
m := make(map[string]*mockEntry)
return &mockStore{m}
}
func (s *mockStore) Delete(token string) error {
delete(s.m, token)
return nil
}
func (s *mockStore) Find(token string) (b []byte, found bool, err error) {
if token == "force-error" {
return nil, false, errors.New("forced-error")
}
entry, exists := s.m[token]
if !exists || entry.expiry.UnixNano() < time.Now().UnixNano() {
return nil, false, nil
}
return entry.b, true, nil
}
func (s *mockStore) Save(token string, b []byte, expiry time.Time) error {
s.m[token] = &mockEntry{b, expiry}
return nil
}

View File

@ -1,188 +0,0 @@
// Package boltstore is a boltdb based session store for the SCS session package.
package boltstore
import (
"log"
"time"
"github.com/boltdb/bolt"
)
var (
dataBucketName = []byte("scs_data_bucket")
expiryBucketName = []byte("scs_expiry_bucket")
)
// BoltStore is a SCS session store backed by a boltdb file.
type BoltStore struct {
db *bolt.DB
stopCleanup chan bool
}
// New creates a BoltStore instance.
//
// The cleanupInterval parameter controls how frequently expired session data
// is removed by the background cleanup goroutine. Setting it to 0 prevents
// the cleanup goroutine from running (i.e. expired sessions will not be removed).
func New(db *bolt.DB, cleanupInterval time.Duration) *BoltStore {
db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(dataBucketName)
if err != nil {
return err
}
_, err = tx.CreateBucketIfNotExists(expiryBucketName)
return err
})
bs := &BoltStore{
db: db,
}
if cleanupInterval > 0 {
go bs.startCleanup(cleanupInterval)
}
return bs
}
// Save updates data for a given session token with a given expiry.
// Any existing data + expiry will be over-written.
func (bs *BoltStore) Save(token string, b []byte, expiry time.Time) error {
return bs.db.Update(func(tx *bolt.Tx) error {
tokenBytes := []byte(token)
bucket := tx.Bucket(dataBucketName)
err := bucket.Put(tokenBytes, b)
if err != nil {
return err
}
expiryBucket := tx.Bucket(expiryBucketName)
expBytes, err := expiry.MarshalText()
if err != nil {
return err
}
return expiryBucket.Put(tokenBytes, expBytes)
})
}
// Find returns the data for a session token.
// If the session token is not found or is expired,
// the exists flag will be false.
func (bs *BoltStore) Find(token string) (b []byte, exists bool, err error) {
var value []byte
err = bs.db.View(func(tx *bolt.Tx) error {
tokenBytes := []byte(token)
bucket := tx.Bucket(dataBucketName)
value = bucket.Get(tokenBytes)
if value == nil {
return nil
}
expiryBucket := tx.Bucket(expiryBucketName)
expiryBytes := expiryBucket.Get(tokenBytes)
if isExpired(expiryBytes) {
value = nil
}
return nil
})
return value, value != nil, err
}
// Delete removes session token and corresponding data.
func (bs *BoltStore) Delete(token string) error {
return bs.db.Update(func(tx *bolt.Tx) error {
tokenBytes := []byte(token)
return txDelete(tx, tokenBytes)
})
}
// startCleanup is a helper func to periodically call deleteExpired.
// It will stop if/when it recieves a message on stopCleanup channel.
func (bs *BoltStore) startCleanup(cleanupInterval time.Duration) {
bs.stopCleanup = make(chan bool)
ticker := time.NewTicker(cleanupInterval)
for {
select {
case <-ticker.C:
err := bs.deleteExpired()
if err != nil {
log.Println(err)
}
case <-bs.stopCleanup:
ticker.Stop()
return
}
}
}
// StopCleanup terminates the background cleanup goroutine for the BoltStore instance.
// It's rare to terminate this; generally BoltStore instances and their cleanup
// goroutines are intended to be long-lived and run for the lifetime of your
// application.
//
// There may be occasions though when your use of the BoltStore is transient. An
// example is creating a new BoltStore instance in a test function. In this scenario,
// the cleanup goroutine (which will run forever) will prevent the BoltStore object
// from being garbage collected even after the test function has finished. You
// can prevent this by manually calling StopCleanup.
func (bs *BoltStore) StopCleanup() {
if bs.stopCleanup != nil {
bs.stopCleanup <- true
}
}
// deleteExpired runs at in a separate goroutine at cleanupInterval
// as specified in the New constructor.
//
// iterate over keys in the expiry bucket,
// and delete keys that are exipred.
func (bs *BoltStore) deleteExpired() error {
var expiredKeys [][]byte
bs.db.View(func(tx *bolt.Tx) error {
b := tx.Bucket(expiryBucketName)
b.ForEach(func(k, v []byte) error {
if isExpired(v) {
expiredKeys = append(expiredKeys, k)
}
return nil
})
return nil
})
if len(expiredKeys) > 0 {
return bs.db.Update(func(tx *bolt.Tx) error {
for _, k := range expiredKeys {
if err := txDelete(tx, k); err != nil {
return err
}
}
return nil
})
}
return nil
}
// txDelete is a helper to delete a key
// from both the data + expiry bucket
// inside a transaction.
func txDelete(tx *bolt.Tx, tokenBytes []byte) error {
expiryBucket := tx.Bucket(expiryBucketName)
expiryBucket.Delete(tokenBytes)
bucket := tx.Bucket(dataBucketName)
return bucket.Delete(tokenBytes)
}
// isExpired is a helper func to unmarshal a expiry date
// and determine if it is after Now.
func isExpired(expiryBytes []byte) bool {
expiry := &time.Time{}
err := expiry.UnmarshalText(expiryBytes)
if err != nil {
return true
}
return time.Now().After(*expiry)
}

View File

@ -1,188 +0,0 @@
package boltstore
import (
"bytes"
"log"
"testing"
"time"
"github.com/boltdb/bolt"
)
func TestSave(t *testing.T) {
db, err := bolt.Open("/tmp/testing.db", 0600, nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
bs := New(db, time.Minute)
bs.Save("key1", []byte("value1"), time.Now().Add(time.Minute))
db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket(dataBucketName)
v := bucket.Get([]byte("key1"))
if !bytes.Equal(v, []byte("value1")) {
t.Fatalf("expected bytes `value1`, got %s", v)
}
return nil
})
}
func TestFind(t *testing.T) {
db, err := bolt.Open("/tmp/testing.db", 0600, nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
bs := New(db, time.Minute)
bs.Save("key1", []byte("value1"), time.Now().Add(time.Minute))
{
v, found, err := bs.Find("key1")
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, false)
}
if !bytes.Equal(v, []byte("value1")) {
t.Fatalf("got %v: expected %v", v, []byte("value1"))
}
}
{
v, found, err := bs.Find("key2")
if err != nil {
t.Fatal(err)
}
if found != false {
t.Fatalf("got %v: expected %v", found, true)
}
if v != nil {
t.Fatalf("got %v, expected %v", v, nil)
}
}
}
func TestDelete(t *testing.T) {
db, err := bolt.Open("/tmp/testing.db", 0600, nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
bs := New(db, time.Minute)
{
err := bs.Save("key1", []byte("value1"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
}
{
_, found, err := bs.Find("key1")
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v, expected %v", found, true)
}
}
{
err := bs.Delete("key1")
if err != nil {
t.Fatal(err)
}
}
{
_, found, err := bs.Find("key1")
if err != nil {
t.Fatal(err)
}
if found != false {
t.Fatalf("got %v, expected %v", found, false)
}
}
}
func TestExpire(t *testing.T) {
db, err := bolt.Open("/tmp/testing.db", 0600, nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
bs := New(db, time.Minute)
err = bs.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
_, found, _ := bs.Find("session_token")
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
time.Sleep(100 * time.Millisecond)
_, found, _ = bs.Find("session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestCleanup(t *testing.T) {
db, err := bolt.Open("/tmp/testing.db", 0600, nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
bs := New(db, time.Millisecond*10)
err = bs.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
{
err := db.View(func(tx *bolt.Tx) error {
dataBucket := tx.Bucket(dataBucketName)
expiryBucket := tx.Bucket(expiryBucketName)
data := dataBucket.Get([]byte("session_token"))
if data != nil {
t.Fatalf("expected nil, got %v", data)
}
exp := expiryBucket.Get([]byte("session_token"))
if exp != nil {
t.Fatalf("expected nil, got %v", exp)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
bs.StopCleanup()
}
func TestStopNilCleanup(t *testing.T) {
db, err := bolt.Open("/tmp/testing.db", 0600, nil)
if err != nil {
log.Fatal(err)
}
defer db.Close()
m := New(db, 0)
time.Sleep(100 * time.Millisecond)
// A send to a nil channel will block forever
m.StopCleanup()
}

View File

@ -1,56 +0,0 @@
// Package buntstore is a buntdb based session store for the SCS session package.
package buntstore
import (
"time"
"github.com/tidwall/buntdb"
)
// BuntStore is a SCS session store backed by a buntdb file.
type BuntStore struct {
db *buntdb.DB
}
// New creates a BuntStore instance.
func New(db *buntdb.DB) *BuntStore {
store := &BuntStore{
db: db,
}
return store
}
// Save updates data for a given session token with a given expiry.
// Any existing data + expiry will be over-written.
func (bs *BuntStore) Save(token string, b []byte, expiry time.Time) error {
return bs.db.Update(func(tx *buntdb.Tx) error {
_, _, err := tx.Set(token, string(b), &buntdb.SetOptions{Expires: true, TTL: expiry.Sub(time.Now())})
return err
})
}
// Find returns the data for a session token.
// If the session token is not found or is expired,
// the exists flag will be false.
func (bs *BuntStore) Find(token string) (b []byte, exists bool, err error) {
var value string
err = bs.db.View(func(tx *buntdb.Tx) error {
value, err = tx.Get(token)
return err
})
if err != nil {
if err == buntdb.ErrNotFound {
return nil, false, nil
}
return nil, false, err
}
return []byte(value), value != "", err
}
// Delete removes session token and corresponding data.
func (bs *BuntStore) Delete(token string) error {
return bs.db.Update(func(tx *buntdb.Tx) error {
_, err := tx.Delete(token)
return err
})
}

View File

@ -1,178 +0,0 @@
package buntstore
import (
"bytes"
"os"
"testing"
"time"
"strings"
"github.com/tidwall/buntdb"
)
// remove old test DB if it exists and create a new one
func getTestDatabase() *buntdb.DB {
err := os.Remove("/tmp/testing.db")
if err != nil {
panic(err)
}
db, err := buntdb.Open("/tmp/testing.db")
if err != nil {
panic(err)
}
return db
}
func TestSave(t *testing.T) {
db := getTestDatabase()
defer db.Close()
bs := New(db)
bs.Save("key1", []byte("value1"), time.Now().Add(time.Minute))
db.View(func(tx *buntdb.Tx) error {
v, err := tx.Get("key1")
if err != nil {
t.Fatalf("expected no error, got %s", err.Error())
}
if !strings.EqualFold(v, "value1") {
t.Fatalf("expected string `value1`, got %s", v)
}
return nil
})
}
func TestFind(t *testing.T) {
db := getTestDatabase()
defer db.Close()
bs := New(db)
bs.Save("key1", []byte("value1"), time.Now().Add(time.Minute))
{
v, found, err := bs.Find("key1")
if err != nil {
t.Fatal(err)
}
if !found {
t.Fatalf("got %v: expected %v (%s)", found, true, v)
}
if !bytes.Equal(v, []byte("value1")) {
t.Fatalf("got %v: expected %v", v, []byte("value1"))
}
}
{
v, found, err := bs.Find("key2")
if err != nil {
t.Fatal(err)
}
if found {
t.Fatalf("got %v: expected %v", found, false)
}
if v != nil {
t.Fatalf("got %v, expected %v", v, nil)
}
}
}
func TestDelete(t *testing.T) {
db := getTestDatabase()
defer db.Close()
bs := New(db)
{
err := bs.Save("key1", []byte("value1"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
}
{
_, found, err := bs.Find("key1")
if err != nil {
t.Fatal(err)
}
if !found {
t.Fatalf("got %v, expected %v", found, true)
}
}
{
err := bs.Delete("key1")
if err != nil {
t.Fatal(err)
}
}
{
_, found, err := bs.Find("key1")
if err != nil {
t.Fatal(err)
}
if found {
t.Fatalf("got %v, expected %v", found, false)
}
}
}
func TestExpire(t *testing.T) {
db := getTestDatabase()
defer db.Close()
bs := New(db)
err := bs.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
_, found, _ := bs.Find("session_token")
if !found {
t.Fatalf("got %v: expected %v", found, true)
}
time.Sleep(10 * time.Millisecond)
_, found, _ = bs.Find("session_token")
if !found {
t.Fatalf("got %v: expected %v", found, false)
}
time.Sleep(100 * time.Millisecond)
_, found, _ = bs.Find("session_token")
if found {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestCleanup(t *testing.T) {
db := getTestDatabase()
defer db.Close()
bs := New(db)
err := bs.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
{
err := db.View(func(tx *buntdb.Tx) error {
data, err := tx.Get("session_token")
if err == nil {
t.Fatalf("expected not found, got %s", err.Error())
}
if data != "" {
t.Fatalf("expected empty, got %v", data)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
}

View File

@ -1,134 +0,0 @@
package cookiestore
import (
"crypto/rand"
"encoding/base64"
"errors"
"strconv"
"time"
"golang.org/x/crypto/nacl/secretbox"
)
var (
errTokenTooLong = errors.New("cookiestore: encoded token length exceeded 4096 characters")
errInvalidToken = errors.New("cookiestore: token is invalid")
errInvalidExpiry = errors.New("cookiestore: expiry time is invalid")
)
// CookieStore represents the currently configured session store.
type CookieStore struct {
keys [][32]byte
}
// New returns a new CookieStore instance.
//
// The key parameter should contain the secret you want to use to authenticate and
// encrypt session cookies. This should be exactly 32 bytes long.
//
// Optionally, the variadic oldKeys parameter can be used to provide an arbitrary
// number of old Keys. This should be used to ensure that valid cookies continue
// to work correctly after key rotation.
func New(key []byte, oldKeys ...[]byte) *CookieStore {
keys := make([][32]byte, 1)
copy(keys[0][:], key)
for _, key := range oldKeys {
var newKey [32]byte
copy(newKey[:], key)
keys = append(keys, newKey)
}
return &CookieStore{
keys: keys,
}
}
// MakeToken creates a signed, optionally encrypted, cookie token for the provided
// session data. The returned token is limited to 4096 characters in length. An
// error will be returned if this is exceeded.
func (c *CookieStore) MakeToken(b []byte, expiry time.Time) (token string, err error) {
return encodeToken(c.keys[0], b, expiry)
}
// Find returns the session data for given cookie token. It loops through all
// available keys (including old keys) to try to decode the cookie. If
// the cookie could not be decoded, or has expired, the returned exists flag
// will be set to false.
func (c *CookieStore) Find(token string) (b []byte, exists bool, error error) {
for _, key := range c.keys {
b, err := decodeToken(key, token)
switch err {
case nil:
return b, true, nil
case errInvalidToken:
continue
default:
return nil, false, err
}
}
return nil, false, nil
}
// Save is a no-op. The function exists only to ensure that a CookieStore instance
// satisfies the scs.Store interface.
func (c *CookieStore) Save(token string, b []byte, expiry time.Time) error {
return nil
}
// Delete is a no-op. The function exists only to ensure that a CookieStore instance
// satisfies the scs.Store interface.
func (c *CookieStore) Delete(token string) error {
return nil
}
func encodeToken(key [32]byte, b []byte, expiry time.Time) (string, error) {
expiryTimestamp := []byte(strconv.FormatInt(expiry.UnixNano(), 10))
if len(expiryTimestamp) != 19 {
return "", errInvalidExpiry
}
message := append(expiryTimestamp, b...)
var nonce [24]byte
_, err := rand.Read(nonce[:])
if err != nil {
return "", err
}
box := secretbox.Seal(nonce[:], message, &nonce, &key)
token := base64.RawURLEncoding.EncodeToString(box)
if len(token) > 4096 {
return "", errTokenTooLong
}
return token, nil
}
func decodeToken(key [32]byte, token string) ([]byte, error) {
box, err := base64.RawURLEncoding.DecodeString(token)
if err != nil {
return nil, errInvalidToken
}
if len(box) < 24 {
return nil, errInvalidToken
}
var nonce [24]byte
copy(nonce[:], box[:24])
message, ok := secretbox.Open(nil, box[24:], &nonce, &key)
if !ok {
return nil, errInvalidToken
}
expiryTimestamp, err := strconv.ParseInt(string(message[:19]), 10, 64)
if err != nil {
return nil, errInvalidToken
}
if expiryTimestamp < time.Now().UnixNano() {
return nil, errInvalidToken
}
return message[19:], nil
}

File diff suppressed because one or more lines are too long

View File

@ -1,163 +0,0 @@
// Package dynamostore is a DynamoDB-based session store for the SCS session package.
//
// The dynamostore package relies on the aws-sdk-go client.
// (https://godoc.org/github.com/aws/aws-sdk-go/service/dynamodb)
package dynamostore
import (
"strconv"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/dynamodb"
)
// DynamoStore represents the currently configured session session store. It is essentially
// a wrapper around a DynamoDB client. And table is a table name session stored. token, data,
// expiry are key names.
type DynamoStore struct {
DB *dynamodb.DynamoDB
table string
token string
data string
expiry string
ttl string
}
const (
defaultTable = "scs_session"
defaultToken = "token"
defaultData = "data"
defaultExpiry = "expiry"
defaultTTL = "ttl"
)
// New returns a new DynamoStore instance. The client parameter shoud be a pointer to a
// aws-sdk-go DynamoDB client. See https://godoc.org/github.com/aws/aws-sdk-go/service/dynamodb#DynamoDB.
func New(dynamo *dynamodb.DynamoDB) *DynamoStore {
return NewWithOption(dynamo, defaultTable, defaultToken, defaultData, defaultExpiry, defaultTTL)
}
// NewWithOption returns a new DynamoStore instance. The client parameter shoud be a pointer to a
// aws-sdk-go DynamoDB client. See https://godoc.org/github.com/aws/aws-sdk-go/service/dynamodb#DynamoDB.
// The parameter table is DynamoDB tabel name, and token/data/expiry are key names.
func NewWithOption(dynamo *dynamodb.DynamoDB, table string, token string, data string, expiry string, ttl string) *DynamoStore {
return &DynamoStore{
DB: dynamo,
table: table,
token: token,
data: data,
expiry: expiry,
ttl: ttl,
}
}
// Find returns the data for a given session token from the DynamoStore instance. If the session
// token is not found or is expired, the returned exists flag will be set to false.
func (d *DynamoStore) Find(token string) (b []byte, found bool, err error) {
params := &dynamodb.GetItemInput{
TableName: aws.String(d.TableName()),
Key: map[string]*dynamodb.AttributeValue{
d.TokenName(): {
S: aws.String(token),
},
},
ConsistentRead: aws.Bool(true),
}
resp, err := d.DB.GetItem(params)
if err != nil {
return nil, false, err
}
if resp.Item == nil {
return nil, false, nil
}
expiry, err := strconv.ParseInt(aws.StringValue(resp.Item[d.ExpiryName()].N), 10, 64)
if err != nil {
return nil, false, err
}
if expiry < time.Now().UnixNano() {
return nil, false, d.Delete(token)
}
return resp.Item[d.DataName()].B, true, nil
}
// Save adds a session token and data to the DynamoStore instance with the given expiry time.
// If the session token already exists then the data and expiry time are updated.
func (d *DynamoStore) Save(token string, b []byte, expiry time.Time) error {
params := &dynamodb.PutItemInput{
TableName: aws.String(d.TableName()),
Item: map[string]*dynamodb.AttributeValue{
d.TokenName(): {
S: aws.String(token),
},
d.DataName(): {
B: b,
},
d.ExpiryName(): {
N: aws.String(strconv.FormatInt(expiry.UnixNano(), 10)),
},
d.TTLName(): {
// TTL is used by DynamoDB Time To Live. It must be Unix Epoch format.
// TTL cannot handle under second like milliseocnd and nanosecond, but
// Expiry can.
N: aws.String(strconv.FormatInt(expiry.Add(1*time.Second).Unix(), 10)),
},
},
}
_, err := d.DB.PutItem(params)
return err
}
// Delete removes a session token and corresponding data from the DynamoStore instance.
func (d *DynamoStore) Delete(token string) error {
params := &dynamodb.DeleteItemInput{
TableName: aws.String(d.TableName()),
Key: map[string]*dynamodb.AttributeValue{
d.TokenName(): {
S: aws.String(token),
},
},
}
_, err := d.DB.DeleteItem(params)
return err
}
// Ping checks to exisit session table in DynamoDB.
func (d *DynamoStore) Ping() error {
params := &dynamodb.DescribeTableInput{
TableName: aws.String(d.TableName()),
}
_, err := d.DB.DescribeTable(params)
return err
}
// TableName returns session table name.
func (d *DynamoStore) TableName() string {
return d.table
}
// TokenName returns session token key name.
func (d *DynamoStore) TokenName() string {
return d.token
}
// DataName returns session data key name.
func (d *DynamoStore) DataName() string {
return d.data
}
// ExpiryName returns session expiry key name.
func (d *DynamoStore) ExpiryName() string {
return d.expiry
}
// TTLName returns session expiry key name.
func (d *DynamoStore) TTLName() string {
return d.ttl
}

View File

@ -1,169 +0,0 @@
package dynamostore
import (
"bytes"
"reflect"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
awsSession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
)
const (
defaultRegion = endpoints.ApNortheast1RegionID
token = "session_token"
data = "encoded_data"
dataUpdated = "encoded_data_updated"
)
func getTestDynamoDB(t *testing.T) *dynamodb.DynamoDB {
conf := &aws.Config{Region: aws.String(defaultRegion)}
sess, err := awsSession.NewSession()
if err != nil {
t.Fatal(err)
}
dy := dynamodb.New(sess, conf)
if dy == nil {
t.Fatal("failed to create dynamodb client")
}
d := New(dy)
err = d.Ping()
if err != nil {
t.Fatal(err)
}
return dy
}
func clearTestDynamoDB(t *testing.T, dy *dynamodb.DynamoDB) {
d := New(dy)
_, found, err := d.Find(token)
if err != nil {
t.Fatal(err)
}
if !found {
return
}
err = d.Delete(token)
if err != nil {
t.Fatal(err)
}
}
func TestFind(t *testing.T) {
dy := getTestDynamoDB(t)
clearTestDynamoDB(t, dy)
d := New(dy)
err := d.Save(token, []byte(data), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
b, found, err := d.Find(token)
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
if bytes.Equal(b, []byte(data)) == false {
t.Fatalf("got %v: expected %v", b, []byte(data))
}
}
func TestFindMissing(t *testing.T) {
dy := getTestDynamoDB(t)
clearTestDynamoDB(t, dy)
d := New(dy)
_, found, err := d.Find(token)
if err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestSaveNew(t *testing.T) {
dy := getTestDynamoDB(t)
clearTestDynamoDB(t, dy)
d := New(dy)
err := d.Save(token, []byte(data), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
b, found, err := d.Find(token)
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
if reflect.DeepEqual(b, []byte(data)) != true {
t.Fatalf("got %v: expected %v", b, []byte(data))
}
}
func TestSaveUpdated(t *testing.T) {
dy := getTestDynamoDB(t)
clearTestDynamoDB(t, dy)
d := New(dy)
err := d.Save(token, []byte(data), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
_, _, err = d.Find(token)
if err != nil {
t.Fatal(err)
}
err = d.Save(token, []byte(dataUpdated), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
b, found, err := d.Find(token)
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
if reflect.DeepEqual(b, []byte(dataUpdated)) != true {
t.Fatalf("got %v: expected %v", b, []byte(dataUpdated))
}
}
func TestExpiry(t *testing.T) {
dy := getTestDynamoDB(t)
clearTestDynamoDB(t, dy)
d := New(dy)
err := d.Save(token, []byte(dataUpdated), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)
_, found, _ := d.Find(token)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}

View File

@ -1,62 +0,0 @@
package memcachedstore
import (
"time"
"github.com/bradfitz/gomemcache/memcache"
)
// Prefix controls the Memcached key prefix. You should only need to change this if there is
// a naming clash.
var Prefix = "scs:session:"
type MemcachedStore struct {
client *memcache.Client
}
// New returns a new MemcachedStore instance.
// The conn parameter should be a pointer to a gomemcache connection pool.
func New(client *memcache.Client) *MemcachedStore {
return &MemcachedStore{client}
}
// Find return the data for a session token from the MemcachedStore instance.
// If the session token is not found or is expired, the found return value
// is false (and the err return value is nil).
func (m *MemcachedStore) Find(token string) (b []byte, found bool, err error) {
item, err := m.client.Get(Prefix + token)
if err != nil {
if err == memcache.ErrCacheMiss {
return nil, false, nil
}
return nil, false, err
}
return item.Value, true, nil
}
// Save adds a session token and data to the MemcachedStore instance with the given expiry time.
// If the session token already exists then the data and expiry time are updated.
func (m *MemcachedStore) Save(token string, b []byte, expiry time.Time) error {
return m.client.Set(&memcache.Item{
Key: Prefix + token,
Value: b,
Expiration: createOffset(expiry),
})
}
// Delete removes a session token and corresponding data from the MemcachedStore instance.
func (m *MemcachedStore) Delete(token string) error {
return m.client.Delete(Prefix + token)
}
// createOffset calculates how expiration dates should be stored
// Memcached stores dates either as seconds since the Unix epoch OR as a relative offset from now
// It decides this by whether the offset is greater than the number of seconds in 30 days
func createOffset(expiry time.Time) int32 {
if expiry.After(time.Now().AddDate(0, 0, 30)) { // more than 30 days away
return int32(expiry.Unix()) // uh oh! https://en.wikipedia.org/wiki/Year_2038_problem
}
return int32(time.Until(expiry).Seconds())
}

View File

@ -1,168 +0,0 @@
package memcachedstore
import (
"bytes"
"os"
"reflect"
"testing"
"time"
"github.com/bradfitz/gomemcache/memcache"
)
func TestFind(t *testing.T) {
mc := memcache.New(os.Getenv("SESSION_MEMCACHED_TEST_ADDR"))
err := mc.DeleteAll()
if err != nil {
t.Fatal(err)
}
mc.Set(&memcache.Item{
Key: Prefix + "session_token",
Value: []byte("encoded_data"),
Expiration: 60,
})
m := New(mc)
b, found, err := m.Find("session_token")
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
if bytes.Equal(b, []byte("encoded_data")) == false {
t.Fatalf("got %v: expected %v", b, []byte("encoded_data"))
}
}
func TestFindMissing(t *testing.T) {
mc := memcache.New(os.Getenv("SESSION_MEMCACHED_TEST_ADDR"))
err := mc.DeleteAll()
if err != nil {
t.Fatal(err)
}
m := New(mc)
_, found, err := m.Find("missing_session_token")
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestSaveNew(t *testing.T) {
mc := memcache.New(os.Getenv("SESSION_MEMCACHED_TEST_ADDR"))
err := mc.DeleteAll()
if err != nil {
t.Fatal(err)
}
m := New(mc)
err = m.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
item, err := mc.Get(Prefix + "session_token")
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(item.Value, []byte("encoded_data")) == false {
t.Fatalf("got %v: expected %v", item.Value, []byte("encoded_data"))
}
}
func TestSaveUpdated(t *testing.T) {
mc := memcache.New(os.Getenv("SESSION_MEMCACHED_TEST_ADDR"))
err := mc.DeleteAll()
if err != nil {
t.Fatal(err)
}
mc.Set(&memcache.Item{
Key: Prefix + "session_token",
Value: []byte("encoded_data"),
Expiration: 60,
})
m := New(mc)
err = m.Save("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
item, err := mc.Get(Prefix + "session_token")
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(item.Value, []byte("new_encoded_data")) == false {
t.Fatalf("got %v: expected %v", item.Value, []byte("new_encoded_data"))
}
}
func TestExpiry(t *testing.T) {
mc := memcache.New(os.Getenv("SESSION_MEMCACHED_TEST_ADDR"))
err := mc.DeleteAll()
if err != nil {
t.Fatal(err)
}
m := New(mc)
err = m.Save("session_token", []byte("encoded_data"), time.Now().Add(2*time.Second))
if err != nil {
t.Fatal(err)
}
_, found, _ := m.Find("session_token")
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
time.Sleep(2 * time.Second)
_, found, _ = m.Find("session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestDelete(t *testing.T) {
mc := memcache.New(os.Getenv("SESSION_MEMCACHED_TEST_ADDR"))
err := mc.DeleteAll()
if err != nil {
t.Fatal(err)
}
mc.Set(&memcache.Item{
Key: Prefix + "session_token",
Value: []byte("encoded_data"),
Expiration: 60,
})
m := New(mc)
err = m.Delete("session_token")
if err != nil {
t.Fatal(err)
}
_, err = mc.Get(Prefix + "session_token")
if err != memcache.ErrCacheMiss {
t.Fatalf("got %v: expected %v", err, memcache.ErrCacheMiss)
}
}

View File

@ -1,69 +0,0 @@
// Package memstore is a in-memory session store for the SCS session package.
//
// Warning: Because memstore uses in-memory storage only, all session data will
// be lost when your Go program is stopped or restarted. On the upside though,
// it is blazingly fast.
//
// In production, memstore should only be used where this volatility is an acceptable
// trade-off for the high performance, and where lost session data will have a
// negligible impact on users.
//
// The memstore package provides a background 'cleanup' goroutine to delete
// expired session data. This stops the underlying cache from holding on to invalid
// sessions forever and taking up unnecessary memory.
package memstore
import (
"errors"
"time"
"github.com/patrickmn/go-cache"
)
var errTypeAssertionFailed = errors.New("type assertion failed: could not convert interface{} to []byte")
// MemStore represents the currently configured session session store. It is essentially
// a wrapper around a go-cache instance (see https://github.com/patrickmn/go-cache).
type MemStore struct {
cache *cache.Cache
}
// New returns a new MemStore instance.
//
// The cleanupInterval parameter controls how frequently expired session data
// is removed by the background 'cleanup' goroutine. Setting it to 0 prevents
// the cleanup goroutine from running (i.e. expired sessions will not be removed).
func New(cleanupInterval time.Duration) *MemStore {
return &MemStore{
cache.New(cache.DefaultExpiration, cleanupInterval),
}
}
// Find returns the data for a given session token from the MemStore instance. If the session
// token is not found or is expired, the returned exists flag will be set to false.
func (m *MemStore) Find(token string) (b []byte, exists bool, err error) {
v, exists := m.cache.Get(token)
if exists == false {
return nil, exists, nil
}
b, ok := v.([]byte)
if ok == false {
return nil, exists, errTypeAssertionFailed
}
return b, exists, nil
}
// Save adds a session token and data to the MemStore instance with the given expiry time.
// If the session token already exists then the data and expiry time are updated.
func (m *MemStore) Save(token string, b []byte, expiry time.Time) error {
m.cache.Set(token, b, expiry.Sub(time.Now()))
return nil
}
// Delete removes a session token and corresponding data from the MemStore instance.
func (m *MemStore) Delete(token string) error {
m.cache.Delete(token)
return nil
}

View File

@ -1,112 +0,0 @@
// Package pgstore is a PostgreSQL-based session store for the SCS session package.
//
// A working PostgreSQL database is required, containing a sessions table with
// the definition:
//
// CREATE TABLE sessions (
// token TEXT PRIMARY KEY,
// data BYTEA NOT NULL,
// expiry TIMESTAMPTZ NOT NULL
// );
// CREATE INDEX sessions_expiry_idx ON sessions (expiry);
//
// The pgstore package provides a background 'cleanup' goroutine to delete expired
// session data. This stops the database table from holding on to invalid sessions
// indefinitely and growing unnecessarily large.
package pgstore
import (
"database/sql"
"log"
"time"
// Register lib/pq with database/sql
_ "github.com/lib/pq"
)
// PGStore represents the currently configured session session store.
type PGStore struct {
db *sql.DB
stopCleanup chan bool
}
// New returns a new PGStore instance.
//
// The cleanupInterval parameter controls how frequently expired session data
// is removed by the background cleanup goroutine. Setting it to 0 prevents
// the cleanup goroutine from running (i.e. expired sessions will not be removed).
func New(db *sql.DB, cleanupInterval time.Duration) *PGStore {
p := &PGStore{db: db}
if cleanupInterval > 0 {
go p.startCleanup(cleanupInterval)
}
return p
}
// Find returns the data for a given session token from the PGStore instance. If
// the session token is not found or is expired, the returned exists flag will
// be set to false.
func (p *PGStore) Find(token string) (b []byte, exists bool, err error) {
row := p.db.QueryRow("SELECT data FROM sessions WHERE token = $1 AND current_timestamp < expiry", token)
err = row.Scan(&b)
if err == sql.ErrNoRows {
return nil, false, nil
} else if err != nil {
return nil, false, err
}
return b, true, nil
}
// Save adds a session token and data to the PGStore instance with the given expiry time.
// If the session token already exists then the data and expiry time are updated.
func (p *PGStore) Save(token string, b []byte, expiry time.Time) error {
_, err := p.db.Exec("INSERT INTO sessions (token, data, expiry) VALUES ($1, $2, $3) ON CONFLICT (token) DO UPDATE SET data = EXCLUDED.data, expiry = EXCLUDED.expiry", token, b, expiry)
if err != nil {
return err
}
return nil
}
// Delete removes a session token and corresponding data from the PGStore instance.
func (p *PGStore) Delete(token string) error {
_, err := p.db.Exec("DELETE FROM sessions WHERE token = $1", token)
return err
}
func (p *PGStore) startCleanup(interval time.Duration) {
p.stopCleanup = make(chan bool)
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
err := p.deleteExpired()
if err != nil {
log.Println(err)
}
case <-p.stopCleanup:
ticker.Stop()
return
}
}
}
// StopCleanup terminates the background cleanup goroutine for the PGStore instance.
// It's rare to terminate this; generally PGStore instances and their cleanup
// goroutines are intended to be long-lived and run for the lifetime of your
// application.
//
// There may be occasions though when your use of the PGStore is transient. An
// example is creating a new PGStore instance in a test function. In this scenario,
// the cleanup goroutine (which will run forever) will prevent the PGStore object
// from being garbage collected even after the test function has finished. You
// can prevent this by manually calling StopCleanup.
func (p *PGStore) StopCleanup() {
if p.stopCleanup != nil {
p.stopCleanup <- true
}
}
func (p *PGStore) deleteExpired() error {
_, err := p.db.Exec("DELETE FROM sessions WHERE expiry < current_timestamp")
return err
}

View File

@ -1,143 +0,0 @@
// Package qlstore is a ql-based session store for the SCS session package.
//
// A working ql database is required, containing a sessions table with
// the definition:
//
// CREATE TABLE sessions (
// token string,
// data blob,
// expiry time
// )
// CREATE INDEX sessions_expiry_idx ON sessions (expiry);
//
// The qlstore package provides a background 'cleanup' goroutine to delete expired
// session data. This stops the database table from holding on to invalid sessions
// indefinitely and growing unnecessarily large.
package qlstore
import (
"database/sql"
"log"
"time"
// Register ql driver with database/sql
_ "github.com/cznic/ql/driver"
)
// QLStore represents the currently configured session session store.
type QLStore struct {
*sql.DB
stopCleanup chan bool
}
// New returns a new QLStore instance.
//
// The cleanupInterval parameter controls how frequently expired session data
// is removed by the background cleanup goroutine. Setting it to 0 prevents
// the cleanup goroutine from running (i.e. expired sessions will not be removed).
func New(db *sql.DB, cleanupInterval time.Duration) *QLStore {
q := &QLStore{DB: db}
if cleanupInterval > 0 {
go q.startCleanup(cleanupInterval)
}
return q
}
func (q *QLStore) startCleanup(interval time.Duration) {
q.stopCleanup = make(chan bool)
ticker := time.NewTicker(interval)
for {
select {
case <-ticker.C:
err := q.deleteExpired()
if err != nil {
log.Println(err)
}
case <-q.stopCleanup:
ticker.Stop()
return
}
}
}
// Delete removes a session token and corresponding data from the QLStore instance.
func (q *QLStore) Delete(token string) error {
_, err := execTx(q.DB, "DELETE FROM sessions where token=$1", token)
return err
}
func (q *QLStore) deleteExpired() error {
_, err := execTx(q.DB, "DELETE FROM sessions WHERE expiry < now()")
return err
}
// Find returns the data for a given session token from the QLStore instance. If
// the session token is not found or is expired, the returned exists flag will
// be set to false.
func (q *QLStore) Find(token string) ([]byte, bool, error) {
var data []byte
query := "SELECT data FROM sessions WHERE token=$1 AND now()<expiry"
err := q.QueryRow(query, token).Scan(&data)
if err != nil {
if err == sql.ErrNoRows {
return nil, false, nil
}
return nil, false, err
}
return data, true, nil
}
// Save adds a session token and data to the QLStore instance with the given expiry time.
// If the session token already exists then the data and expiry time are updated.
func (q *QLStore) Save(token string, b []byte, expiry time.Time) error {
_, ok, _ := q.Find(token)
if ok {
_, err := execTx(q.DB, `
UPDATE sessions data=$2,expiry=$3 WHERE token=$1
`, token, b, expiry)
return err
}
_, err := execTx(q.DB, `
INSERT INTO sessions (token , data, expiry) VALUES ($1,$2,$3)
`, token, b, expiry)
return err
}
func execTx(db *sql.DB, query string, args ...interface{}) (sql.Result, error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}
defer func() {
_ = tx.Commit()
}()
r, err := tx.Exec(query, args...)
return r, err
}
// StopCleanup terminates the background cleanup goroutine for the QLStore instance.
// It's rare to terminate this; generally QLStore instances and their cleanup
// goroutines are intended to be long-lived and run for the lifetime of your
// application.
//
// There may be occasions though when your use of the QLStore is transient. An
// example is creating a new QLStore instance in a test function. In this scenario,
// the cleanup goroutine (which will run forever) will prevent the QLStore object
// from being garbage collected even after the test function has finished. You
// can prevent this by manually calling StopCleanup.
func (q *QLStore) StopCleanup() {
if q.stopCleanup != nil {
q.stopCleanup <- true
}
}
//Table provides SQL for creating a session table in ql database
func Table() string {
return `
CREATE TABLE sessions (
token string,
data blob,
expiry time
)
`
}

View File

@ -1,315 +0,0 @@
package qlstore
import (
"bytes"
"database/sql"
"os"
"reflect"
"testing"
"time"
)
func TestFind(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
_, err = execTx(db, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
ex := time.Now().Add(time.Minute)
_, err = execTx(db,
`INSERT INTO sessions VALUES("session_token", $1,$2 )`,
[]byte("encoded_data"), ex)
if err != nil {
t.Fatal(err)
}
p := New(db, 0)
b, found, err := p.Find("session_token")
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
if bytes.Equal(b, []byte("encoded_data")) == false {
t.Fatalf("got %v: expected %v", b, []byte("encoded_data"))
}
}
func TestFindMissing(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
p := New(db, 0)
_, found, err := p.Find("missing_session_token")
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestSaveNew(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
p := New(db, 0)
err = p.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
row := db.QueryRow(`SELECT data FROM sessions WHERE token = "session_token"`)
var data []byte
err = row.Scan(&data)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(data, []byte("encoded_data")) == false {
t.Fatalf("got %v: expected %v", data, []byte("encoded_data"))
}
}
func TestSaveUpdated(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
ex := time.Now().Add(time.Minute)
_, err = execTx(db,
`INSERT INTO sessions VALUES("session_token", $1,$2 )`,
[]byte("encoded_data"), ex)
if err != nil {
t.Fatal(err)
}
p := New(db, 0)
err = p.Save("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
row := db.QueryRow(`SELECT data FROM sessions WHERE token = "session_token"`)
var data []byte
err = row.Scan(&data)
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(data, []byte("new_encoded_data")) == false {
t.Fatalf("got %v: expected %v", data, []byte("new_encoded_data"))
}
}
func TestExpiry(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
p := New(db, 0)
err = p.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
_, found, _ := p.Find("session_token")
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
time.Sleep(100 * time.Millisecond)
_, found, _ = p.Find("session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestDelete(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
ex := time.Now().Add(time.Minute)
_, err = execTx(db,
`INSERT INTO sessions VALUES("session_token", $1,$2 )`,
[]byte("encoded_data"), ex)
if err != nil {
t.Fatal(err)
}
p := New(db, 0)
err = p.Delete("session_token")
if err != nil {
t.Fatal(err)
}
row := db.QueryRow(`SELECT count(*) FROM sessions WHERE token = "session_token"`)
var count int
err = row.Scan(&count)
if err != nil {
t.Fatal(err)
}
if count != 0 {
t.Fatalf("got %d: expected %d", count, 0)
}
}
func TestCleanup(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
_, err = execTx(db, "TRUNCATE TABLE sessions")
if err != nil {
t.Fatal(err)
}
p := New(db, 200*time.Millisecond)
defer p.StopCleanup()
err = p.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
row := db.QueryRow(`SELECT count(*) FROM sessions WHERE token = "session_token"`)
var count int
err = row.Scan(&count)
if err != nil {
t.Fatal(err)
}
if count != 1 {
t.Fatalf("got %d: expected %d", count, 1)
}
time.Sleep(300 * time.Millisecond)
row = db.QueryRow(`SELECT count(*) FROM sessions WHERE token = "session_token"`)
err = row.Scan(&count)
if err != nil {
t.Fatal(err)
}
if count != 0 {
t.Fatalf("got %d: expected %d", count, 0)
}
}
func TestStopNilCleanup(t *testing.T) {
dsn := os.Getenv("SESSION_QL_TEST_DSN")
if dsn == "" {
dsn = "test.db"
}
db, err := sql.Open("ql-mem", dsn)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
if err = db.Ping(); err != nil {
t.Fatal(err)
}
migrate(t, db)
p := New(db, 0)
time.Sleep(100 * time.Millisecond)
// A send to a nil channel will block forever
p.StopCleanup()
}
func migrate(t *testing.T, db *sql.DB) {
_, err := execTx(db, Table())
if err != nil {
t.Error(err)
}
}

View File

@ -1,81 +0,0 @@
// Package redisstore is a Redis-based session store for the SCS session package.
//
// Warning: The redisstore API is not finalized and may change, possibly significantly.
// The package is fine to use as-is, but it is strongly recommended that you vendor
// the package to avoid compatibility problems in the future.
//
// The redisstore package relies on the the popular Redigo Redis client
// (github.com/gomodule/redigo/redis).
package redisstore
import (
"time"
"github.com/gomodule/redigo/redis"
)
// Prefix controls the Redis key prefix. You should only need to change this if there is
// a naming clash.
var Prefix = "scs:session:"
// RedisStore represents the currently configured session session store. It is essentially
// a wrapper around a Redigo connection pool.
type RedisStore struct {
pool *redis.Pool
}
// New returns a new RedisStore instance. The pool parameter should be a pointer to a
// Redigo connection pool. See https://godoc.org/github.com/garyburd/redigo/redis#Pool.
func New(pool *redis.Pool) *RedisStore {
return &RedisStore{pool}
}
// Find returns the data for a given session token from the RedisStore instance. If the session
// token is not found or is expired, the returned exists flag will be set to false.
func (r *RedisStore) Find(token string) (b []byte, exists bool, err error) {
conn := r.pool.Get()
defer conn.Close()
b, err = redis.Bytes(conn.Do("GET", Prefix+token))
if err == redis.ErrNil {
return nil, false, nil
} else if err != nil {
return nil, false, err
}
return b, true, nil
}
// Save adds a session token and data to the RedisStore instance with the given expiry time.
// If the session token already exists then the data and expiry time are updated.
func (r *RedisStore) Save(token string, b []byte, expiry time.Time) error {
conn := r.pool.Get()
defer conn.Close()
err := conn.Send("MULTI")
if err != nil {
return err
}
err = conn.Send("SET", Prefix+token, b)
if err != nil {
return err
}
err = conn.Send("PEXPIREAT", Prefix+token, makeMillisecondTimestamp(expiry))
if err != nil {
return err
}
_, err = conn.Do("EXEC")
return err
}
// Delete removes a session token and corresponding data from the ResisStore instance.
func (r *RedisStore) Delete(token string) error {
conn := r.pool.Get()
defer conn.Close()
_, err := conn.Do("DEL", Prefix+token)
return err
}
func makeMillisecondTimestamp(t time.Time) int64 {
return t.UnixNano() / (int64(time.Millisecond) / int64(time.Nanosecond))
}

View File

@ -1,225 +0,0 @@
package redisstore
import (
"bytes"
"os"
"reflect"
"testing"
"time"
"github.com/gomodule/redigo/redis"
)
func TestFind(t *testing.T) {
redisPool := redis.NewPool(func() (redis.Conn, error) {
addr := os.Getenv("SESSION_REDIS_TEST_ADDR")
conn, err := redis.Dial("tcp", addr)
if err != nil {
return nil, err
}
return conn, err
}, 1)
defer redisPool.Close()
conn := redisPool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHDB")
if err != nil {
t.Fatal(err)
}
_, err = conn.Do("SET", Prefix+"session_token", "encoded_data")
if err != nil {
t.Fatal(err)
}
r := New(redisPool)
b, found, err := r.Find("session_token")
if err != nil {
t.Fatal(err)
}
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
if bytes.Equal(b, []byte("encoded_data")) == false {
t.Fatalf("got %v: expected %v", b, []byte("encoded_data"))
}
}
func TestSaveNew(t *testing.T) {
redisPool := redis.NewPool(func() (redis.Conn, error) {
addr := os.Getenv("SESSION_REDIS_TEST_ADDR")
conn, err := redis.Dial("tcp", addr)
if err != nil {
return nil, err
}
return conn, err
}, 1)
defer redisPool.Close()
conn := redisPool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHDB")
if err != nil {
t.Fatal(err)
}
r := New(redisPool)
err = r.Save("session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
data, err := redis.Bytes(conn.Do("GET", Prefix+"session_token"))
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(data, []byte("encoded_data")) == false {
t.Fatalf("got %v: expected %v", data, []byte("encoded_data"))
}
}
func TestFindMissing(t *testing.T) {
redisPool := redis.NewPool(func() (redis.Conn, error) {
addr := os.Getenv("SESSION_REDIS_TEST_ADDR")
conn, err := redis.Dial("tcp", addr)
if err != nil {
return nil, err
}
return conn, err
}, 1)
defer redisPool.Close()
conn := redisPool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHDB")
if err != nil {
t.Fatal(err)
}
r := New(redisPool)
_, found, err := r.Find("missing_session_token")
if err != nil {
t.Fatalf("got %v: expected %v", err, nil)
}
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestSaveUpdated(t *testing.T) {
redisPool := redis.NewPool(func() (redis.Conn, error) {
addr := os.Getenv("SESSION_REDIS_TEST_ADDR")
conn, err := redis.Dial("tcp", addr)
if err != nil {
return nil, err
}
return conn, err
}, 1)
defer redisPool.Close()
conn := redisPool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHDB")
if err != nil {
t.Fatal(err)
}
_, err = conn.Do("SET", Prefix+"session_token", "encoded_data")
if err != nil {
t.Fatal(err)
}
r := New(redisPool)
err = r.Save("session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
if err != nil {
t.Fatal(err)
}
data, err := redis.Bytes(conn.Do("GET", Prefix+"session_token"))
if err != nil {
t.Fatal(err)
}
if reflect.DeepEqual(data, []byte("new_encoded_data")) == false {
t.Fatalf("got %v: expected %v", data, []byte("new_encoded_data"))
}
}
func TestExpiry(t *testing.T) {
redisPool := redis.NewPool(func() (redis.Conn, error) {
addr := os.Getenv("SESSION_REDIS_TEST_ADDR")
conn, err := redis.Dial("tcp", addr)
if err != nil {
return nil, err
}
return conn, err
}, 1)
defer redisPool.Close()
conn := redisPool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHDB")
if err != nil {
t.Fatal(err)
}
r := New(redisPool)
err = r.Save("session_token", []byte("encoded_data"), time.Now().Add(100*time.Millisecond))
if err != nil {
t.Fatal(err)
}
_, found, _ := r.Find("session_token")
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
time.Sleep(200 * time.Millisecond)
_, found, _ = r.Find("session_token")
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestDelete(t *testing.T) {
redisPool := redis.NewPool(func() (redis.Conn, error) {
addr := os.Getenv("SESSION_REDIS_TEST_ADDR")
conn, err := redis.Dial("tcp", addr)
if err != nil {
return nil, err
}
return conn, err
}, 1)
defer redisPool.Close()
conn := redisPool.Get()
defer conn.Close()
_, err := conn.Do("FLUSHDB")
if err != nil {
t.Fatal(err)
}
_, err = conn.Do("SET", Prefix+"session_token", "encoded_data")
if err != nil {
t.Fatal(err)
}
r := New(redisPool)
err = r.Delete("session_token")
if err != nil {
t.Fatal(err)
}
data, err := conn.Do("GET", Prefix+"session_token")
if err != nil {
t.Fatal(err)
}
if data != nil {
t.Fatalf("got %v: expected %v", data, nil)
}
}