1
0
mirror of https://github.com/alexedwards/scs.git synced 2025-07-15 01:04:36 +02:00

[refactor] Deprecate Middleware

This commit is contained in:
Alex Edwards
2017-08-31 17:20:58 +02:00
parent bd7dc7a5c9
commit 688d5d4cd0
17 changed files with 1028 additions and 2070 deletions

98
manager.go Normal file
View File

@ -0,0 +1,98 @@
package scs
import (
"net/http"
"time"
"github.com/alexedwards/scs/stores/cookiestore"
)
// Manager is a session manager.
type Manager struct {
store Store
opts *options
}
// NewManager returns a pointer to a new session manager.
func NewManager(store Store) *Manager {
defaultOptions := &options{
domain: "",
httpOnly: true,
idleTimeout: 0,
lifetime: 24 * time.Hour,
path: "/",
persist: false,
secure: false,
}
return &Manager{
store: store,
opts: defaultOptions,
}
}
// Domain sets the 'Domain' attribute on the session cookie. By default it will
// be set to the domain name that the cookie was issued from.
func (m *Manager) Domain(s string) {
m.opts.domain = s
}
// HttpOnly sets the 'HttpOnly' attribute on the session cookie. The default value
// is true.
func (m *Manager) HttpOnly(b bool) {
m.opts.httpOnly = b
}
// IdleTimeout sets the maximum length of time a session can be inactive before it
// expires. For example, some applications may wish to set this so there is a timeout
// after 20 minutes of inactivity. The inactivity period is reset whenever the
// session data is changed (but not read).
//
// By default IdleTimeout is not set and there is no inactivity timeout.
func (m *Manager) IdleTimeout(t time.Duration) {
m.opts.idleTimeout = t
}
// Lifetime sets the maximum length of time that a session is valid for before
// it expires. The lifetime is an 'absolute expiry' which is set when the session
// is first created and does not change.
//
// The default value is 24 hours.
func (m *Manager) Lifetime(t time.Duration) {
m.opts.lifetime = t
}
// Path sets the 'Path' attribute on the session cookie. The default value is "/".
// Passing the empty string "" will result in it being set to the path that the
// cookie was issued from.
func (m *Manager) Path(s string) {
m.opts.path = s
}
// Persist sets whether the session cookie should be persistent or not (i.e. whether
// it should be retained after a user closes their browser).
//
// The default value is false, which means that the session cookie will be destroyed
// when the user closes their browser. If set to true, explicit 'Expires' and
// 'MaxAge' values will be added to the cookie and it will be retained by the
// user's browser until the given expiry time is reached.
func (m *Manager) Persist(b bool) {
m.opts.persist = b
}
// Secure sets the 'Secure' attribute on the session cookie. The default value
// is false. It's recommended that you set this to true and serve all requests
// over HTTPS in production environments.
func (m *Manager) Secure(b bool) {
m.opts.secure = b
}
// Load returns the session data for the current request.
func (m *Manager) Load(r *http.Request) *Session {
return load(r, m.store, m.opts)
}
func NewCookieManager(key string) *Manager {
store := cookiestore.New([]byte(key))
return NewManager(store)
}

20
options.go Normal file
View File

@ -0,0 +1,20 @@
package scs
import (
"time"
)
// CookieName changes the name of the session cookie issued to clients. Note that
// cookie names should not contain whitespace, commas, semicolons, backslashes
// or control characters as per RFC6265.
var CookieName = "session"
type options struct {
domain string
httpOnly bool
idleTimeout time.Duration
lifetime time.Duration
path string
persist bool
secure bool
}

153
options_test.go Normal file
View File

@ -0,0 +1,153 @@
package scs
import (
"strings"
"testing"
"time"
)
func TestCookieOptions(t *testing.T) {
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Path=/") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/")
}
if strings.Contains(cookie, "Domain=") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "Domain=")
}
if strings.Contains(cookie, "Secure") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "Secure")
}
if strings.Contains(cookie, "HttpOnly") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "HttpOnly")
}
manager = NewManager(newMockStore())
manager.Path("/foo")
manager.Domain("example.org")
manager.Secure(true)
manager.HttpOnly(false)
manager.Lifetime(time.Hour)
manager.Persist(true)
_, _, cookie = testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Path=/foo") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/foo")
}
if strings.Contains(cookie, "Domain=example.org") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Domain=example.org")
}
if strings.Contains(cookie, "Secure") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Secure")
}
if strings.Contains(cookie, "HttpOnly") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "HttpOnly")
}
if strings.Contains(cookie, "Max-Age=3600") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=86400")
}
if strings.Contains(cookie, "Expires=") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Expires")
}
manager = NewManager(newMockStore())
manager.Lifetime(time.Hour)
_, _, cookie = testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Max-Age=") == true {
t.Fatalf("got %q: expected not to contain %q:", cookie, "Max-Age=")
}
if strings.Contains(cookie, "Expires=") == true {
t.Fatalf("got %q: expected not to contain %q:", cookie, "Expires")
}
}
func TestLifetime(t *testing.T) {
manager := NewManager(newMockStore())
manager.Lifetime(200 * time.Millisecond)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
time.Sleep(100 * time.Millisecond)
_, _, cookie = testRequest(t, testPutString(manager), cookie)
time.Sleep(100 * time.Millisecond)
_, body, _ := testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, _, cookie = testRequest(t, testPutString(manager), cookie)
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatalf("expected a difference")
}
}
func TestIdleTimeout(t *testing.T) {
manager := NewManager(newMockStore())
manager.IdleTimeout(100 * time.Millisecond)
manager.Lifetime(500 * time.Millisecond)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
time.Sleep(150 * time.Millisecond)
_, body, _ := testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, _, cookie = testRequest(t, testPutString(manager), cookie)
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatalf("expected a difference")
}
_, _, cookie = testRequest(t, testPutString(manager), "")
oldToken = extractTokenFromCookie(cookie)
time.Sleep(75 * time.Millisecond)
_, _, cookie = testRequest(t, testPutString(manager), cookie)
time.Sleep(75 * time.Millisecond)
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, _, cookie = testRequest(t, testPutString(manager), cookie)
newToken = extractTokenFromCookie(cookie)
if newToken != oldToken {
t.Fatalf("expected the same")
}
}
func TestPersist(t *testing.T) {
manager := NewManager(newMockStore())
manager.IdleTimeout(5 * time.Minute)
manager.Persist(true)
_, _, cookie := testRequest(t, testPutString(manager), "")
if strings.Contains(cookie, "Max-Age=300") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=300")
}
}
func TestCookieName(t *testing.T) {
oldCookieName := CookieName
CookieName = "custom_cookie_name"
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
if strings.HasPrefix(cookie, "custom_cookie_name=") == false {
t.Fatalf("got %q: expected prefix %q", cookie, "custom_cookie_name=")
}
_, body, _ := testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
CookieName = oldCookieName
}

16
scs.go
View File

@ -1,16 +0,0 @@
/*
Package scs is a session manager for Go 1.7+.
It features:
* Automatic loading and saving of session data via middleware.
* Fast and very memory-efficient performance.
* Choice of PostgreSQL, MySQL, Redis, encrypted cookie and in-memory storage engines. Custom storage engines are also supported.
* Type-safe and sensible API. Designed to be safe for concurrent use.
* Supports OWASP good-practices, including absolute and idle session timeouts and easy regeneration of session tokens.
This top-level package is a wrapper for its sub-packages and doesn't actually
contain any code. You probably want to start by looking at the documentation
for the session sub-package.
*/
package scs

181
scs_test.go Normal file
View File

@ -0,0 +1,181 @@
package scs
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func testRequest(t *testing.T, h http.Handler, cookie string) (int, string, string) {
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
if cookie != "" {
r.Header.Add("Cookie", cookie)
}
h.ServeHTTP(rr, r)
code := rr.Code
body := string(rr.Body.Bytes())
cookie = rr.Header().Get("Set-Cookie")
return code, body, cookie
}
func extractTokenFromCookie(c string) string {
parts := strings.Split(c, ";")
return strings.TrimPrefix(parts[0], fmt.Sprintf("%s=", CookieName))
}
// Test Handlers
func testPutString(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.PutString(w, "test_string", "lorem ipsum")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testGetString(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
s, err := session.GetString("test_string")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, s)
}
}
func testPopString(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
s, err := session.PopString(w, "test_string")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, s)
}
}
func testPutBool(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.PutBool(w, "test_bool", true)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testGetBool(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
b, err := session.GetBool("test_bool")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%v", b)
}
}
func testPutObject(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
u := &testUser{"alice", 21}
err := session.PutObject(w, "test_object", u)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testGetObject(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
u := new(testUser)
err := session.GetObject("test_object", u)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
}
}
func testPopObject(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
u := new(testUser)
err := session.PopObject(w, "test_object", u)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
}
}
func testDestroy(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.Destroy(w)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testRenewToken(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.RenewToken(w)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testClear(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
err := session.Clear(w)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, "OK")
}
}
func testKeys(manager *Manager) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := manager.Load(r)
keys, err := session.Keys()
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "%v", keys)
}
}

View File

@ -1,7 +1,8 @@
package session
package scs
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/gob"
"encoding/json"
@ -9,6 +10,7 @@ import (
"net/http"
"sort"
"strconv"
"sync"
"time"
)
@ -16,12 +18,122 @@ import (
// received value could not be type asserted or converted into the required type.
var ErrTypeAssertionFailed = errors.New("type assertion failed")
// Session contains data for the current session.
type Session struct {
token string
data map[string]interface{}
deadline time.Time
store Store
opts *options
loadErr error
mu sync.Mutex
}
func newSession(store Store, opts *options) *Session {
return &Session{
data: make(map[string]interface{}),
deadline: time.Now().Add(opts.lifetime),
store: store,
opts: opts,
}
}
func load(r *http.Request, store Store, opts *options) *Session {
cookie, err := r.Cookie(CookieName)
if err == http.ErrNoCookie {
return newSession(store, opts)
} else if err != nil {
return &Session{loadErr: err}
}
if cookie.Value == "" {
return newSession(store, opts)
}
token := cookie.Value
j, found, err := store.Find(token)
if err != nil {
return &Session{loadErr: err}
}
if found == false {
return newSession(store, opts)
}
data, deadline, err := decodeFromJSON(j)
if err != nil {
return &Session{loadErr: err}
}
s := &Session{
token: token,
data: data,
deadline: deadline,
store: store,
opts: opts,
}
return s
}
func (s *Session) write(w http.ResponseWriter) error {
s.mu.Lock()
defer s.mu.Unlock()
j, err := encodeToJSON(s.data, s.deadline)
if err != nil {
return err
}
expiry := s.deadline
if s.opts.idleTimeout > 0 {
ie := time.Now().Add(s.opts.idleTimeout)
if ie.Before(expiry) {
expiry = ie
}
}
if ce, ok := s.store.(cookieStore); ok {
s.token, err = ce.MakeToken(j, expiry)
if err != nil {
return err
}
} else {
if s.token == "" {
s.token, err = generateToken()
if err != nil {
return err
}
}
err = s.store.Save(s.token, j, expiry)
if err != nil {
return err
}
}
cookie := &http.Cookie{
Name: CookieName,
Value: s.token,
Path: s.opts.path,
Domain: s.opts.domain,
Secure: s.opts.secure,
HttpOnly: s.opts.httpOnly,
}
if s.opts.persist == true {
// Round up expiry time to the nearest second.
cookie.Expires = time.Unix(expiry.Unix()+1, 0)
cookie.MaxAge = int(expiry.Sub(time.Now()).Seconds() + 1)
}
http.SetCookie(w, cookie)
return nil
}
// GetString returns the string value for a given key from the session data. The
// zero value for a string ("") is returned if the key does not exist. An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted or converted to a
// string.
func GetString(r *http.Request, key string) (string, error) {
v, exists, err := get(r, key)
func (s *Session) GetString(key string) (string, error) {
v, exists, err := s.get(key)
if err != nil {
return "", err
}
@ -38,16 +150,16 @@ func GetString(r *http.Request, key string) (string, error) {
// PutString adds a string value and corresponding key to the the session data.
// Any existing value for the key will be replaced.
func PutString(r *http.Request, key string, val string) error {
return put(r, key, val)
func (s *Session) PutString(w http.ResponseWriter, key string, val string) error {
return s.put(w, key, val)
}
// PopString removes the string value for a given key from the session data
// and returns it. The zero value for a string ("") is returned if the key does
// not exist. An ErrTypeAssertionFailed error is returned if the value could not
// be type asserted to a string.
func PopString(r *http.Request, key string) (string, error) {
v, exists, err := pop(r, key)
func (s *Session) PopString(w http.ResponseWriter, key string) (string, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return "", err
}
@ -65,8 +177,8 @@ func PopString(r *http.Request, key string) (string, error) {
// GetBool returns the bool value for a given key from the session data. The
// zero value for a bool (false) is returned if the key does not exist. An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted to a bool.
func GetBool(r *http.Request, key string) (bool, error) {
v, exists, err := get(r, key)
func (s *Session) GetBool(key string) (bool, error) {
v, exists, err := s.get(key)
if err != nil {
return false, err
}
@ -83,16 +195,16 @@ func GetBool(r *http.Request, key string) (bool, error) {
// PutBool adds a bool value and corresponding key to the session data. Any existing
// value for the key will be replaced.
func PutBool(r *http.Request, key string, val bool) error {
return put(r, key, val)
func (s *Session) PutBool(w http.ResponseWriter, key string, val bool) error {
return s.put(w, key, val)
}
// PopBool removes the bool value for a given key from the session data and returns
// it. The zero value for a bool (false) is returned if the key does not exist.
// An ErrTypeAssertionFailed error is returned if the value could not be type
// asserted to a bool.
func PopBool(r *http.Request, key string) (bool, error) {
v, exists, err := pop(r, key)
func (s *Session) PopBool(w http.ResponseWriter, key string) (bool, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return false, err
}
@ -110,8 +222,8 @@ func PopBool(r *http.Request, key string) (bool, error) {
// GetInt returns the int value for a given key from the session data. The zero
// value for an int (0) is returned if the key does not exist. An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted or converted to a int.
func GetInt(r *http.Request, key string) (int, error) {
v, exists, err := get(r, key)
func (s *Session) GetInt(key string) (int, error) {
v, exists, err := s.get(key)
if err != nil {
return 0, err
}
@ -130,16 +242,16 @@ func GetInt(r *http.Request, key string) (int, error) {
// PutInt adds an int value and corresponding key to the session data. Any existing
// value for the key will be replaced.
func PutInt(r *http.Request, key string, val int) error {
return put(r, key, val)
func (s *Session) PutInt(w http.ResponseWriter, key string, val int) error {
return s.put(w, key, val)
}
// PopInt removes the int value for a given key from the session data and returns
// it. The zero value for an int (0) is returned if the key does not exist. An
// ErrTypeAssertionFailed error is returned if the value could not be type asserted
// or converted to a int.
func PopInt(r *http.Request, key string) (int, error) {
v, exists, err := pop(r, key)
func (s *Session) PopInt(w http.ResponseWriter, key string) (int, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return 0, err
}
@ -156,13 +268,11 @@ func PopInt(r *http.Request, key string) (int, error) {
return 0, ErrTypeAssertionFailed
}
//
// GetInt64 returns the int64 value for a given key from the session data. The
// zero value for an int (0) is returned if the key does not exist. An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted or converted to a int64.
func GetInt64(r *http.Request, key string) (int64, error) {
v, exists, err := get(r, key)
func (s *Session) GetInt64(key string) (int64, error) {
v, exists, err := s.get(key)
if err != nil {
return 0, err
}
@ -181,16 +291,16 @@ func GetInt64(r *http.Request, key string) (int64, error) {
// PutInt64 adds an int64 value and corresponding key to the session data. Any existing
// value for the key will be replaced.
func PutInt64(r *http.Request, key string, val int64) error {
return put(r, key, val)
func (s *Session) PutInt64(w http.ResponseWriter, key string, val int64) error {
return s.put(w, key, val)
}
// PopInt64 remvoes the int64 value for a given key from the session data
// and returns it. The zero value for an int (0) is returned if the key does not
// exist. An ErrTypeAssertionFailed error is returned if the value could not be
// type asserted or converted to a int64.
func PopInt64(r *http.Request, key string) (int64, error) {
v, exists, err := pop(r, key)
func (s *Session) PopInt64(w http.ResponseWriter, key string) (int64, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return 0, err
}
@ -211,8 +321,8 @@ func PopInt64(r *http.Request, key string) (int64, error) {
// zero value for an float (0) is returned if the key does not exist. An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted or converted to a
// float64.
func GetFloat(r *http.Request, key string) (float64, error) {
v, exists, err := get(r, key)
func (s *Session) GetFloat(key string) (float64, error) {
v, exists, err := s.get(key)
if err != nil {
return 0, err
}
@ -231,16 +341,16 @@ func GetFloat(r *http.Request, key string) (float64, error) {
// PutFloat adds an float64 value and corresponding key to the session data. Any
// existing value for the key will be replaced.
func PutFloat(r *http.Request, key string, val float64) error {
return put(r, key, val)
func (s *Session) PutFloat(w http.ResponseWriter, key string, val float64) error {
return s.put(w, key, val)
}
// PopFloat removes the float64 value for a given key from the session data
// and returns it. The zero value for an float (0) is returned if the key does
// not exist. An ErrTypeAssertionFailed error is returned if the value could not
// be type asserted or converted to a float64.
func PopFloat(r *http.Request, key string) (float64, error) {
v, exists, err := pop(r, key)
func (s *Session) PopFloat(w http.ResponseWriter, key string) (float64, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return 0, err
}
@ -262,8 +372,8 @@ func PopFloat(r *http.Request, key string) (float64, error) {
// can be checked for with the time.IsZero method). An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted or converted to a
// time.Time.
func GetTime(r *http.Request, key string) (time.Time, error) {
v, exists, err := get(r, key)
func (s *Session) GetTime(key string) (time.Time, error) {
v, exists, err := s.get(key)
if err != nil {
return time.Time{}, err
}
@ -282,8 +392,8 @@ func GetTime(r *http.Request, key string) (time.Time, error) {
// PutTime adds an time.Time value and corresponding key to the session data. Any
// existing value for the key will be replaced.
func PutTime(r *http.Request, key string, val time.Time) error {
return put(r, key, val)
func (s *Session) PutTime(w http.ResponseWriter, key string, val time.Time) error {
return s.put(w, key, val)
}
// PopTime removes the time.Time value for a given key from the session data
@ -291,8 +401,8 @@ func PutTime(r *http.Request, key string, val time.Time) error {
// does not exist (this can be checked for with the time.IsZero method). An ErrTypeAssertionFailed
// error is returned if the value could not be type asserted or converted to a
// time.Time.
func PopTime(r *http.Request, key string) (time.Time, error) {
v, exists, err := pop(r, key)
func (s *Session) PopTime(w http.ResponseWriter, key string) (time.Time, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return time.Time{}, err
}
@ -313,8 +423,8 @@ func PopTime(r *http.Request, key string) (time.Time, error) {
// data. The zero value for a slice (nil) is returned if the key does not exist.
// An ErrTypeAssertionFailed error is returned if the value could not be type
// asserted or converted to []byte.
func GetBytes(r *http.Request, key string) ([]byte, error) {
v, exists, err := get(r, key)
func (s *Session) GetBytes(key string) ([]byte, error) {
v, exists, err := s.get(key)
if err != nil {
return nil, err
}
@ -333,20 +443,20 @@ func GetBytes(r *http.Request, key string) ([]byte, error) {
// PutBytes adds a byte slice ([]byte) value and corresponding key to the the
// session data. Any existing value for the key will be replaced.
func PutBytes(r *http.Request, key string, val []byte) error {
func (s *Session) PutBytes(w http.ResponseWriter, key string, val []byte) error {
if val == nil {
return errors.New("value must not be nil")
}
return put(r, key, val)
return s.put(w, key, val)
}
// PopBytes removes the byte slice ([]byte) value for a given key from the session
// data and returns it. The zero value for a slice (nil) is returned if the key
// does not exist. An ErrTypeAssertionFailed error is returned if the value could
// not be type asserted or converted to a []byte.
func PopBytes(r *http.Request, key string) ([]byte, error) {
v, exists, err := pop(r, key)
func (s *Session) PopBytes(w http.ResponseWriter, key string) ([]byte, error) {
v, exists, err := s.pop(w, key)
if err != nil {
return nil, err
}
@ -363,41 +473,14 @@ func PopBytes(r *http.Request, key string) ([]byte, error) {
return nil, ErrTypeAssertionFailed
}
/*
GetObject reads the data for a given session key into an arbitrary object
(represented by the dst parameter). It should only be used to retrieve custom
data types that have been stored using PutObject. The object represented by dst
will remain unchanged if the key does not exist.
The dst parameter must be a pointer.
Usage:
// Note that the fields on the custom type are all exported.
type User struct {
Name string
Email string
}
func getHandler(w http.ResponseWriter, r *http.Request) {
// Register the type with the encoding/gob package. Usually this would be
// done in an init() function.
gob.Register(User{})
// Initialise a pointer to a new, empty, custom object.
user := &User{}
// Read the custom object data from the session into the pointer.
err := session.GetObject(r, "user", user)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
fmt.Fprintf(w, "Name: %s, Email: %s", user.Name, user.Email)
}
*/
func GetObject(r *http.Request, key string, dst interface{}) error {
b, err := GetBytes(r, key)
// GetObject reads the data for a given session key into an arbitrary object
// (represented by the dst parameter). It should only be used to retrieve custom
// data types that have been stored using PutObject. The object represented by dst
// will remain unchanged if the key does not exist.
//
// The dst parameter must be a pointer.
func (s *Session) GetObject(key string, dst interface{}) error {
b, err := s.GetBytes(key)
if err != nil {
return err
}
@ -408,45 +491,20 @@ func GetObject(r *http.Request, key string, dst interface{}) error {
return gobDecode(b, dst)
}
/*
PutObject adds an arbitrary object and corresponding key to the the session data.
Any existing value for the key will be replaced.
The val parameter must be a pointer to your object.
PutObject is typically used to store custom data types. It encodes the object
into a gob and then into a base64-encoded string which is persisted by the
storage engine. This makes PutObject (and the accompanying GetObject and PopObject
functions) comparatively expensive operations.
Because gob encoding is used, the fields on custom types must be exported in
order to be persisted correctly. Custom data types must also be registered with
gob.Register before PutObject is called (see https://golang.org/pkg/encoding/gob/#Register).
Usage:
type User struct {
Name string
Email string
}
func putHandler(w http.ResponseWriter, r *http.Request) {
// Register the type with the encoding/gob package. Usually this would be
// done in an init() function.
gob.Register(User{})
// Initialise a pointer to a new custom object.
user := &User{"Alice", "alice@example.com"}
// Store the custom object in the session data. Important: you should pass in
// a pointer to your object, not the value.
err := session.PutObject(r, "user", user)
if err != nil {
http.Error(w, err.Error(), 500)
}
}
*/
func PutObject(r *http.Request, key string, val interface{}) error {
// PutObject adds an arbitrary object and corresponding key to the the session data.
// Any existing value for the key will be replaced.
//
// The val parameter must be a pointer to your object.
//
// PutObject is typically used to store custom data types. It encodes the object
// into a gob and then into a base64-encoded string which is persisted by the
// session store. This makes PutObject (and the accompanying GetObject and PopObject
// functions) comparatively expensive operations.
//
// Because gob encoding is used, the fields on custom types must be exported in
// order to be persisted correctly. Custom data types must also be registered with
// gob.Register before PutObject is called (see https://golang.org/pkg/encoding/gob/#Register).
func (s *Session) PutObject(w http.ResponseWriter, key string, val interface{}) error {
if val == nil {
return errors.New("value must not be nil")
}
@ -456,7 +514,7 @@ func PutObject(r *http.Request, key string, val interface{}) error {
return err
}
return PutBytes(r, key, b)
return s.PutBytes(w, key, b)
}
// PopObject removes the data for a given session key and reads it into a custom
@ -465,8 +523,8 @@ func PutObject(r *http.Request, key string, val interface{}) error {
// by dst will remain unchanged if the key does not exist.
//
// The dst parameter must be a pointer.
func PopObject(r *http.Request, key string, dst interface{}) error {
b, err := PopBytes(r, key)
func (s *Session) PopObject(w http.ResponseWriter, key string, dst interface{}) error {
b, err := s.PopBytes(w, key)
if err != nil {
return err
}
@ -480,143 +538,216 @@ func PopObject(r *http.Request, key string, dst interface{}) error {
// Keys returns a slice of all key names present in the session data, sorted
// alphabetically. If the session contains no data then an empty slice will be
// returned.
func Keys(r *http.Request) ([]string, error) {
s, err := sessionFromContext(r)
if err != nil {
return nil, err
func (s *Session) Keys() ([]string, error) {
if s.loadErr != nil {
return nil, s.loadErr
}
s.mu.Lock()
defer s.mu.Unlock()
keys := make([]string, len(s.data))
i := 0
for k := range s.data {
keys[i] = k
i++
}
s.mu.Unlock()
sort.Strings(keys)
return keys, nil
}
// Exists returns true if the given key is present in the session data.
func Exists(r *http.Request, key string) (bool, error) {
s, err := sessionFromContext(r)
if err != nil {
return false, err
func (s *Session) Exists(key string) (bool, error) {
if s.loadErr != nil {
return false, s.loadErr
}
s.mu.Lock()
_, exists := s.data[key]
s.mu.Unlock()
defer s.mu.Unlock()
_, exists := s.data[key]
return exists, nil
}
// Remove deletes the given key and corresponding value from the session data.
// If the key is not present this operation is a no-op.
func Remove(r *http.Request, key string) error {
s, err := sessionFromContext(r)
if err != nil {
return err
func (s *Session) Remove(w http.ResponseWriter, key string) error {
if s.loadErr != nil {
return s.loadErr
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return ErrAlreadyWritten
}
_, exists := s.data[key]
if exists == false {
s.mu.Unlock()
return nil
}
delete(s.data, key)
s.modified = true
return nil
s.mu.Unlock()
return s.write(w)
}
// Clear removes all data for the current session. The session token and lifetime
// are unaffected. If there is no data in the current session this operation is
// a no-op.
func Clear(r *http.Request) error {
s, err := sessionFromContext(r)
if err != nil {
return err
func (s *Session) Clear(w http.ResponseWriter) error {
if s.loadErr != nil {
return s.loadErr
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return ErrAlreadyWritten
}
if len(s.data) == 0 {
s.mu.Unlock()
return nil
}
for key := range s.data {
delete(s.data, key)
}
s.modified = true
return nil
s.mu.Unlock()
return s.write(w)
}
func get(r *http.Request, key string) (interface{}, bool, error) {
s, err := sessionFromContext(r)
if err != nil {
return nil, false, err
// RenewToken creates a new session token while retaining the current session
// data. The session lifetime is also reset.
//
// The old session token and accompanying data are deleted from the session store.
//
// To mitigate the risk of session fixation attacks, it's important that you call
// RenewToken before making any changes to privilege levels (e.g. login and
// logout operations). See https://www.owasp.org/index.php/Session_fixation for
// additional information.
func (s *Session) RenewToken(w http.ResponseWriter) error {
if s.loadErr != nil {
return s.loadErr
}
s.mu.Lock()
v, exists := s.data[key]
s.mu.Unlock()
return v, exists, nil
err := s.store.Delete(s.token)
if err != nil {
s.mu.Unlock()
return err
}
func put(r *http.Request, key string, val interface{}) error {
s, err := sessionFromContext(r)
token, err := generateToken()
if err != nil {
s.mu.Unlock()
return err
}
s.token = token
s.deadline = time.Now().Add(s.opts.lifetime)
s.mu.Unlock()
return s.write(w)
}
// Destroy deletes the current session. The session token and accompanying
// data are deleted from the session store, and the client is instructed to
// delete the session cookie.
//
// Any further operations on the session in the same request cycle will result in a
// new session being created.
//
// A new empty session will be created for any client that subsequently tries
// to use the destroyed session token.
func (s *Session) Destroy(w http.ResponseWriter) error {
if s.loadErr != nil {
return s.loadErr
}
s.mu.Lock()
defer s.mu.Unlock()
err := s.store.Delete(s.token)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return ErrAlreadyWritten
s.token = ""
for key := range s.data {
delete(s.data, key)
}
s.data[key] = val
s.modified = true
cookie := &http.Cookie{
Name: CookieName,
Value: "",
Path: s.opts.path,
Domain: s.opts.domain,
Secure: s.opts.secure,
HttpOnly: s.opts.httpOnly,
Expires: time.Unix(1, 0),
MaxAge: -1,
}
http.SetCookie(w, cookie)
return nil
}
func pop(r *http.Request, key string) (interface{}, bool, error) {
s, err := sessionFromContext(r)
if err != nil {
return "", false, err
func (s *Session) get(key string) (interface{}, bool, error) {
if s.loadErr != nil {
return nil, false, s.loadErr
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return "", false, ErrAlreadyWritten
v, exists := s.data[key]
return v, exists, nil
}
func (s *Session) put(w http.ResponseWriter, key string, val interface{}) error {
if s.loadErr != nil {
return s.loadErr
}
s.mu.Lock()
s.data[key] = val
s.mu.Unlock()
return s.write(w)
}
func (s *Session) pop(w http.ResponseWriter, key string) (interface{}, bool, error) {
if s.loadErr != nil {
return nil, false, s.loadErr
}
s.mu.Lock()
v, exists := s.data[key]
if exists == false {
s.mu.Unlock()
return nil, false, nil
}
delete(s.data, key)
s.modified = true
s.mu.Unlock()
err := s.write(w)
if err != nil {
return nil, false, err
}
return v, true, nil
}
func generateToken() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func gobEncode(v interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(v)
@ -630,3 +761,28 @@ func gobDecode(b []byte, dst interface{}) error {
buf := bytes.NewBuffer(b)
return gob.NewDecoder(buf).Decode(dst)
}
func encodeToJSON(data map[string]interface{}, deadline time.Time) ([]byte, error) {
return json.Marshal(&struct {
Data map[string]interface{} `json:"data"`
Deadline int64 `json:"deadline"`
}{
Data: data,
Deadline: deadline.UnixNano(),
})
}
func decodeFromJSON(j []byte) (map[string]interface{}, time.Time, error) {
aux := struct {
Data map[string]interface{} `json:"data"`
Deadline int64 `json:"deadline"`
}{}
dec := json.NewDecoder(bytes.NewReader(j))
dec.UseNumber()
err := dec.Decode(&aux)
if err != nil {
return nil, time.Time{}, err
}
return aux.Data, time.Unix(0, aux.Deadline), nil
}

View File

@ -1,359 +0,0 @@
package session
import (
"bytes"
"net/http"
"reflect"
"testing"
"time"
)
func TestString(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutString", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, cookie = testRequest(t, h, "/PopString", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
}
func TestBool(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutBool", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetBool", cookie)
if body != "true" {
t.Fatalf("got %q: expected %q", body, "true")
}
_, body, cookie = testRequest(t, h, "/PopBool", cookie)
if body != "true" {
t.Fatalf("got %q: expected %q", body, "true")
}
_, body, _ = testRequest(t, h, "/GetBool", cookie)
if body != "false" {
t.Fatalf("got %q: expected %q", body, "false")
}
}
func TestInt(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutInt", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetInt", cookie)
if body != "12345" {
t.Fatalf("got %q: expected %q", body, "12345")
}
_, body, cookie = testRequest(t, h, "/PopInt", cookie)
if body != "12345" {
t.Fatalf("got %q: expected %q", body, "12345")
}
_, body, _ = testRequest(t, h, "/GetInt", cookie)
if body != "0" {
t.Fatalf("got %q: expected %q", body, "0")
}
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
_ = PutInt(r, "test_int", 12345)
i, _ := GetInt(r, "test_int")
if i != 12345 {
t.Fatalf("got %d: expected %d", i, 12345)
}
}
func TestInt64(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutInt64", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetInt64", cookie)
if body != "9223372036854775807" {
t.Fatalf("got %q: expected %q", body, "9223372036854775807")
}
_, body, cookie = testRequest(t, h, "/PopInt64", cookie)
if body != "9223372036854775807" {
t.Fatalf("got %q: expected %q", body, "9223372036854775807")
}
_, body, _ = testRequest(t, h, "/GetInt64", cookie)
if body != "0" {
t.Fatalf("got %q: expected %q", body, "0")
}
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
_ = PutInt64(r, "test_int64", 9223372036854775807)
i, _ := GetInt64(r, "test_int64")
if i != 9223372036854775807 {
t.Fatalf("got %d: expected %d", i, 9223372036854775807)
}
}
func TestFloat(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutFloat", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetFloat", cookie)
if body != "12.345" {
t.Fatalf("got %q: expected %q", body, "12.345")
}
_, body, cookie = testRequest(t, h, "/PopFloat", cookie)
if body != "12.345" {
t.Fatalf("got %q: expected %q", body, "12.345")
}
_, body, _ = testRequest(t, h, "/GetFloat", cookie)
if body != "0.000" {
t.Fatalf("got %q: expected %q", body, "0.000")
}
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
_ = PutFloat(r, "test_float", 12.345)
i, _ := GetFloat(r, "test_float")
if i != 12.345 {
t.Fatalf("got %f: expected %f", i, 12.345)
}
}
func TestTime(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutTime", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetTime", cookie)
if body != "10 Nov 09 23:00 UTC" {
t.Fatalf("got %q: expected %q", body, "10 Nov 09 23:00 UTC")
}
_, body, cookie = testRequest(t, h, "/PopTime", cookie)
if body != "10 Nov 09 23:00 UTC" {
t.Fatalf("got %q: expected %q", body, "10 Nov 09 23:00 UTC")
}
_, body, _ = testRequest(t, h, "/GetTime", cookie)
if body != "01 Jan 01 00:00 UTC" {
t.Fatalf("got %q: expected %q", body, "01 Jan 01 00:00 UTC")
}
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
tt := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
_ = PutTime(r, "test_time", tt)
tm, _ := GetTime(r, "test_time")
if tm != tt {
t.Fatalf("got %v: expected %v", t, tt)
}
}
func TestBytes(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutBytes", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetBytes", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, cookie = testRequest(t, h, "/PopBytes", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, _ = testRequest(t, h, "/GetBytes", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
_ = PutBytes(r, "test_bytes", []byte("lorem ipsum"))
b, _ := GetBytes(r, "test_bytes")
if bytes.Equal(b, []byte("lorem ipsum")) == false {
t.Fatalf("got %v: expected %v", b, []byte("lorem ipsum"))
}
err := PutBytes(r, "test_bytes", nil)
if err == nil {
t.Fatalf("expected an error")
}
}
func TestObject(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, body, cookie := testRequest(t, h, "/PutObject", "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetObject", cookie)
if body != "alice: 21" {
t.Fatalf("got %q: expected %q", body, "alice: 21")
}
_, body, cookie = testRequest(t, h, "/PopObject", cookie)
if body != "alice: 21" {
t.Fatalf("got %q: expected %q", body, "alice: 21")
}
_, body, _ = testRequest(t, h, "/GetObject", cookie)
if body != ": 0" {
t.Fatalf("got %q: expected %q", body, ": 0")
}
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
u := &testUser{"bob", 65}
_ = PutObject(r, "test_object", u)
o := &testUser{}
_ = GetObject(r, "test_object", o)
if reflect.DeepEqual(u, o) == false {
t.Fatalf("got %v: expected %v", reflect.DeepEqual(u, o), false)
}
}
func TestKeys(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
_, _, _ = testRequest(t, h, "/PutBool", cookie)
_, body, _ := testRequest(t, h, "/Keys", cookie)
if body != "[test_bool test_string]" {
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
}
_, _, _ = testRequest(t, h, "/Clear", cookie)
_, body, _ = testRequest(t, h, "/Keys", cookie)
if body != "[]" {
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
}
}
func TestExists(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
_, body, _ := testRequest(t, h, "/Exists", cookie)
if body != "true" {
t.Fatalf("got %q: expected %q", body, "true")
}
_, body, _ = testRequest(t, h, "/PopString", cookie)
_, body, _ = testRequest(t, h, "/Exists", cookie)
if body != "false" {
t.Fatalf("got %q: expected %q", body, "false")
}
}
func TestRemove(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
_, _, cookie = testRequest(t, h, "/PutBool", cookie)
_, body, cookie := testRequest(t, h, "/RemoveString", cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, body, _ = testRequest(t, h, "/GetBool", cookie)
if body != "true" {
t.Fatalf("got %q: expected %q", body, "true")
}
_, _, cookie = testRequest(t, h, "/RemoveString", cookie)
if cookie != "" {
t.Fatalf("got %q: expected %q", cookie, "")
}
}
func TestClear(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
_, _, cookie = testRequest(t, h, "/PutBool", cookie)
_, body, cookie := testRequest(t, h, "/Clear", cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, body, _ = testRequest(t, h, "/GetBool", cookie)
if body != "false" {
t.Fatalf("got %q: expected %q", body, "false")
}
_, _, cookie = testRequest(t, h, "/Clear", cookie)
if cookie != "" {
t.Fatalf("got %q: expected %q", cookie, "")
}
}

View File

@ -1,115 +0,0 @@
package session
import (
"bufio"
"bytes"
"log"
"net"
"net/http"
)
// Deprecated: Middleware previously defined the signature for the session management
// middleware returned by Manage. Manage now returns a func(h http.Handler) http.Handler
// directly instead, so it's easier to use with middleware chaining packages like Alice.
type Middleware func(h http.Handler) http.Handler
/*
Manage returns a new session manager middleware instance. The first parameter
should be a valid storage engine, followed by zero or more functional options.
For example:
session.Manage(memstore.New(0))
session.Manage(memstore.New(0), session.Lifetime(14*24*time.Hour))
session.Manage(memstore.New(0),
session.Secure(true),
session.Persist(true),
session.Lifetime(14*24*time.Hour),
)
The returned session manager can be used to wrap any http.Handler. It automatically
loads sessions based on the HTTP request and saves session data as and when necessary.
*/
func Manage(engine Engine, opts ...Option) func(h http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
do := *defaultOptions
m := &manager{
h: h,
engine: engine,
opts: &do,
}
for _, option := range opts {
option(m.opts)
}
return m
}
}
type manager struct {
h http.Handler
engine Engine
opts *options
}
func (m *manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sr, err := load(r, m.engine, m.opts)
if err != nil {
m.opts.errorFunc(w, r, err)
return
}
bw := &bufferedResponseWriter{ResponseWriter: w}
m.h.ServeHTTP(bw, sr)
err = write(w, sr)
if err != nil {
m.opts.errorFunc(w, r, err)
return
}
if bw.code != 0 {
w.WriteHeader(bw.code)
}
_, err = w.Write(bw.buf.Bytes())
if err != nil {
m.opts.errorFunc(w, r, err)
}
}
type bufferedResponseWriter struct {
http.ResponseWriter
buf bytes.Buffer
code int
}
func (bw *bufferedResponseWriter) Write(b []byte) (int, error) {
return bw.buf.Write(b)
}
func (bw *bufferedResponseWriter) WriteHeader(code int) {
bw.code = code
}
func (bw *bufferedResponseWriter) Flush() {
f, ok := bw.ResponseWriter.(http.Flusher)
if ok == true {
bw.ResponseWriter.Write(bw.buf.Bytes())
f.Flush()
bw.buf.Reset()
}
}
func (bw *bufferedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj := bw.ResponseWriter.(http.Hijacker)
return hj.Hijack()
}
func defaultErrorFunc(w http.ResponseWriter, r *http.Request, err error) {
log.Output(2, err.Error())
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

View File

@ -1,61 +0,0 @@
package session
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestWriteResponse(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
code, _, _ := testRequest(t, h, "/WriteHeader", "")
if code != 418 {
t.Fatalf("got %d: expected %d", code, 418)
}
}
func TestManagerOptionsLeak(t *testing.T) {
_ = Manage(testEngine, Domain("example.org"))
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
if strings.Contains(cookie, "example.org") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "example.org")
}
}
func TestFlusher(t *testing.T) {
e := testEngine
m := Manage(e)
h := m(testServeMux)
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/Flush", nil)
if err != nil {
t.Fatal(err)
}
h.ServeHTTP(rr, r)
body := string(rr.Body.Bytes())
cookie := rr.Header().Get("Set-Cookie")
token := extractTokenFromCookie(cookie)
if body != "This is some…flushed data" {
t.Fatalf("got %q: expected %q", body, "This is some…flushed data")
}
if len(rr.Header()["Set-Cookie"]) != 1 {
t.Fatalf("got %d: expected %d", len(rr.Header()["Set-Cookie"]), 1)
}
if strings.HasPrefix(cookie, fmt.Sprintf("%s=", CookieName)) == false {
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=", CookieName))
}
_, found, _ := e.Find(token)
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
}

View File

@ -1,32 +0,0 @@
package session
import (
"net/http"
"time"
)
func NewMockRequest(r *http.Request) *http.Request {
do := *defaultOptions
s := &session{
token: "",
data: make(map[string]interface{}),
deadline: time.Now().Add(do.lifetime),
engine: &mockEngine{},
opts: &do,
}
return requestWithSession(r, s)
}
type mockEngine struct{}
func (me *mockEngine) Find(token string) (b []byte, exists bool, err error) {
return nil, false, nil
}
func (me *mockEngine) Save(token string, b []byte, expiry time.Time) error {
return nil
}
func (me *mockEngine) Delete(token string) error {
return nil
}

View File

@ -1,123 +0,0 @@
package session
import (
"net/http"
"time"
)
// ContextName changes the value of the (string) key used to store the session
// information in Request.Context. You should only need to change this if there is
// a naming clash.
var ContextName = "scs.session"
// CookieName changes the name of the session cookie issued to clients. Note that
// cookie names should not contain whitespace, commas, semicolons, backslashes
// or control characters as per RFC6265.
var CookieName = "scs.session.token"
var defaultOptions = &options{
domain: "",
errorFunc: defaultErrorFunc,
httpOnly: true,
idleTimeout: 0,
lifetime: 24 * time.Hour,
path: "/",
persist: false,
secure: false,
}
type options struct {
domain string
errorFunc func(http.ResponseWriter, *http.Request, error)
httpOnly bool
idleTimeout time.Duration
lifetime time.Duration
path string
persist bool
secure bool
}
// Option defines the functional arguments for configuring the session manager.
type Option func(*options)
// Domain sets the 'Domain' attribute on the session cookie. By default it will
// be set to the domain name that the cookie was issued from.
func Domain(s string) Option {
return func(opts *options) {
opts.domain = s
}
}
// ErrorFunc allows you to control behavior when an error is encountered loading
// or writing a session. The default behavior is for a HTTP 500 status code to
// be written to the ResponseWriter along with the plain-text error string. If
// a custom error function is set, then control will be passed to this instead.
// A typical use would be to provide a function which logs the error and returns
// a customized HTML error page.
func ErrorFunc(f func(http.ResponseWriter, *http.Request, error)) Option {
return func(opts *options) {
opts.errorFunc = f
}
}
// HttpOnly sets the 'HttpOnly' attribute on the session cookie. The default value
// is true.
func HttpOnly(b bool) Option {
return func(opts *options) {
opts.httpOnly = b
}
}
// IdleTimeout sets the maximum length of time a session can be inactive before it
// expires. For example, some applications may wish to set this so there is a timeout after
// 20 minutes of inactivity. Any client request which includes the
// session cookie and is handled by the session middleware is classed as activity.s
//
// By default IdleTimeout is not set and there is no inactivity timeout.
func IdleTimeout(t time.Duration) Option {
return func(opts *options) {
opts.idleTimeout = t
}
}
// Lifetime sets the maximum length of time that a session is valid for before
// it expires. The lifetime is an 'absolute expiry' which is set when the session
// is first created and does not change.
//
// The default value is 24 hours.
func Lifetime(t time.Duration) Option {
return func(opts *options) {
opts.lifetime = t
}
}
// Path sets the 'Path' attribute on the session cookie. The default value is "/".
// Passing the empty string "" will result in it being set to the path that the
// cookie was issued from.
func Path(s string) Option {
return func(opts *options) {
opts.path = s
}
}
// Persist sets whether the session cookie should be persistent or not (i.e. whether
// it should be retained after a user closes their browser).
//
// The default value is false, which means that the session cookie will be destroyed
// when the user closes their browser. If set to true, explicit 'Expires' and
// 'MaxAge' values will be added to the cookie and it will be retained by the
// user's browser until the given expiry time is reached.
func Persist(b bool) Option {
return func(opts *options) {
opts.persist = b
}
}
// Secure sets the 'Secure' attribute on the session cookie. The default value
// is false. It's recommended that you set this to true and serve all requests
// over HTTPS in production environments.
func Secure(b bool) Option {
return func(opts *options) {
opts.secure = b
}
}

View File

@ -1,208 +0,0 @@
package session
import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestCookieOptions(t *testing.T) {
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
if strings.Contains(cookie, "Path=/") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/")
}
if strings.Contains(cookie, "Domain=") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "Domain=")
}
if strings.Contains(cookie, "Secure") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "Secure")
}
if strings.Contains(cookie, "HttpOnly") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "HttpOnly")
}
m = Manage(testEngine, Path("/foo"), Domain("example.org"), Secure(true), HttpOnly(false), Lifetime(time.Hour), Persist(true))
h = m(testServeMux)
_, _, cookie = testRequest(t, h, "/PutString", "")
if strings.Contains(cookie, "Path=/foo") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/foo")
}
if strings.Contains(cookie, "Domain=example.org") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Domain=example.org")
}
if strings.Contains(cookie, "Secure") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Secure")
}
if strings.Contains(cookie, "HttpOnly") == true {
t.Fatalf("got %q: expected to not contain %q", cookie, "HttpOnly")
}
if strings.Contains(cookie, "Max-Age=3600") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=86400")
}
if strings.Contains(cookie, "Expires=") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Expires")
}
m = Manage(testEngine, Lifetime(time.Hour))
h = m(testServeMux)
_, _, cookie = testRequest(t, h, "/PutString", "")
if strings.Contains(cookie, "Max-Age=") == true {
t.Fatalf("got %q: expected not to contain %q:", cookie, "Max-Age=")
}
if strings.Contains(cookie, "Expires=") == true {
t.Fatalf("got %q: expected not to contain %q:", cookie, "Expires")
}
}
func TestLifetime(t *testing.T) {
m := Manage(testEngine, Lifetime(200*time.Millisecond))
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
oldToken := extractTokenFromCookie(cookie)
time.Sleep(100 * time.Millisecond)
_, _, cookie = testRequest(t, h, "/PutString", cookie)
time.Sleep(100 * time.Millisecond)
_, body, _ := testRequest(t, h, "/GetString", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, _, cookie = testRequest(t, h, "/PutString", cookie)
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatalf("expected a difference")
}
}
func TestIdleTimeout(t *testing.T) {
m := Manage(testEngine, IdleTimeout(100*time.Millisecond), Lifetime(500*time.Millisecond))
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
oldToken := extractTokenFromCookie(cookie)
time.Sleep(150 * time.Millisecond)
_, body, _ := testRequest(t, h, "/GetString", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, _, cookie = testRequest(t, h, "/PutString", cookie)
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatalf("expected a difference")
}
_, _, cookie = testRequest(t, h, "/PutString", "")
oldToken = extractTokenFromCookie(cookie)
time.Sleep(75 * time.Millisecond)
_, _, cookie = testRequest(t, h, "/GetString", cookie)
time.Sleep(75 * time.Millisecond)
_, body, cookie = testRequest(t, h, "/GetString", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, _, cookie = testRequest(t, h, "/PutString", cookie)
newToken = extractTokenFromCookie(cookie)
if newToken != oldToken {
t.Fatalf("expected the same")
}
}
func TestErrorFunc(t *testing.T) {
m := Manage(testEngine)
man, ok := m(nil).(*manager)
if ok == false {
t.Fatal("type assertion failed")
}
rr := httptest.NewRecorder()
man.opts.errorFunc(rr, new(http.Request), errors.New("testing error log..."))
if rr.Code != http.StatusInternalServerError {
t.Fatalf("got %d: expected %d", rr.Code, http.StatusInternalServerError)
}
if string(rr.Body.Bytes()) != fmt.Sprintf("%s\n", http.StatusText(http.StatusInternalServerError)) {
t.Fatalf("got %q: expected %q", string(rr.Body.Bytes()), fmt.Sprintf("%s\n", http.StatusText(http.StatusInternalServerError)))
}
m = Manage(testEngine, ErrorFunc(func(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(418)
io.WriteString(w, http.StatusText(418))
}))
man, ok = m(nil).(*manager)
if ok == false {
t.Fatal("type assertion failed")
}
rr = httptest.NewRecorder()
man.opts.errorFunc(rr, new(http.Request), errors.New("testing error log..."))
if rr.Code != 418 {
t.Fatalf("got %d: expected %d", rr.Code, 418)
}
if string(rr.Body.Bytes()) != http.StatusText(418) {
t.Fatalf("got %q: expected %q", string(rr.Body.Bytes()), http.StatusText(418))
}
}
func TestPersist(t *testing.T) {
m := Manage(testEngine, IdleTimeout(5*time.Minute), Persist(true))
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
if strings.Contains(cookie, "Max-Age=300") == false {
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=300")
}
}
func TestCookieName(t *testing.T) {
oldCookieName := CookieName
CookieName = "custom_cookie_name"
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
if strings.HasPrefix(cookie, "custom_cookie_name=") == false {
t.Fatalf("got %q: expected prefix %q", cookie, "custom_cookie_name=")
}
_, body, _ := testRequest(t, h, "/GetString", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
CookieName = oldCookieName
}
func TestContextDataName(t *testing.T) {
oldContextName := ContextName
ContextName = "custom_context_name"
m := Manage(testEngine)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
_, body, _ := testRequest(t, h, "/DumpContext", cookie)
if strings.Contains(body, "custom_context_name") == false {
t.Fatalf("got %q: expected to contain %q", body, "custom_context_name")
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
ContextName = oldContextName
}

View File

@ -1,456 +0,0 @@
/*
Package session provides session management middleware and helpers for
manipulating session data.
It should be installed alongside one of the storage engines from https://godoc.org/github.com/alexedwards/scs/engine.
For example:
$ go get github.com/alexedwards/scs/session
$ go get github.com/alexedwards/scs/engine/memstore
Basic use:
package main
import (
"io"
"net/http"
"github.com/alexedwards/scs/engine/memstore"
"github.com/alexedwards/scs/session"
)
func main() {
// Initialise a new storage engine. Here we use the memstore package, but the principles
// are the same no matter which back-end store you choose.
engine := memstore.New(0)
// Initialise the session manager middleware, passing in the storage engine as
// the first parameter. This middleware will automatically handle loading and
// saving of session data for you.
sessionManager := session.Manage(engine)
// Set up your HTTP handlers in the normal way.
mux := http.NewServeMux()
mux.HandleFunc("/put", putHandler)
mux.HandleFunc("/get", getHandler)
// Wrap your handlers with the session manager middleware.
http.ListenAndServe(":4000", sessionManager(mux))
}
func putHandler(w http.ResponseWriter, r *http.Request) {
// Use the PutString helper to store a new key and associated string value in
// the session data. Helpers are also available for bool, int, int64, float,
// time.Time and []byte data types.
err := session.PutString(r, "message", "Hello from a session!")
if err != nil {
http.Error(w, err.Error(), 500)
}
}
func getHandler(w http.ResponseWriter, r *http.Request) {
// Use the GetString helper to retrieve the string value associated with a key.
msg, err := session.GetString(r, "message")
if err != nil {
http.Error(w, err.Error(), 500)
return
}
io.WriteString(w, msg)
}
*/
package session
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"sync"
"time"
)
// ErrAlreadyWritten is returned when an attempt is made to modify the session
// data after it has already been sent to the storage engine and client.
var ErrAlreadyWritten = errors.New("session already written to the engine and http.ResponseWriter")
type session struct {
token string
data map[string]interface{}
deadline time.Time
engine Engine
opts *options
modified bool
written bool
mu sync.Mutex
}
func generateToken() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func newSession(r *http.Request, engine Engine, opts *options) (*http.Request, error) {
token, err := generateToken()
if err != nil {
return nil, err
}
s := &session{
token: token,
data: make(map[string]interface{}),
deadline: time.Now().Add(opts.lifetime),
engine: engine,
opts: opts,
}
return requestWithSession(r, s), nil
}
func load(r *http.Request, engine Engine, opts *options) (*http.Request, error) {
cookie, err := r.Cookie(CookieName)
if err == http.ErrNoCookie {
return newSession(r, engine, opts)
} else if err != nil {
return nil, err
}
if cookie.Value == "" {
return newSession(r, engine, opts)
}
token := cookie.Value
j, found, err := engine.Find(token)
if err != nil {
return nil, err
}
if found == false {
return newSession(r, engine, opts)
}
data, deadline, err := decodeFromJSON(j)
if err != nil {
return nil, err
}
s := &session{
token: token,
data: data,
deadline: deadline,
engine: engine,
opts: opts,
}
return requestWithSession(r, s), nil
}
func write(w http.ResponseWriter, r *http.Request) error {
s, err := sessionFromContext(r)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return nil
}
if s.modified == false && s.opts.idleTimeout == 0 {
return nil
}
j, err := encodeToJSON(s.data, s.deadline)
if err != nil {
return err
}
expiry := s.deadline
if s.opts.idleTimeout > 0 {
ie := time.Now().Add(s.opts.idleTimeout)
if ie.Before(expiry) {
expiry = ie
}
}
if ce, ok := s.engine.(cookieEngine); ok {
s.token, err = ce.MakeToken(j, expiry)
if err != nil {
return err
}
} else {
err = s.engine.Save(s.token, j, expiry)
if err != nil {
return err
}
}
cookie := &http.Cookie{
Name: CookieName,
Value: s.token,
Path: s.opts.path,
Domain: s.opts.domain,
Secure: s.opts.secure,
HttpOnly: s.opts.httpOnly,
}
if s.opts.persist == true {
cookie.Expires = expiry
// The addition of 0.5 means MaxAge is correctly rounded to the nearest
// second instead of being floored.
cookie.MaxAge = int(expiry.Sub(time.Now()).Seconds() + 0.5)
}
http.SetCookie(w, cookie)
s.written = true
return nil
}
/*
RegenerateToken creates a new session token while retaining the current session
data. The session lifetime is also reset.
The old session token and accompanying data are deleted from the storage engine.
To mitigate the risk of session fixation attacks, it's important that you call
RegenerateToken before making any changes to privilege levels (e.g. login and
logout operations). See https://www.owasp.org/index.php/Session_fixation for
additional information.
Usage:
func loginHandler(w http.ResponseWriter, r *http.Request) {
userID := 123
// First regenerate the session token…
err := session.RegenerateToken(r)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
// Then make the privilege-level change.
err = session.PutInt(r, "userID", userID)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
}
*/
func RegenerateToken(r *http.Request) error {
s, err := sessionFromContext(r)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return ErrAlreadyWritten
}
err = s.engine.Delete(s.token)
if err != nil {
return err
}
token, err := generateToken()
if err != nil {
return err
}
s.token = token
s.deadline = time.Now().Add(s.opts.lifetime)
s.modified = true
return nil
}
// Renew creates a new session token and removes all data for the session. The
// session lifetime is also reset.
//
// The old session token and accompanying data are deleted from the storage engine.
//
// The Renew function is essentially a concurrency-safe amalgamation of the
// RegenerateToken and Clear functions.
func Renew(r *http.Request) error {
s, err := sessionFromContext(r)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return ErrAlreadyWritten
}
err = s.engine.Delete(s.token)
if err != nil {
return err
}
token, err := generateToken()
if err != nil {
return err
}
s.token = token
for key := range s.data {
delete(s.data, key)
}
s.deadline = time.Now().Add(s.opts.lifetime)
s.modified = true
return nil
}
// Destroy deletes the current session. The session token and accompanying
// data are deleted from the storage engine, and the client is instructed to
// delete the session cookie.
//
// Destroy operations are effective immediately, and any further operations on
// the session in the same request cycle will return an ErrAlreadyWritten error.
// If you see this error you probably want to use the Renew function instead.
//
// A new empty session will be created for any client that subsequently tries
// to use the destroyed session token.
func Destroy(w http.ResponseWriter, r *http.Request) error {
s, err := sessionFromContext(r)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if s.written == true {
return ErrAlreadyWritten
}
err = s.engine.Delete(s.token)
if err != nil {
return err
}
s.token = ""
for key := range s.data {
delete(s.data, key)
}
s.modified = true
cookie := &http.Cookie{
Name: CookieName,
Value: "",
Path: s.opts.path,
Domain: s.opts.domain,
Secure: s.opts.secure,
HttpOnly: s.opts.httpOnly,
Expires: time.Unix(1, 0),
MaxAge: -1,
}
http.SetCookie(w, cookie)
s.written = true
return nil
}
// Save immediately writes the session cookie header to the ResponseWriter and
// saves the session data to the storage engine, if needed.
//
// Using Save is not normally necessary. The session middleware (which buffers
// all writes to the underlying connection) will automatically handle setting the
// cookie header and storing the data for you.
//
// However there may be instances where you wish to break out of this normal
// operation and (one way or another) write to the underlying connection before
// control is passed back to the session middleware. In these instances, where
// response headers have already been written, the middleware will be too late
// to set the cookie header. The solution is to manually call Save before performing
// any writes.
//
// An example is flushing data using the http.Flusher interface:
//
// func flushingHandler(w http.ResponseWriter, r *http.Request) {
// err := session.PutString(r, "foo", "bar")
// if err != nil {
// http.Error(w, err.Error(), 500)
// return
// }
// err = session.Save(w, r)
// if err != nil {
// http.Error(w, err.Error(), 500)
// return
// }
//
// fw, ok := w.(http.Flusher)
// if !ok {
// http.Error(w, "could not assert to http.Flusher", 500)
// return
// }
// w.Write([]byte("This is some…"))
// fw.Flush()
// w.Write([]byte("flushed data"))
// }
func Save(w http.ResponseWriter, r *http.Request) error {
s, err := sessionFromContext(r)
if err != nil {
return err
}
s.mu.Lock()
wr := s.written
s.mu.Unlock()
if wr == true {
return ErrAlreadyWritten
}
return write(w, r)
}
func sessionFromContext(r *http.Request) (*session, error) {
s, ok := r.Context().Value(ContextName).(*session)
if ok == false {
return nil, errors.New("request.Context does not contain a *session value")
}
return s, nil
}
func requestWithSession(r *http.Request, s *session) *http.Request {
ctx := context.WithValue(r.Context(), ContextName, s)
return r.WithContext(ctx)
}
func encodeToJSON(data map[string]interface{}, deadline time.Time) ([]byte, error) {
return json.Marshal(&struct {
Data map[string]interface{} `json:"data"`
Deadline int64 `json:"deadline"`
}{
Data: data,
Deadline: deadline.UnixNano(),
})
}
func decodeFromJSON(j []byte) (map[string]interface{}, time.Time, error) {
aux := struct {
Data map[string]interface{} `json:"data"`
Deadline int64 `json:"deadline"`
}{}
dec := json.NewDecoder(bytes.NewReader(j))
dec.UseNumber()
err := dec.Decode(&aux)
if err != nil {
return nil, time.Time{}, err
}
return aux.Data, time.Unix(0, aux.Deadline), nil
}

View File

@ -1,516 +0,0 @@
package session
import (
"encoding/gob"
"fmt"
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
"time"
"github.com/alexedwards/scs/engine/memstore"
)
var testEngine Engine
var testServeMux *http.ServeMux
type testUser struct {
Name string
Age int
}
func init() {
gob.Register(new(testUser))
testEngine = memstore.New(time.Minute)
testServeMux = http.NewServeMux()
testServeMux.HandleFunc("/PutString", func(w http.ResponseWriter, r *http.Request) {
err := PutString(r, "test_string", "lorem ipsum")
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetString", func(w http.ResponseWriter, r *http.Request) {
s, err := GetString(r, "test_string")
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, s)
})
testServeMux.HandleFunc("/PopString", func(w http.ResponseWriter, r *http.Request) {
s, err := PopString(r, "test_string")
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, s)
})
testServeMux.HandleFunc("/PutBool", func(w http.ResponseWriter, r *http.Request) {
err := PutBool(r, "test_bool", true)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetBool", func(w http.ResponseWriter, r *http.Request) {
b, err := GetBool(r, "test_bool")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%v", b)
})
testServeMux.HandleFunc("/PopBool", func(w http.ResponseWriter, r *http.Request) {
b, err := PopBool(r, "test_bool")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%v", b)
})
testServeMux.HandleFunc("/PutInt", func(w http.ResponseWriter, r *http.Request) {
err := PutInt(r, "test_int", 12345)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetInt", func(w http.ResponseWriter, r *http.Request) {
i, err := GetInt(r, "test_int")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%d", i)
})
testServeMux.HandleFunc("/PopInt", func(w http.ResponseWriter, r *http.Request) {
i, err := PopInt(r, "test_int")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%d", i)
})
testServeMux.HandleFunc("/PutInt64", func(w http.ResponseWriter, r *http.Request) {
err := PutInt64(r, "test_int", 9223372036854775807)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetInt64", func(w http.ResponseWriter, r *http.Request) {
i, err := GetInt64(r, "test_int")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%d", i)
})
testServeMux.HandleFunc("/PopInt64", func(w http.ResponseWriter, r *http.Request) {
i, err := PopInt64(r, "test_int")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%d", i)
})
testServeMux.HandleFunc("/PutFloat", func(w http.ResponseWriter, r *http.Request) {
err := PutFloat(r, "test_float", 12.345)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetFloat", func(w http.ResponseWriter, r *http.Request) {
f, err := GetFloat(r, "test_float")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%.3f", f)
})
testServeMux.HandleFunc("/PopFloat", func(w http.ResponseWriter, r *http.Request) {
f, err := PopFloat(r, "test_float")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%.3f", f)
})
testServeMux.HandleFunc("/PutTime", func(w http.ResponseWriter, r *http.Request) {
err := PutTime(r, "test_time", time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC))
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetTime", func(w http.ResponseWriter, r *http.Request) {
t, err := GetTime(r, "test_time")
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, t.Format(time.RFC822))
})
testServeMux.HandleFunc("/PopTime", func(w http.ResponseWriter, r *http.Request) {
t, err := PopTime(r, "test_time")
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, t.Format(time.RFC822))
})
testServeMux.HandleFunc("/PutBytes", func(w http.ResponseWriter, r *http.Request) {
err := PutBytes(r, "test_bytes", []byte("lorem ipsum"))
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetBytes", func(w http.ResponseWriter, r *http.Request) {
b, err := GetBytes(r, "test_bytes")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%s", b)
})
testServeMux.HandleFunc("/PopBytes", func(w http.ResponseWriter, r *http.Request) {
b, err := PopBytes(r, "test_bytes")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%s", b)
})
testServeMux.HandleFunc("/PutObject", func(w http.ResponseWriter, r *http.Request) {
u := &testUser{"alice", 21}
err := PutObject(r, "test_object", u)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/GetObject", func(w http.ResponseWriter, r *http.Request) {
u := new(testUser)
err := GetObject(r, "test_object", u)
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
})
testServeMux.HandleFunc("/PopObject", func(w http.ResponseWriter, r *http.Request) {
u := new(testUser)
err := PopObject(r, "test_object", u)
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
})
testServeMux.HandleFunc("/Keys", func(w http.ResponseWriter, r *http.Request) {
keys, err := Keys(r)
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%v", keys)
})
testServeMux.HandleFunc("/Exists", func(w http.ResponseWriter, r *http.Request) {
exists, err := Exists(r, "test_string")
if err != nil {
io.WriteString(w, err.Error())
return
}
fmt.Fprintf(w, "%v", exists)
})
testServeMux.HandleFunc("/RemoveString", func(w http.ResponseWriter, r *http.Request) {
err := Remove(r, "test_string")
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/Clear", func(w http.ResponseWriter, r *http.Request) {
err := Clear(r)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/Destroy", func(w http.ResponseWriter, r *http.Request) {
err := Destroy(w, r)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/RegenerateToken", func(w http.ResponseWriter, r *http.Request) {
err := RegenerateToken(r)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/Renew", func(w http.ResponseWriter, r *http.Request) {
err := Renew(r)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/Save", func(w http.ResponseWriter, r *http.Request) {
err := PutString(r, "test_string", "lorem ipsum")
if err != nil {
io.WriteString(w, err.Error())
return
}
err = Save(w, r)
if err != nil {
io.WriteString(w, err.Error())
return
}
io.WriteString(w, "OK")
})
testServeMux.HandleFunc("/Flush", func(w http.ResponseWriter, r *http.Request) {
err := PutString(r, "test_string", "lorem ipsum")
if err != nil {
io.WriteString(w, err.Error())
return
}
err = Save(w, r)
if err != nil {
io.WriteString(w, err.Error())
return
}
fw, ok := w.(http.Flusher)
if !ok {
http.Error(w, "could not assert to Flusher", 500)
return
}
w.Write([]byte("This is some…"))
fw.Flush()
w.Write([]byte("flushed data"))
})
testServeMux.HandleFunc("/WriteHeader", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
io.WriteString(w, http.StatusText(http.StatusTeapot))
})
testServeMux.HandleFunc("/DumpContext", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%v", r.Context())
})
}
func testRequest(t *testing.T, h http.Handler, path string, cookie string) (int, string, string) {
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", path, nil)
if err != nil {
t.Fatal(err)
}
if cookie != "" {
r.Header.Add("Cookie", cookie)
}
h.ServeHTTP(rr, r)
code := rr.Code
body := string(rr.Body.Bytes())
cookie = rr.Header().Get("Set-Cookie")
return code, body, cookie
}
func extractTokenFromCookie(c string) string {
parts := strings.Split(c, ";")
return strings.TrimPrefix(parts[0], fmt.Sprintf("%s=", CookieName))
}
func TestGenerateToken(t *testing.T) {
id, err := generateToken()
if err != nil {
t.Fatal(err)
}
match, err := regexp.MatchString("^[0-9a-zA-Z_\\-]{43}$", id)
if err != nil {
t.Fatal(err)
}
if match == false {
t.Errorf("got %q: should match %q", id, "^[0-9a-zA-Z_\\-]{43}$")
}
}
func TestDestroy(t *testing.T) {
e := testEngine
m := Manage(e)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
oldToken := extractTokenFromCookie(cookie)
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/Destroy", nil)
if err != nil {
t.Fatal(err)
}
r.Header.Add("Cookie", cookie)
h.ServeHTTP(rr, r)
body := string(rr.Body.Bytes())
cookie = rr.Header().Get("Set-Cookie")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
if len(rr.Header()["Set-Cookie"]) != 1 {
t.Fatalf("got %d: expected %d", len(rr.Header()["Set-Cookie"]), 1)
}
if strings.HasPrefix(cookie, fmt.Sprintf("%s=;", CookieName)) == false {
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=;", CookieName))
}
if strings.Contains(cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT")
}
if strings.Contains(cookie, "Max-Age=0") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Max-Age=0")
}
_, found, _ := e.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestRegenerateToken(t *testing.T) {
e := testEngine
m := Manage(e)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
oldToken := extractTokenFromCookie(cookie)
_, body, cookie := testRequest(t, h, "/RegenerateToken", cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatal("expected a difference")
}
_, found, _ := e.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
}
func TestRenew(t *testing.T) {
e := testEngine
m := Manage(e)
h := m(testServeMux)
_, _, cookie := testRequest(t, h, "/PutString", "")
oldToken := extractTokenFromCookie(cookie)
_, body, cookie := testRequest(t, h, "/Renew", cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatal("expected a difference")
}
_, found, _ := e.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
_, body, _ = testRequest(t, h, "/GetString", cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
}
func TestSave(t *testing.T) {
e := testEngine
m := Manage(e)
h := m(testServeMux)
rr := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/Save", nil)
if err != nil {
t.Fatal(err)
}
h.ServeHTTP(rr, r)
body := string(rr.Body.Bytes())
cookie := rr.Header().Get("Set-Cookie")
token := extractTokenFromCookie(cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
if len(rr.Header()["Set-Cookie"]) != 1 {
t.Fatalf("got %d: expected %d", len(rr.Header()["Set-Cookie"]), 1)
}
if strings.HasPrefix(cookie, fmt.Sprintf("%s=", CookieName)) == false {
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=", CookieName))
}
_, found, _ := e.Find(token)
if found != true {
t.Fatalf("got %v: expected %v", found, true)
}
}

195
session_test.go Normal file
View File

@ -0,0 +1,195 @@
package scs
import (
"encoding/gob"
"fmt"
"net/http"
"regexp"
"strings"
"testing"
)
type testUser struct {
Name string
Age int
}
func init() {
gob.Register(new(testUser))
}
func TestGenerateToken(t *testing.T) {
id, err := generateToken()
if err != nil {
t.Fatal(err)
}
match, err := regexp.MatchString("^[0-9a-zA-Z_\\-]{43}$", id)
if err != nil {
t.Fatal(err)
}
if match == false {
t.Errorf("got %q: should match %q", id, "^[0-9a-zA-Z_\\-]{43}$")
}
}
func TestString(t *testing.T) {
manager := NewManager(newMockStore())
_, body, cookie := testRequest(t, testPutString(manager), "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, cookie = testRequest(t, testPopString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
}
func TestObject(t *testing.T) {
manager := NewManager(newMockStore())
_, body, cookie := testRequest(t, testPutObject(manager), "")
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, testGetObject(manager), cookie)
if body != "alice: 21" {
t.Fatalf("got %q: expected %q", body, "alice: 21")
}
_, body, cookie = testRequest(t, testPopObject(manager), cookie)
if body != "alice: 21" {
t.Fatalf("got %q: expected %q", body, "alice: 21")
}
_, body, _ = testRequest(t, testGetObject(manager), cookie)
if body != ": 0" {
t.Fatalf("got %q: expected %q", body, ": 0")
}
}
func TestDestroy(t *testing.T) {
store := newMockStore()
manager := NewManager(store)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
_, body, cookie := testRequest(t, testDestroy(manager), cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
if strings.HasPrefix(cookie, fmt.Sprintf("%s=;", CookieName)) == false {
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=;", CookieName))
}
if strings.Contains(cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT")
}
if strings.Contains(cookie, "Max-Age=0") == false {
t.Fatalf("got %q: expected to contain %q", cookie, "Max-Age=0")
}
_, found, _ := store.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
}
func TestRenewToken(t *testing.T) {
store := newMockStore()
manager := NewManager(store)
_, _, cookie := testRequest(t, testPutString(manager), "")
oldToken := extractTokenFromCookie(cookie)
_, body, cookie := testRequest(t, testRenewToken(manager), cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
newToken := extractTokenFromCookie(cookie)
if newToken == oldToken {
t.Fatal("expected a difference")
}
_, found, _ := store.Find(oldToken)
if found != false {
t.Fatalf("got %v: expected %v", found, false)
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "lorem ipsum" {
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
}
}
func TestClear(t *testing.T) {
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
_, _, cookie = testRequest(t, testPutBool(manager), cookie)
_, body, cookie := testRequest(t, testClear(manager), cookie)
if body != "OK" {
t.Fatalf("got %q: expected %q", body, "OK")
}
_, body, _ = testRequest(t, testGetString(manager), cookie)
if body != "" {
t.Fatalf("got %q: expected %q", body, "")
}
_, body, _ = testRequest(t, testGetBool(manager), cookie)
if body != "false" {
t.Fatalf("got %q: expected %q", body, "false")
}
// Check that it's a no-op if there is no data in the session
_, _, cookie = testRequest(t, testClear(manager), cookie)
if cookie != "" {
t.Fatalf("got %q: expected %q", cookie, "")
}
}
func TestKeys(t *testing.T) {
manager := NewManager(newMockStore())
_, _, cookie := testRequest(t, testPutString(manager), "")
_, _, _ = testRequest(t, testPutBool(manager), cookie)
_, body, _ := testRequest(t, testKeys(manager), cookie)
if body != "[test_bool test_string]" {
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
}
_, _, _ = testRequest(t, testClear(manager), cookie)
_, body, _ = testRequest(t, testKeys(manager), cookie)
if body != "[]" {
t.Fatalf("got %q: expected %q", body, "[]")
}
}
func TestLoadFailure(t *testing.T) {
manager := NewManager(newMockStore())
cookie := http.Cookie{
Name: "session",
Value: "force-error",
}
_, body, _ := testRequest(t, testPutString(manager), cookie.String())
if body != "forced-error\n" {
t.Fatalf("got %q: expected %q", body, "forced-error\n")
}
}

View File

@ -1,27 +1,27 @@
package session
package scs
import "time"
// Engine is the interface for storage engines.
type Engine interface {
// Store is the interface for session stores.
type Store interface {
// Delete should remove the session token and corresponding data from the
// session engine. If the token does not exist then Delete should be a no-op
// session store. If the token does not exist then Delete should be a no-op
// and return nil (not an error).
Delete(token string) (err error)
// Find should return the data for a session token from the storage engine.
// Find should return the data for a session token from the session store.
// If the session token is not found or is expired, the found return value
// should be false (and the err return value should be nil). Similarly, tampered
// or malformed tokens should result in a found return value of false and a
// nil err value. The err return value should be used for system errors only.
Find(token string) (b []byte, found bool, err error)
// Save should add the session token and data to the storage engine, with
// Save should add the session token and data to the session store, with
// the given expiry time. If the session token already exists, then the data
// and expiry time should be overwritten.
Save(token string, b []byte, expiry time.Time) (err error)
}
type cookieEngine interface {
type cookieStore interface {
MakeToken(b []byte, expiry time.Time) (token string, err error)
}

41
store_test.go Normal file
View File

@ -0,0 +1,41 @@
package scs
import (
"errors"
"time"
)
type mockStore struct {
m map[string]*mockEntry
}
type mockEntry struct {
b []byte
expiry time.Time
}
func newMockStore() *mockStore {
m := make(map[string]*mockEntry)
return &mockStore{m}
}
func (s *mockStore) Delete(token string) error {
delete(s.m, token)
return nil
}
func (s *mockStore) Find(token string) (b []byte, found bool, err error) {
if token == "force-error" {
return nil, false, errors.New("forced-error")
}
entry, exists := s.m[token]
if !exists || entry.expiry.UnixNano() < time.Now().UnixNano() {
return nil, false, nil
}
return entry.b, true, nil
}
func (s *mockStore) Save(token string, b []byte, expiry time.Time) error {
s.m[token] = &mockEntry{b, expiry}
return nil
}