mirror of
https://github.com/alexedwards/scs.git
synced 2025-07-15 01:04:36 +02:00
[refactor] Deprecate Middleware
This commit is contained in:
98
manager.go
Normal file
98
manager.go
Normal file
@ -0,0 +1,98 @@
|
||||
package scs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/scs/stores/cookiestore"
|
||||
)
|
||||
|
||||
// Manager is a session manager.
|
||||
type Manager struct {
|
||||
store Store
|
||||
opts *options
|
||||
}
|
||||
|
||||
// NewManager returns a pointer to a new session manager.
|
||||
func NewManager(store Store) *Manager {
|
||||
defaultOptions := &options{
|
||||
domain: "",
|
||||
httpOnly: true,
|
||||
idleTimeout: 0,
|
||||
lifetime: 24 * time.Hour,
|
||||
path: "/",
|
||||
persist: false,
|
||||
secure: false,
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
store: store,
|
||||
opts: defaultOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// Domain sets the 'Domain' attribute on the session cookie. By default it will
|
||||
// be set to the domain name that the cookie was issued from.
|
||||
func (m *Manager) Domain(s string) {
|
||||
m.opts.domain = s
|
||||
}
|
||||
|
||||
// HttpOnly sets the 'HttpOnly' attribute on the session cookie. The default value
|
||||
// is true.
|
||||
func (m *Manager) HttpOnly(b bool) {
|
||||
m.opts.httpOnly = b
|
||||
}
|
||||
|
||||
// IdleTimeout sets the maximum length of time a session can be inactive before it
|
||||
// expires. For example, some applications may wish to set this so there is a timeout
|
||||
// after 20 minutes of inactivity. The inactivity period is reset whenever the
|
||||
// session data is changed (but not read).
|
||||
//
|
||||
// By default IdleTimeout is not set and there is no inactivity timeout.
|
||||
func (m *Manager) IdleTimeout(t time.Duration) {
|
||||
m.opts.idleTimeout = t
|
||||
}
|
||||
|
||||
// Lifetime sets the maximum length of time that a session is valid for before
|
||||
// it expires. The lifetime is an 'absolute expiry' which is set when the session
|
||||
// is first created and does not change.
|
||||
//
|
||||
// The default value is 24 hours.
|
||||
func (m *Manager) Lifetime(t time.Duration) {
|
||||
m.opts.lifetime = t
|
||||
}
|
||||
|
||||
// Path sets the 'Path' attribute on the session cookie. The default value is "/".
|
||||
// Passing the empty string "" will result in it being set to the path that the
|
||||
// cookie was issued from.
|
||||
func (m *Manager) Path(s string) {
|
||||
m.opts.path = s
|
||||
}
|
||||
|
||||
// Persist sets whether the session cookie should be persistent or not (i.e. whether
|
||||
// it should be retained after a user closes their browser).
|
||||
//
|
||||
// The default value is false, which means that the session cookie will be destroyed
|
||||
// when the user closes their browser. If set to true, explicit 'Expires' and
|
||||
// 'MaxAge' values will be added to the cookie and it will be retained by the
|
||||
// user's browser until the given expiry time is reached.
|
||||
func (m *Manager) Persist(b bool) {
|
||||
m.opts.persist = b
|
||||
}
|
||||
|
||||
// Secure sets the 'Secure' attribute on the session cookie. The default value
|
||||
// is false. It's recommended that you set this to true and serve all requests
|
||||
// over HTTPS in production environments.
|
||||
func (m *Manager) Secure(b bool) {
|
||||
m.opts.secure = b
|
||||
}
|
||||
|
||||
// Load returns the session data for the current request.
|
||||
func (m *Manager) Load(r *http.Request) *Session {
|
||||
return load(r, m.store, m.opts)
|
||||
}
|
||||
|
||||
func NewCookieManager(key string) *Manager {
|
||||
store := cookiestore.New([]byte(key))
|
||||
return NewManager(store)
|
||||
}
|
20
options.go
Normal file
20
options.go
Normal file
@ -0,0 +1,20 @@
|
||||
package scs
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CookieName changes the name of the session cookie issued to clients. Note that
|
||||
// cookie names should not contain whitespace, commas, semicolons, backslashes
|
||||
// or control characters as per RFC6265.
|
||||
var CookieName = "session"
|
||||
|
||||
type options struct {
|
||||
domain string
|
||||
httpOnly bool
|
||||
idleTimeout time.Duration
|
||||
lifetime time.Duration
|
||||
path string
|
||||
persist bool
|
||||
secure bool
|
||||
}
|
153
options_test.go
Normal file
153
options_test.go
Normal file
@ -0,0 +1,153 @@
|
||||
package scs
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCookieOptions(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
if strings.Contains(cookie, "Path=/") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/")
|
||||
}
|
||||
if strings.Contains(cookie, "Domain=") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "Domain=")
|
||||
}
|
||||
if strings.Contains(cookie, "Secure") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "Secure")
|
||||
}
|
||||
if strings.Contains(cookie, "HttpOnly") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "HttpOnly")
|
||||
}
|
||||
|
||||
manager = NewManager(newMockStore())
|
||||
manager.Path("/foo")
|
||||
manager.Domain("example.org")
|
||||
manager.Secure(true)
|
||||
manager.HttpOnly(false)
|
||||
manager.Lifetime(time.Hour)
|
||||
manager.Persist(true)
|
||||
|
||||
_, _, cookie = testRequest(t, testPutString(manager), "")
|
||||
if strings.Contains(cookie, "Path=/foo") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/foo")
|
||||
}
|
||||
if strings.Contains(cookie, "Domain=example.org") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Domain=example.org")
|
||||
}
|
||||
if strings.Contains(cookie, "Secure") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Secure")
|
||||
}
|
||||
if strings.Contains(cookie, "HttpOnly") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "HttpOnly")
|
||||
}
|
||||
if strings.Contains(cookie, "Max-Age=3600") == false {
|
||||
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=86400")
|
||||
}
|
||||
if strings.Contains(cookie, "Expires=") == false {
|
||||
t.Fatalf("got %q: expected to contain %q:", cookie, "Expires")
|
||||
}
|
||||
|
||||
manager = NewManager(newMockStore())
|
||||
manager.Lifetime(time.Hour)
|
||||
|
||||
_, _, cookie = testRequest(t, testPutString(manager), "")
|
||||
if strings.Contains(cookie, "Max-Age=") == true {
|
||||
t.Fatalf("got %q: expected not to contain %q:", cookie, "Max-Age=")
|
||||
}
|
||||
if strings.Contains(cookie, "Expires=") == true {
|
||||
t.Fatalf("got %q: expected not to contain %q:", cookie, "Expires")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLifetime(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
manager.Lifetime(200 * time.Millisecond)
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
_, _, cookie = testRequest(t, testPutString(manager), cookie)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
_, body, _ := testRequest(t, testGetString(manager), cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
_, _, cookie = testRequest(t, testPutString(manager), cookie)
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatalf("expected a difference")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeout(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
manager.IdleTimeout(100 * time.Millisecond)
|
||||
manager.Lifetime(500 * time.Millisecond)
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
_, body, _ := testRequest(t, testGetString(manager), cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
_, _, cookie = testRequest(t, testPutString(manager), cookie)
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatalf("expected a difference")
|
||||
}
|
||||
|
||||
_, _, cookie = testRequest(t, testPutString(manager), "")
|
||||
oldToken = extractTokenFromCookie(cookie)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
|
||||
_, _, cookie = testRequest(t, testPutString(manager), cookie)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
|
||||
_, body, _ = testRequest(t, testGetString(manager), cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
_, _, cookie = testRequest(t, testPutString(manager), cookie)
|
||||
newToken = extractTokenFromCookie(cookie)
|
||||
if newToken != oldToken {
|
||||
t.Fatalf("expected the same")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersist(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
manager.IdleTimeout(5 * time.Minute)
|
||||
manager.Persist(true)
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
if strings.Contains(cookie, "Max-Age=300") == false {
|
||||
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=300")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieName(t *testing.T) {
|
||||
oldCookieName := CookieName
|
||||
CookieName = "custom_cookie_name"
|
||||
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
if strings.HasPrefix(cookie, "custom_cookie_name=") == false {
|
||||
t.Fatalf("got %q: expected prefix %q", cookie, "custom_cookie_name=")
|
||||
}
|
||||
|
||||
_, body, _ := testRequest(t, testGetString(manager), cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
CookieName = oldCookieName
|
||||
}
|
16
scs.go
16
scs.go
@ -1,16 +0,0 @@
|
||||
/*
|
||||
Package scs is a session manager for Go 1.7+.
|
||||
|
||||
It features:
|
||||
|
||||
* Automatic loading and saving of session data via middleware.
|
||||
* Fast and very memory-efficient performance.
|
||||
* Choice of PostgreSQL, MySQL, Redis, encrypted cookie and in-memory storage engines. Custom storage engines are also supported.
|
||||
* Type-safe and sensible API. Designed to be safe for concurrent use.
|
||||
* Supports OWASP good-practices, including absolute and idle session timeouts and easy regeneration of session tokens.
|
||||
|
||||
This top-level package is a wrapper for its sub-packages and doesn't actually
|
||||
contain any code. You probably want to start by looking at the documentation
|
||||
for the session sub-package.
|
||||
*/
|
||||
package scs
|
181
scs_test.go
Normal file
181
scs_test.go
Normal file
@ -0,0 +1,181 @@
|
||||
package scs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testRequest(t *testing.T, h http.Handler, cookie string) (int, string, string) {
|
||||
rr := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cookie != "" {
|
||||
r.Header.Add("Cookie", cookie)
|
||||
}
|
||||
h.ServeHTTP(rr, r)
|
||||
|
||||
code := rr.Code
|
||||
body := string(rr.Body.Bytes())
|
||||
cookie = rr.Header().Get("Set-Cookie")
|
||||
return code, body, cookie
|
||||
}
|
||||
|
||||
func extractTokenFromCookie(c string) string {
|
||||
parts := strings.Split(c, ";")
|
||||
return strings.TrimPrefix(parts[0], fmt.Sprintf("%s=", CookieName))
|
||||
}
|
||||
|
||||
// Test Handlers
|
||||
|
||||
func testPutString(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
err := session.PutString(w, "test_string", "lorem ipsum")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func testGetString(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
s, err := session.GetString("test_string")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, s)
|
||||
}
|
||||
}
|
||||
|
||||
func testPopString(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
s, err := session.PopString(w, "test_string")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, s)
|
||||
}
|
||||
}
|
||||
|
||||
func testPutBool(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
err := session.PutBool(w, "test_bool", true)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func testGetBool(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
b, err := session.GetBool("test_bool")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v", b)
|
||||
}
|
||||
}
|
||||
|
||||
func testPutObject(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
u := &testUser{"alice", 21}
|
||||
err := session.PutObject(w, "test_object", u)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func testGetObject(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
u := new(testUser)
|
||||
err := session.GetObject("test_object", u)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func testPopObject(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
u := new(testUser)
|
||||
err := session.PopObject(w, "test_object", u)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
|
||||
}
|
||||
}
|
||||
|
||||
func testDestroy(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
err := session.Destroy(w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func testRenewToken(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
err := session.RenewToken(w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func testClear(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
err := session.Clear(w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
func testKeys(manager *Manager) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := manager.Load(r)
|
||||
keys, err := session.Keys()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v", keys)
|
||||
}
|
||||
}
|
@ -1,7 +1,8 @@
|
||||
package session
|
||||
package scs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -16,12 +18,122 @@ import (
|
||||
// received value could not be type asserted or converted into the required type.
|
||||
var ErrTypeAssertionFailed = errors.New("type assertion failed")
|
||||
|
||||
// Session contains data for the current session.
|
||||
type Session struct {
|
||||
token string
|
||||
data map[string]interface{}
|
||||
deadline time.Time
|
||||
store Store
|
||||
opts *options
|
||||
loadErr error
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newSession(store Store, opts *options) *Session {
|
||||
return &Session{
|
||||
data: make(map[string]interface{}),
|
||||
deadline: time.Now().Add(opts.lifetime),
|
||||
store: store,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
func load(r *http.Request, store Store, opts *options) *Session {
|
||||
cookie, err := r.Cookie(CookieName)
|
||||
if err == http.ErrNoCookie {
|
||||
return newSession(store, opts)
|
||||
} else if err != nil {
|
||||
return &Session{loadErr: err}
|
||||
}
|
||||
|
||||
if cookie.Value == "" {
|
||||
return newSession(store, opts)
|
||||
}
|
||||
token := cookie.Value
|
||||
|
||||
j, found, err := store.Find(token)
|
||||
if err != nil {
|
||||
return &Session{loadErr: err}
|
||||
}
|
||||
if found == false {
|
||||
return newSession(store, opts)
|
||||
}
|
||||
|
||||
data, deadline, err := decodeFromJSON(j)
|
||||
if err != nil {
|
||||
return &Session{loadErr: err}
|
||||
}
|
||||
|
||||
s := &Session{
|
||||
token: token,
|
||||
data: data,
|
||||
deadline: deadline,
|
||||
store: store,
|
||||
opts: opts,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Session) write(w http.ResponseWriter) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
j, err := encodeToJSON(s.data, s.deadline)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expiry := s.deadline
|
||||
if s.opts.idleTimeout > 0 {
|
||||
ie := time.Now().Add(s.opts.idleTimeout)
|
||||
if ie.Before(expiry) {
|
||||
expiry = ie
|
||||
}
|
||||
}
|
||||
|
||||
if ce, ok := s.store.(cookieStore); ok {
|
||||
s.token, err = ce.MakeToken(j, expiry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if s.token == "" {
|
||||
s.token, err = generateToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = s.store.Save(s.token, j, expiry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: s.token,
|
||||
Path: s.opts.path,
|
||||
Domain: s.opts.domain,
|
||||
Secure: s.opts.secure,
|
||||
HttpOnly: s.opts.httpOnly,
|
||||
}
|
||||
if s.opts.persist == true {
|
||||
// Round up expiry time to the nearest second.
|
||||
cookie.Expires = time.Unix(expiry.Unix()+1, 0)
|
||||
cookie.MaxAge = int(expiry.Sub(time.Now()).Seconds() + 1)
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetString returns the string value for a given key from the session data. The
|
||||
// zero value for a string ("") is returned if the key does not exist. An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted or converted to a
|
||||
// string.
|
||||
func GetString(r *http.Request, key string) (string, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetString(key string) (string, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -38,16 +150,16 @@ func GetString(r *http.Request, key string) (string, error) {
|
||||
|
||||
// PutString adds a string value and corresponding key to the the session data.
|
||||
// Any existing value for the key will be replaced.
|
||||
func PutString(r *http.Request, key string, val string) error {
|
||||
return put(r, key, val)
|
||||
func (s *Session) PutString(w http.ResponseWriter, key string, val string) error {
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopString removes the string value for a given key from the session data
|
||||
// and returns it. The zero value for a string ("") is returned if the key does
|
||||
// not exist. An ErrTypeAssertionFailed error is returned if the value could not
|
||||
// be type asserted to a string.
|
||||
func PopString(r *http.Request, key string) (string, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopString(w http.ResponseWriter, key string) (string, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -65,8 +177,8 @@ func PopString(r *http.Request, key string) (string, error) {
|
||||
// GetBool returns the bool value for a given key from the session data. The
|
||||
// zero value for a bool (false) is returned if the key does not exist. An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted to a bool.
|
||||
func GetBool(r *http.Request, key string) (bool, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetBool(key string) (bool, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -83,16 +195,16 @@ func GetBool(r *http.Request, key string) (bool, error) {
|
||||
|
||||
// PutBool adds a bool value and corresponding key to the session data. Any existing
|
||||
// value for the key will be replaced.
|
||||
func PutBool(r *http.Request, key string, val bool) error {
|
||||
return put(r, key, val)
|
||||
func (s *Session) PutBool(w http.ResponseWriter, key string, val bool) error {
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopBool removes the bool value for a given key from the session data and returns
|
||||
// it. The zero value for a bool (false) is returned if the key does not exist.
|
||||
// An ErrTypeAssertionFailed error is returned if the value could not be type
|
||||
// asserted to a bool.
|
||||
func PopBool(r *http.Request, key string) (bool, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopBool(w http.ResponseWriter, key string) (bool, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -110,8 +222,8 @@ func PopBool(r *http.Request, key string) (bool, error) {
|
||||
// GetInt returns the int value for a given key from the session data. The zero
|
||||
// value for an int (0) is returned if the key does not exist. An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted or converted to a int.
|
||||
func GetInt(r *http.Request, key string) (int, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetInt(key string) (int, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -130,16 +242,16 @@ func GetInt(r *http.Request, key string) (int, error) {
|
||||
|
||||
// PutInt adds an int value and corresponding key to the session data. Any existing
|
||||
// value for the key will be replaced.
|
||||
func PutInt(r *http.Request, key string, val int) error {
|
||||
return put(r, key, val)
|
||||
func (s *Session) PutInt(w http.ResponseWriter, key string, val int) error {
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopInt removes the int value for a given key from the session data and returns
|
||||
// it. The zero value for an int (0) is returned if the key does not exist. An
|
||||
// ErrTypeAssertionFailed error is returned if the value could not be type asserted
|
||||
// or converted to a int.
|
||||
func PopInt(r *http.Request, key string) (int, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopInt(w http.ResponseWriter, key string) (int, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -156,13 +268,11 @@ func PopInt(r *http.Request, key string) (int, error) {
|
||||
return 0, ErrTypeAssertionFailed
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
// GetInt64 returns the int64 value for a given key from the session data. The
|
||||
// zero value for an int (0) is returned if the key does not exist. An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted or converted to a int64.
|
||||
func GetInt64(r *http.Request, key string) (int64, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetInt64(key string) (int64, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -181,16 +291,16 @@ func GetInt64(r *http.Request, key string) (int64, error) {
|
||||
|
||||
// PutInt64 adds an int64 value and corresponding key to the session data. Any existing
|
||||
// value for the key will be replaced.
|
||||
func PutInt64(r *http.Request, key string, val int64) error {
|
||||
return put(r, key, val)
|
||||
func (s *Session) PutInt64(w http.ResponseWriter, key string, val int64) error {
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopInt64 remvoes the int64 value for a given key from the session data
|
||||
// and returns it. The zero value for an int (0) is returned if the key does not
|
||||
// exist. An ErrTypeAssertionFailed error is returned if the value could not be
|
||||
// type asserted or converted to a int64.
|
||||
func PopInt64(r *http.Request, key string) (int64, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopInt64(w http.ResponseWriter, key string) (int64, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -211,8 +321,8 @@ func PopInt64(r *http.Request, key string) (int64, error) {
|
||||
// zero value for an float (0) is returned if the key does not exist. An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted or converted to a
|
||||
// float64.
|
||||
func GetFloat(r *http.Request, key string) (float64, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetFloat(key string) (float64, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -231,16 +341,16 @@ func GetFloat(r *http.Request, key string) (float64, error) {
|
||||
|
||||
// PutFloat adds an float64 value and corresponding key to the session data. Any
|
||||
// existing value for the key will be replaced.
|
||||
func PutFloat(r *http.Request, key string, val float64) error {
|
||||
return put(r, key, val)
|
||||
func (s *Session) PutFloat(w http.ResponseWriter, key string, val float64) error {
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopFloat removes the float64 value for a given key from the session data
|
||||
// and returns it. The zero value for an float (0) is returned if the key does
|
||||
// not exist. An ErrTypeAssertionFailed error is returned if the value could not
|
||||
// be type asserted or converted to a float64.
|
||||
func PopFloat(r *http.Request, key string) (float64, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopFloat(w http.ResponseWriter, key string) (float64, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -262,8 +372,8 @@ func PopFloat(r *http.Request, key string) (float64, error) {
|
||||
// can be checked for with the time.IsZero method). An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted or converted to a
|
||||
// time.Time.
|
||||
func GetTime(r *http.Request, key string) (time.Time, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetTime(key string) (time.Time, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
@ -282,8 +392,8 @@ func GetTime(r *http.Request, key string) (time.Time, error) {
|
||||
|
||||
// PutTime adds an time.Time value and corresponding key to the session data. Any
|
||||
// existing value for the key will be replaced.
|
||||
func PutTime(r *http.Request, key string, val time.Time) error {
|
||||
return put(r, key, val)
|
||||
func (s *Session) PutTime(w http.ResponseWriter, key string, val time.Time) error {
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopTime removes the time.Time value for a given key from the session data
|
||||
@ -291,8 +401,8 @@ func PutTime(r *http.Request, key string, val time.Time) error {
|
||||
// does not exist (this can be checked for with the time.IsZero method). An ErrTypeAssertionFailed
|
||||
// error is returned if the value could not be type asserted or converted to a
|
||||
// time.Time.
|
||||
func PopTime(r *http.Request, key string) (time.Time, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopTime(w http.ResponseWriter, key string) (time.Time, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
@ -313,8 +423,8 @@ func PopTime(r *http.Request, key string) (time.Time, error) {
|
||||
// data. The zero value for a slice (nil) is returned if the key does not exist.
|
||||
// An ErrTypeAssertionFailed error is returned if the value could not be type
|
||||
// asserted or converted to []byte.
|
||||
func GetBytes(r *http.Request, key string) ([]byte, error) {
|
||||
v, exists, err := get(r, key)
|
||||
func (s *Session) GetBytes(key string) ([]byte, error) {
|
||||
v, exists, err := s.get(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -333,20 +443,20 @@ func GetBytes(r *http.Request, key string) ([]byte, error) {
|
||||
|
||||
// PutBytes adds a byte slice ([]byte) value and corresponding key to the the
|
||||
// session data. Any existing value for the key will be replaced.
|
||||
func PutBytes(r *http.Request, key string, val []byte) error {
|
||||
func (s *Session) PutBytes(w http.ResponseWriter, key string, val []byte) error {
|
||||
if val == nil {
|
||||
return errors.New("value must not be nil")
|
||||
}
|
||||
|
||||
return put(r, key, val)
|
||||
return s.put(w, key, val)
|
||||
}
|
||||
|
||||
// PopBytes removes the byte slice ([]byte) value for a given key from the session
|
||||
// data and returns it. The zero value for a slice (nil) is returned if the key
|
||||
// does not exist. An ErrTypeAssertionFailed error is returned if the value could
|
||||
// not be type asserted or converted to a []byte.
|
||||
func PopBytes(r *http.Request, key string) ([]byte, error) {
|
||||
v, exists, err := pop(r, key)
|
||||
func (s *Session) PopBytes(w http.ResponseWriter, key string) ([]byte, error) {
|
||||
v, exists, err := s.pop(w, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -363,41 +473,14 @@ func PopBytes(r *http.Request, key string) ([]byte, error) {
|
||||
return nil, ErrTypeAssertionFailed
|
||||
}
|
||||
|
||||
/*
|
||||
GetObject reads the data for a given session key into an arbitrary object
|
||||
(represented by the dst parameter). It should only be used to retrieve custom
|
||||
data types that have been stored using PutObject. The object represented by dst
|
||||
will remain unchanged if the key does not exist.
|
||||
|
||||
The dst parameter must be a pointer.
|
||||
|
||||
Usage:
|
||||
|
||||
// Note that the fields on the custom type are all exported.
|
||||
type User struct {
|
||||
Name string
|
||||
Email string
|
||||
}
|
||||
|
||||
func getHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Register the type with the encoding/gob package. Usually this would be
|
||||
// done in an init() function.
|
||||
gob.Register(User{})
|
||||
|
||||
// Initialise a pointer to a new, empty, custom object.
|
||||
user := &User{}
|
||||
|
||||
// Read the custom object data from the session into the pointer.
|
||||
err := session.GetObject(r, "user", user)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "Name: %s, Email: %s", user.Name, user.Email)
|
||||
}
|
||||
*/
|
||||
func GetObject(r *http.Request, key string, dst interface{}) error {
|
||||
b, err := GetBytes(r, key)
|
||||
// GetObject reads the data for a given session key into an arbitrary object
|
||||
// (represented by the dst parameter). It should only be used to retrieve custom
|
||||
// data types that have been stored using PutObject. The object represented by dst
|
||||
// will remain unchanged if the key does not exist.
|
||||
//
|
||||
// The dst parameter must be a pointer.
|
||||
func (s *Session) GetObject(key string, dst interface{}) error {
|
||||
b, err := s.GetBytes(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -408,45 +491,20 @@ func GetObject(r *http.Request, key string, dst interface{}) error {
|
||||
return gobDecode(b, dst)
|
||||
}
|
||||
|
||||
/*
|
||||
PutObject adds an arbitrary object and corresponding key to the the session data.
|
||||
Any existing value for the key will be replaced.
|
||||
|
||||
The val parameter must be a pointer to your object.
|
||||
|
||||
PutObject is typically used to store custom data types. It encodes the object
|
||||
into a gob and then into a base64-encoded string which is persisted by the
|
||||
storage engine. This makes PutObject (and the accompanying GetObject and PopObject
|
||||
functions) comparatively expensive operations.
|
||||
|
||||
Because gob encoding is used, the fields on custom types must be exported in
|
||||
order to be persisted correctly. Custom data types must also be registered with
|
||||
gob.Register before PutObject is called (see https://golang.org/pkg/encoding/gob/#Register).
|
||||
|
||||
Usage:
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
Email string
|
||||
}
|
||||
|
||||
func putHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Register the type with the encoding/gob package. Usually this would be
|
||||
// done in an init() function.
|
||||
gob.Register(User{})
|
||||
|
||||
// Initialise a pointer to a new custom object.
|
||||
user := &User{"Alice", "alice@example.com"}
|
||||
|
||||
// Store the custom object in the session data. Important: you should pass in
|
||||
// a pointer to your object, not the value.
|
||||
err := session.PutObject(r, "user", user)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
}
|
||||
}
|
||||
*/
|
||||
func PutObject(r *http.Request, key string, val interface{}) error {
|
||||
// PutObject adds an arbitrary object and corresponding key to the the session data.
|
||||
// Any existing value for the key will be replaced.
|
||||
//
|
||||
// The val parameter must be a pointer to your object.
|
||||
//
|
||||
// PutObject is typically used to store custom data types. It encodes the object
|
||||
// into a gob and then into a base64-encoded string which is persisted by the
|
||||
// session store. This makes PutObject (and the accompanying GetObject and PopObject
|
||||
// functions) comparatively expensive operations.
|
||||
//
|
||||
// Because gob encoding is used, the fields on custom types must be exported in
|
||||
// order to be persisted correctly. Custom data types must also be registered with
|
||||
// gob.Register before PutObject is called (see https://golang.org/pkg/encoding/gob/#Register).
|
||||
func (s *Session) PutObject(w http.ResponseWriter, key string, val interface{}) error {
|
||||
if val == nil {
|
||||
return errors.New("value must not be nil")
|
||||
}
|
||||
@ -456,7 +514,7 @@ func PutObject(r *http.Request, key string, val interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return PutBytes(r, key, b)
|
||||
return s.PutBytes(w, key, b)
|
||||
}
|
||||
|
||||
// PopObject removes the data for a given session key and reads it into a custom
|
||||
@ -465,8 +523,8 @@ func PutObject(r *http.Request, key string, val interface{}) error {
|
||||
// by dst will remain unchanged if the key does not exist.
|
||||
//
|
||||
// The dst parameter must be a pointer.
|
||||
func PopObject(r *http.Request, key string, dst interface{}) error {
|
||||
b, err := PopBytes(r, key)
|
||||
func (s *Session) PopObject(w http.ResponseWriter, key string, dst interface{}) error {
|
||||
b, err := s.PopBytes(w, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -480,143 +538,216 @@ func PopObject(r *http.Request, key string, dst interface{}) error {
|
||||
// Keys returns a slice of all key names present in the session data, sorted
|
||||
// alphabetically. If the session contains no data then an empty slice will be
|
||||
// returned.
|
||||
func Keys(r *http.Request) ([]string, error) {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (s *Session) Keys() ([]string, error) {
|
||||
if s.loadErr != nil {
|
||||
return nil, s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
keys := make([]string, len(s.data))
|
||||
i := 0
|
||||
for k := range s.data {
|
||||
keys[i] = k
|
||||
i++
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
sort.Strings(keys)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Exists returns true if the given key is present in the session data.
|
||||
func Exists(r *http.Request, key string) (bool, error) {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return false, err
|
||||
func (s *Session) Exists(key string) (bool, error) {
|
||||
if s.loadErr != nil {
|
||||
return false, s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
_, exists := s.data[key]
|
||||
s.mu.Unlock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
_, exists := s.data[key]
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// Remove deletes the given key and corresponding value from the session data.
|
||||
// If the key is not present this operation is a no-op.
|
||||
func Remove(r *http.Request, key string) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *Session) Remove(w http.ResponseWriter, key string) error {
|
||||
if s.loadErr != nil {
|
||||
return s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return ErrAlreadyWritten
|
||||
}
|
||||
|
||||
_, exists := s.data[key]
|
||||
if exists == false {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(s.data, key)
|
||||
s.modified = true
|
||||
return nil
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.write(w)
|
||||
}
|
||||
|
||||
// Clear removes all data for the current session. The session token and lifetime
|
||||
// are unaffected. If there is no data in the current session this operation is
|
||||
// a no-op.
|
||||
func Clear(r *http.Request) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
func (s *Session) Clear(w http.ResponseWriter) error {
|
||||
if s.loadErr != nil {
|
||||
return s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return ErrAlreadyWritten
|
||||
}
|
||||
|
||||
if len(s.data) == 0 {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
for key := range s.data {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.modified = true
|
||||
return nil
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.write(w)
|
||||
}
|
||||
|
||||
func get(r *http.Request, key string) (interface{}, bool, error) {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
// RenewToken creates a new session token while retaining the current session
|
||||
// data. The session lifetime is also reset.
|
||||
//
|
||||
// The old session token and accompanying data are deleted from the session store.
|
||||
//
|
||||
// To mitigate the risk of session fixation attacks, it's important that you call
|
||||
// RenewToken before making any changes to privilege levels (e.g. login and
|
||||
// logout operations). See https://www.owasp.org/index.php/Session_fixation for
|
||||
// additional information.
|
||||
func (s *Session) RenewToken(w http.ResponseWriter) error {
|
||||
if s.loadErr != nil {
|
||||
return s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
v, exists := s.data[key]
|
||||
s.mu.Unlock()
|
||||
|
||||
return v, exists, nil
|
||||
err := s.store.Delete(s.token)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func put(r *http.Request, key string, val interface{}) error {
|
||||
s, err := sessionFromContext(r)
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
s.token = token
|
||||
s.deadline = time.Now().Add(s.opts.lifetime)
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.write(w)
|
||||
}
|
||||
|
||||
// Destroy deletes the current session. The session token and accompanying
|
||||
// data are deleted from the session store, and the client is instructed to
|
||||
// delete the session cookie.
|
||||
//
|
||||
// Any further operations on the session in the same request cycle will result in a
|
||||
// new session being created.
|
||||
//
|
||||
// A new empty session will be created for any client that subsequently tries
|
||||
// to use the destroyed session token.
|
||||
func (s *Session) Destroy(w http.ResponseWriter) error {
|
||||
if s.loadErr != nil {
|
||||
return s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
err := s.store.Delete(s.token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return ErrAlreadyWritten
|
||||
s.token = ""
|
||||
for key := range s.data {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.data[key] = val
|
||||
s.modified = true
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: "",
|
||||
Path: s.opts.path,
|
||||
Domain: s.opts.domain,
|
||||
Secure: s.opts.secure,
|
||||
HttpOnly: s.opts.httpOnly,
|
||||
Expires: time.Unix(1, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func pop(r *http.Request, key string) (interface{}, bool, error) {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
func (s *Session) get(key string) (interface{}, bool, error) {
|
||||
if s.loadErr != nil {
|
||||
return nil, false, s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return "", false, ErrAlreadyWritten
|
||||
v, exists := s.data[key]
|
||||
return v, exists, nil
|
||||
}
|
||||
|
||||
func (s *Session) put(w http.ResponseWriter, key string, val interface{}) error {
|
||||
if s.loadErr != nil {
|
||||
return s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.data[key] = val
|
||||
s.mu.Unlock()
|
||||
|
||||
return s.write(w)
|
||||
}
|
||||
|
||||
func (s *Session) pop(w http.ResponseWriter, key string) (interface{}, bool, error) {
|
||||
if s.loadErr != nil {
|
||||
return nil, false, s.loadErr
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
v, exists := s.data[key]
|
||||
if exists == false {
|
||||
s.mu.Unlock()
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
delete(s.data, key)
|
||||
s.modified = true
|
||||
s.mu.Unlock()
|
||||
|
||||
err := s.write(w)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
func generateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func gobEncode(v interface{}) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
err := gob.NewEncoder(buf).Encode(v)
|
||||
@ -630,3 +761,28 @@ func gobDecode(b []byte, dst interface{}) error {
|
||||
buf := bytes.NewBuffer(b)
|
||||
return gob.NewDecoder(buf).Decode(dst)
|
||||
}
|
||||
|
||||
func encodeToJSON(data map[string]interface{}, deadline time.Time) ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Deadline int64 `json:"deadline"`
|
||||
}{
|
||||
Data: data,
|
||||
Deadline: deadline.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
func decodeFromJSON(j []byte) (map[string]interface{}, time.Time, error) {
|
||||
aux := struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Deadline int64 `json:"deadline"`
|
||||
}{}
|
||||
|
||||
dec := json.NewDecoder(bytes.NewReader(j))
|
||||
dec.UseNumber()
|
||||
err := dec.Decode(&aux)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
return aux.Data, time.Unix(0, aux.Deadline), nil
|
||||
}
|
@ -1,359 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutString", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopString", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBool(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutBool", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetBool", cookie)
|
||||
if body != "true" {
|
||||
t.Fatalf("got %q: expected %q", body, "true")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopBool", cookie)
|
||||
if body != "true" {
|
||||
t.Fatalf("got %q: expected %q", body, "true")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetBool", cookie)
|
||||
if body != "false" {
|
||||
t.Fatalf("got %q: expected %q", body, "false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutInt", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetInt", cookie)
|
||||
if body != "12345" {
|
||||
t.Fatalf("got %q: expected %q", body, "12345")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopInt", cookie)
|
||||
if body != "12345" {
|
||||
t.Fatalf("got %q: expected %q", body, "12345")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetInt", cookie)
|
||||
if body != "0" {
|
||||
t.Fatalf("got %q: expected %q", body, "0")
|
||||
}
|
||||
|
||||
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
|
||||
|
||||
_ = PutInt(r, "test_int", 12345)
|
||||
i, _ := GetInt(r, "test_int")
|
||||
if i != 12345 {
|
||||
t.Fatalf("got %d: expected %d", i, 12345)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt64(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutInt64", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetInt64", cookie)
|
||||
if body != "9223372036854775807" {
|
||||
t.Fatalf("got %q: expected %q", body, "9223372036854775807")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopInt64", cookie)
|
||||
if body != "9223372036854775807" {
|
||||
t.Fatalf("got %q: expected %q", body, "9223372036854775807")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetInt64", cookie)
|
||||
if body != "0" {
|
||||
t.Fatalf("got %q: expected %q", body, "0")
|
||||
}
|
||||
|
||||
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
|
||||
|
||||
_ = PutInt64(r, "test_int64", 9223372036854775807)
|
||||
i, _ := GetInt64(r, "test_int64")
|
||||
if i != 9223372036854775807 {
|
||||
t.Fatalf("got %d: expected %d", i, 9223372036854775807)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloat(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutFloat", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetFloat", cookie)
|
||||
if body != "12.345" {
|
||||
t.Fatalf("got %q: expected %q", body, "12.345")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopFloat", cookie)
|
||||
if body != "12.345" {
|
||||
t.Fatalf("got %q: expected %q", body, "12.345")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetFloat", cookie)
|
||||
if body != "0.000" {
|
||||
t.Fatalf("got %q: expected %q", body, "0.000")
|
||||
}
|
||||
|
||||
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
|
||||
|
||||
_ = PutFloat(r, "test_float", 12.345)
|
||||
i, _ := GetFloat(r, "test_float")
|
||||
if i != 12.345 {
|
||||
t.Fatalf("got %f: expected %f", i, 12.345)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTime(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutTime", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetTime", cookie)
|
||||
if body != "10 Nov 09 23:00 UTC" {
|
||||
t.Fatalf("got %q: expected %q", body, "10 Nov 09 23:00 UTC")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopTime", cookie)
|
||||
if body != "10 Nov 09 23:00 UTC" {
|
||||
t.Fatalf("got %q: expected %q", body, "10 Nov 09 23:00 UTC")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetTime", cookie)
|
||||
if body != "01 Jan 01 00:00 UTC" {
|
||||
t.Fatalf("got %q: expected %q", body, "01 Jan 01 00:00 UTC")
|
||||
}
|
||||
|
||||
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
|
||||
|
||||
tt := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC)
|
||||
_ = PutTime(r, "test_time", tt)
|
||||
tm, _ := GetTime(r, "test_time")
|
||||
if tm != tt {
|
||||
t.Fatalf("got %v: expected %v", t, tt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytes(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutBytes", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetBytes", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopBytes", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetBytes", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
|
||||
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
|
||||
|
||||
_ = PutBytes(r, "test_bytes", []byte("lorem ipsum"))
|
||||
b, _ := GetBytes(r, "test_bytes")
|
||||
if bytes.Equal(b, []byte("lorem ipsum")) == false {
|
||||
t.Fatalf("got %v: expected %v", b, []byte("lorem ipsum"))
|
||||
}
|
||||
|
||||
err := PutBytes(r, "test_bytes", nil)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestObject(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/PutObject", "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetObject", cookie)
|
||||
if body != "alice: 21" {
|
||||
t.Fatalf("got %q: expected %q", body, "alice: 21")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/PopObject", cookie)
|
||||
if body != "alice: 21" {
|
||||
t.Fatalf("got %q: expected %q", body, "alice: 21")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetObject", cookie)
|
||||
if body != ": 0" {
|
||||
t.Fatalf("got %q: expected %q", body, ": 0")
|
||||
}
|
||||
|
||||
r := requestWithSession(new(http.Request), &session{data: make(map[string]interface{})})
|
||||
|
||||
u := &testUser{"bob", 65}
|
||||
_ = PutObject(r, "test_object", u)
|
||||
o := &testUser{}
|
||||
_ = GetObject(r, "test_object", o)
|
||||
if reflect.DeepEqual(u, o) == false {
|
||||
t.Fatalf("got %v: expected %v", reflect.DeepEqual(u, o), false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeys(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
_, _, _ = testRequest(t, h, "/PutBool", cookie)
|
||||
|
||||
_, body, _ := testRequest(t, h, "/Keys", cookie)
|
||||
if body != "[test_bool test_string]" {
|
||||
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
|
||||
}
|
||||
|
||||
_, _, _ = testRequest(t, h, "/Clear", cookie)
|
||||
_, body, _ = testRequest(t, h, "/Keys", cookie)
|
||||
if body != "[]" {
|
||||
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExists(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
|
||||
_, body, _ := testRequest(t, h, "/Exists", cookie)
|
||||
if body != "true" {
|
||||
t.Fatalf("got %q: expected %q", body, "true")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/PopString", cookie)
|
||||
|
||||
_, body, _ = testRequest(t, h, "/Exists", cookie)
|
||||
if body != "false" {
|
||||
t.Fatalf("got %q: expected %q", body, "false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
_, _, cookie = testRequest(t, h, "/PutBool", cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/RemoveString", cookie)
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetBool", cookie)
|
||||
if body != "true" {
|
||||
t.Fatalf("got %q: expected %q", body, "true")
|
||||
}
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/RemoveString", cookie)
|
||||
if cookie != "" {
|
||||
t.Fatalf("got %q: expected %q", cookie, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClear(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
_, _, cookie = testRequest(t, h, "/PutBool", cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/Clear", cookie)
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetBool", cookie)
|
||||
if body != "false" {
|
||||
t.Fatalf("got %q: expected %q", body, "false")
|
||||
}
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/Clear", cookie)
|
||||
if cookie != "" {
|
||||
t.Fatalf("got %q: expected %q", cookie, "")
|
||||
}
|
||||
}
|
@ -1,115 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Deprecated: Middleware previously defined the signature for the session management
|
||||
// middleware returned by Manage. Manage now returns a func(h http.Handler) http.Handler
|
||||
// directly instead, so it's easier to use with middleware chaining packages like Alice.
|
||||
type Middleware func(h http.Handler) http.Handler
|
||||
|
||||
/*
|
||||
Manage returns a new session manager middleware instance. The first parameter
|
||||
should be a valid storage engine, followed by zero or more functional options.
|
||||
|
||||
For example:
|
||||
|
||||
session.Manage(memstore.New(0))
|
||||
|
||||
session.Manage(memstore.New(0), session.Lifetime(14*24*time.Hour))
|
||||
|
||||
session.Manage(memstore.New(0),
|
||||
session.Secure(true),
|
||||
session.Persist(true),
|
||||
session.Lifetime(14*24*time.Hour),
|
||||
)
|
||||
|
||||
The returned session manager can be used to wrap any http.Handler. It automatically
|
||||
loads sessions based on the HTTP request and saves session data as and when necessary.
|
||||
*/
|
||||
func Manage(engine Engine, opts ...Option) func(h http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
do := *defaultOptions
|
||||
|
||||
m := &manager{
|
||||
h: h,
|
||||
engine: engine,
|
||||
opts: &do,
|
||||
}
|
||||
|
||||
for _, option := range opts {
|
||||
option(m.opts)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
h http.Handler
|
||||
engine Engine
|
||||
opts *options
|
||||
}
|
||||
|
||||
func (m *manager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
sr, err := load(r, m.engine, m.opts)
|
||||
if err != nil {
|
||||
m.opts.errorFunc(w, r, err)
|
||||
return
|
||||
}
|
||||
bw := &bufferedResponseWriter{ResponseWriter: w}
|
||||
m.h.ServeHTTP(bw, sr)
|
||||
|
||||
err = write(w, sr)
|
||||
if err != nil {
|
||||
m.opts.errorFunc(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if bw.code != 0 {
|
||||
w.WriteHeader(bw.code)
|
||||
}
|
||||
|
||||
_, err = w.Write(bw.buf.Bytes())
|
||||
if err != nil {
|
||||
m.opts.errorFunc(w, r, err)
|
||||
}
|
||||
}
|
||||
|
||||
type bufferedResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
buf bytes.Buffer
|
||||
code int
|
||||
}
|
||||
|
||||
func (bw *bufferedResponseWriter) Write(b []byte) (int, error) {
|
||||
return bw.buf.Write(b)
|
||||
}
|
||||
|
||||
func (bw *bufferedResponseWriter) WriteHeader(code int) {
|
||||
bw.code = code
|
||||
}
|
||||
|
||||
func (bw *bufferedResponseWriter) Flush() {
|
||||
f, ok := bw.ResponseWriter.(http.Flusher)
|
||||
if ok == true {
|
||||
bw.ResponseWriter.Write(bw.buf.Bytes())
|
||||
f.Flush()
|
||||
bw.buf.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (bw *bufferedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj := bw.ResponseWriter.(http.Hijacker)
|
||||
return hj.Hijack()
|
||||
}
|
||||
|
||||
func defaultErrorFunc(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Output(2, err.Error())
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriteResponse(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
code, _, _ := testRequest(t, h, "/WriteHeader", "")
|
||||
if code != 418 {
|
||||
t.Fatalf("got %d: expected %d", code, 418)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerOptionsLeak(t *testing.T) {
|
||||
_ = Manage(testEngine, Domain("example.org"))
|
||||
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
if strings.Contains(cookie, "example.org") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "example.org")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlusher(t *testing.T) {
|
||||
e := testEngine
|
||||
m := Manage(e)
|
||||
h := m(testServeMux)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/Flush", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h.ServeHTTP(rr, r)
|
||||
|
||||
body := string(rr.Body.Bytes())
|
||||
cookie := rr.Header().Get("Set-Cookie")
|
||||
token := extractTokenFromCookie(cookie)
|
||||
|
||||
if body != "This is some…flushed data" {
|
||||
t.Fatalf("got %q: expected %q", body, "This is some…flushed data")
|
||||
}
|
||||
if len(rr.Header()["Set-Cookie"]) != 1 {
|
||||
t.Fatalf("got %d: expected %d", len(rr.Header()["Set-Cookie"]), 1)
|
||||
}
|
||||
if strings.HasPrefix(cookie, fmt.Sprintf("%s=", CookieName)) == false {
|
||||
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=", CookieName))
|
||||
}
|
||||
_, found, _ := e.Find(token)
|
||||
if found != true {
|
||||
t.Fatalf("got %v: expected %v", found, true)
|
||||
}
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewMockRequest(r *http.Request) *http.Request {
|
||||
do := *defaultOptions
|
||||
s := &session{
|
||||
token: "",
|
||||
data: make(map[string]interface{}),
|
||||
deadline: time.Now().Add(do.lifetime),
|
||||
engine: &mockEngine{},
|
||||
opts: &do,
|
||||
}
|
||||
return requestWithSession(r, s)
|
||||
}
|
||||
|
||||
type mockEngine struct{}
|
||||
|
||||
func (me *mockEngine) Find(token string) (b []byte, exists bool, err error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (me *mockEngine) Save(token string, b []byte, expiry time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (me *mockEngine) Delete(token string) error {
|
||||
return nil
|
||||
}
|
@ -1,123 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ContextName changes the value of the (string) key used to store the session
|
||||
// information in Request.Context. You should only need to change this if there is
|
||||
// a naming clash.
|
||||
var ContextName = "scs.session"
|
||||
|
||||
// CookieName changes the name of the session cookie issued to clients. Note that
|
||||
// cookie names should not contain whitespace, commas, semicolons, backslashes
|
||||
// or control characters as per RFC6265.
|
||||
var CookieName = "scs.session.token"
|
||||
|
||||
var defaultOptions = &options{
|
||||
domain: "",
|
||||
errorFunc: defaultErrorFunc,
|
||||
httpOnly: true,
|
||||
idleTimeout: 0,
|
||||
lifetime: 24 * time.Hour,
|
||||
path: "/",
|
||||
persist: false,
|
||||
secure: false,
|
||||
}
|
||||
|
||||
type options struct {
|
||||
domain string
|
||||
errorFunc func(http.ResponseWriter, *http.Request, error)
|
||||
httpOnly bool
|
||||
idleTimeout time.Duration
|
||||
lifetime time.Duration
|
||||
path string
|
||||
persist bool
|
||||
secure bool
|
||||
}
|
||||
|
||||
// Option defines the functional arguments for configuring the session manager.
|
||||
type Option func(*options)
|
||||
|
||||
// Domain sets the 'Domain' attribute on the session cookie. By default it will
|
||||
// be set to the domain name that the cookie was issued from.
|
||||
func Domain(s string) Option {
|
||||
return func(opts *options) {
|
||||
opts.domain = s
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorFunc allows you to control behavior when an error is encountered loading
|
||||
// or writing a session. The default behavior is for a HTTP 500 status code to
|
||||
// be written to the ResponseWriter along with the plain-text error string. If
|
||||
// a custom error function is set, then control will be passed to this instead.
|
||||
// A typical use would be to provide a function which logs the error and returns
|
||||
// a customized HTML error page.
|
||||
func ErrorFunc(f func(http.ResponseWriter, *http.Request, error)) Option {
|
||||
return func(opts *options) {
|
||||
opts.errorFunc = f
|
||||
}
|
||||
}
|
||||
|
||||
// HttpOnly sets the 'HttpOnly' attribute on the session cookie. The default value
|
||||
// is true.
|
||||
func HttpOnly(b bool) Option {
|
||||
return func(opts *options) {
|
||||
opts.httpOnly = b
|
||||
}
|
||||
}
|
||||
|
||||
// IdleTimeout sets the maximum length of time a session can be inactive before it
|
||||
// expires. For example, some applications may wish to set this so there is a timeout after
|
||||
// 20 minutes of inactivity. Any client request which includes the
|
||||
// session cookie and is handled by the session middleware is classed as activity.s
|
||||
//
|
||||
// By default IdleTimeout is not set and there is no inactivity timeout.
|
||||
func IdleTimeout(t time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.idleTimeout = t
|
||||
}
|
||||
}
|
||||
|
||||
// Lifetime sets the maximum length of time that a session is valid for before
|
||||
// it expires. The lifetime is an 'absolute expiry' which is set when the session
|
||||
// is first created and does not change.
|
||||
//
|
||||
// The default value is 24 hours.
|
||||
func Lifetime(t time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.lifetime = t
|
||||
}
|
||||
}
|
||||
|
||||
// Path sets the 'Path' attribute on the session cookie. The default value is "/".
|
||||
// Passing the empty string "" will result in it being set to the path that the
|
||||
// cookie was issued from.
|
||||
func Path(s string) Option {
|
||||
return func(opts *options) {
|
||||
opts.path = s
|
||||
}
|
||||
}
|
||||
|
||||
// Persist sets whether the session cookie should be persistent or not (i.e. whether
|
||||
// it should be retained after a user closes their browser).
|
||||
//
|
||||
// The default value is false, which means that the session cookie will be destroyed
|
||||
// when the user closes their browser. If set to true, explicit 'Expires' and
|
||||
// 'MaxAge' values will be added to the cookie and it will be retained by the
|
||||
// user's browser until the given expiry time is reached.
|
||||
func Persist(b bool) Option {
|
||||
return func(opts *options) {
|
||||
opts.persist = b
|
||||
}
|
||||
}
|
||||
|
||||
// Secure sets the 'Secure' attribute on the session cookie. The default value
|
||||
// is false. It's recommended that you set this to true and serve all requests
|
||||
// over HTTPS in production environments.
|
||||
func Secure(b bool) Option {
|
||||
return func(opts *options) {
|
||||
opts.secure = b
|
||||
}
|
||||
}
|
@ -1,208 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCookieOptions(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
if strings.Contains(cookie, "Path=/") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/")
|
||||
}
|
||||
if strings.Contains(cookie, "Domain=") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "Domain=")
|
||||
}
|
||||
if strings.Contains(cookie, "Secure") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "Secure")
|
||||
}
|
||||
if strings.Contains(cookie, "HttpOnly") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "HttpOnly")
|
||||
}
|
||||
|
||||
m = Manage(testEngine, Path("/foo"), Domain("example.org"), Secure(true), HttpOnly(false), Lifetime(time.Hour), Persist(true))
|
||||
h = m(testServeMux)
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/PutString", "")
|
||||
if strings.Contains(cookie, "Path=/foo") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Path=/foo")
|
||||
}
|
||||
if strings.Contains(cookie, "Domain=example.org") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Domain=example.org")
|
||||
}
|
||||
if strings.Contains(cookie, "Secure") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Secure")
|
||||
}
|
||||
if strings.Contains(cookie, "HttpOnly") == true {
|
||||
t.Fatalf("got %q: expected to not contain %q", cookie, "HttpOnly")
|
||||
}
|
||||
if strings.Contains(cookie, "Max-Age=3600") == false {
|
||||
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=86400")
|
||||
}
|
||||
if strings.Contains(cookie, "Expires=") == false {
|
||||
t.Fatalf("got %q: expected to contain %q:", cookie, "Expires")
|
||||
}
|
||||
|
||||
m = Manage(testEngine, Lifetime(time.Hour))
|
||||
h = m(testServeMux)
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/PutString", "")
|
||||
if strings.Contains(cookie, "Max-Age=") == true {
|
||||
t.Fatalf("got %q: expected not to contain %q:", cookie, "Max-Age=")
|
||||
}
|
||||
if strings.Contains(cookie, "Expires=") == true {
|
||||
t.Fatalf("got %q: expected not to contain %q:", cookie, "Expires")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLifetime(t *testing.T) {
|
||||
m := Manage(testEngine, Lifetime(200*time.Millisecond))
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/PutString", cookie)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
_, body, _ := testRequest(t, h, "/GetString", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
_, _, cookie = testRequest(t, h, "/PutString", cookie)
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatalf("expected a difference")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleTimeout(t *testing.T) {
|
||||
m := Manage(testEngine, IdleTimeout(100*time.Millisecond), Lifetime(500*time.Millisecond))
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
_, body, _ := testRequest(t, h, "/GetString", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
_, _, cookie = testRequest(t, h, "/PutString", cookie)
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatalf("expected a difference")
|
||||
}
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/PutString", "")
|
||||
oldToken = extractTokenFromCookie(cookie)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
|
||||
_, _, cookie = testRequest(t, h, "/GetString", cookie)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
|
||||
_, body, cookie = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
_, _, cookie = testRequest(t, h, "/PutString", cookie)
|
||||
newToken = extractTokenFromCookie(cookie)
|
||||
if newToken != oldToken {
|
||||
t.Fatalf("expected the same")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFunc(t *testing.T) {
|
||||
m := Manage(testEngine)
|
||||
man, ok := m(nil).(*manager)
|
||||
if ok == false {
|
||||
t.Fatal("type assertion failed")
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
man.opts.errorFunc(rr, new(http.Request), errors.New("testing error log..."))
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("got %d: expected %d", rr.Code, http.StatusInternalServerError)
|
||||
}
|
||||
if string(rr.Body.Bytes()) != fmt.Sprintf("%s\n", http.StatusText(http.StatusInternalServerError)) {
|
||||
t.Fatalf("got %q: expected %q", string(rr.Body.Bytes()), fmt.Sprintf("%s\n", http.StatusText(http.StatusInternalServerError)))
|
||||
}
|
||||
|
||||
m = Manage(testEngine, ErrorFunc(func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
w.WriteHeader(418)
|
||||
io.WriteString(w, http.StatusText(418))
|
||||
}))
|
||||
man, ok = m(nil).(*manager)
|
||||
if ok == false {
|
||||
t.Fatal("type assertion failed")
|
||||
}
|
||||
|
||||
rr = httptest.NewRecorder()
|
||||
man.opts.errorFunc(rr, new(http.Request), errors.New("testing error log..."))
|
||||
if rr.Code != 418 {
|
||||
t.Fatalf("got %d: expected %d", rr.Code, 418)
|
||||
}
|
||||
if string(rr.Body.Bytes()) != http.StatusText(418) {
|
||||
t.Fatalf("got %q: expected %q", string(rr.Body.Bytes()), http.StatusText(418))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersist(t *testing.T) {
|
||||
m := Manage(testEngine, IdleTimeout(5*time.Minute), Persist(true))
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
if strings.Contains(cookie, "Max-Age=300") == false {
|
||||
t.Fatalf("got %q: expected to contain %q:", cookie, "Max-Age=300")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieName(t *testing.T) {
|
||||
oldCookieName := CookieName
|
||||
CookieName = "custom_cookie_name"
|
||||
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
if strings.HasPrefix(cookie, "custom_cookie_name=") == false {
|
||||
t.Fatalf("got %q: expected prefix %q", cookie, "custom_cookie_name=")
|
||||
}
|
||||
|
||||
_, body, _ := testRequest(t, h, "/GetString", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
CookieName = oldCookieName
|
||||
}
|
||||
|
||||
func TestContextDataName(t *testing.T) {
|
||||
oldContextName := ContextName
|
||||
ContextName = "custom_context_name"
|
||||
|
||||
m := Manage(testEngine)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
_, body, _ := testRequest(t, h, "/DumpContext", cookie)
|
||||
if strings.Contains(body, "custom_context_name") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", body, "custom_context_name")
|
||||
}
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
ContextName = oldContextName
|
||||
}
|
@ -1,456 +0,0 @@
|
||||
/*
|
||||
Package session provides session management middleware and helpers for
|
||||
manipulating session data.
|
||||
|
||||
It should be installed alongside one of the storage engines from https://godoc.org/github.com/alexedwards/scs/engine.
|
||||
|
||||
For example:
|
||||
|
||||
$ go get github.com/alexedwards/scs/session
|
||||
$ go get github.com/alexedwards/scs/engine/memstore
|
||||
|
||||
Basic use:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/alexedwards/scs/engine/memstore"
|
||||
"github.com/alexedwards/scs/session"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialise a new storage engine. Here we use the memstore package, but the principles
|
||||
// are the same no matter which back-end store you choose.
|
||||
engine := memstore.New(0)
|
||||
|
||||
// Initialise the session manager middleware, passing in the storage engine as
|
||||
// the first parameter. This middleware will automatically handle loading and
|
||||
// saving of session data for you.
|
||||
sessionManager := session.Manage(engine)
|
||||
|
||||
// Set up your HTTP handlers in the normal way.
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/put", putHandler)
|
||||
mux.HandleFunc("/get", getHandler)
|
||||
|
||||
// Wrap your handlers with the session manager middleware.
|
||||
http.ListenAndServe(":4000", sessionManager(mux))
|
||||
}
|
||||
|
||||
func putHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Use the PutString helper to store a new key and associated string value in
|
||||
// the session data. Helpers are also available for bool, int, int64, float,
|
||||
// time.Time and []byte data types.
|
||||
err := session.PutString(r, "message", "Hello from a session!")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
}
|
||||
}
|
||||
|
||||
func getHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Use the GetString helper to retrieve the string value associated with a key.
|
||||
msg, err := session.GetString(r, "message")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, msg)
|
||||
}
|
||||
*/
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrAlreadyWritten is returned when an attempt is made to modify the session
|
||||
// data after it has already been sent to the storage engine and client.
|
||||
var ErrAlreadyWritten = errors.New("session already written to the engine and http.ResponseWriter")
|
||||
|
||||
type session struct {
|
||||
token string
|
||||
data map[string]interface{}
|
||||
deadline time.Time
|
||||
engine Engine
|
||||
opts *options
|
||||
modified bool
|
||||
written bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func generateToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func newSession(r *http.Request, engine Engine, opts *options) (*http.Request, error) {
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &session{
|
||||
token: token,
|
||||
data: make(map[string]interface{}),
|
||||
deadline: time.Now().Add(opts.lifetime),
|
||||
engine: engine,
|
||||
opts: opts,
|
||||
}
|
||||
return requestWithSession(r, s), nil
|
||||
}
|
||||
|
||||
func load(r *http.Request, engine Engine, opts *options) (*http.Request, error) {
|
||||
cookie, err := r.Cookie(CookieName)
|
||||
if err == http.ErrNoCookie {
|
||||
return newSession(r, engine, opts)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cookie.Value == "" {
|
||||
return newSession(r, engine, opts)
|
||||
}
|
||||
token := cookie.Value
|
||||
|
||||
j, found, err := engine.Find(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if found == false {
|
||||
return newSession(r, engine, opts)
|
||||
}
|
||||
|
||||
data, deadline, err := decodeFromJSON(j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &session{
|
||||
token: token,
|
||||
data: data,
|
||||
deadline: deadline,
|
||||
engine: engine,
|
||||
opts: opts,
|
||||
}
|
||||
|
||||
return requestWithSession(r, s), nil
|
||||
}
|
||||
|
||||
func write(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.modified == false && s.opts.idleTimeout == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
j, err := encodeToJSON(s.data, s.deadline)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expiry := s.deadline
|
||||
if s.opts.idleTimeout > 0 {
|
||||
ie := time.Now().Add(s.opts.idleTimeout)
|
||||
if ie.Before(expiry) {
|
||||
expiry = ie
|
||||
}
|
||||
}
|
||||
|
||||
if ce, ok := s.engine.(cookieEngine); ok {
|
||||
s.token, err = ce.MakeToken(j, expiry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = s.engine.Save(s.token, j, expiry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: s.token,
|
||||
Path: s.opts.path,
|
||||
Domain: s.opts.domain,
|
||||
Secure: s.opts.secure,
|
||||
HttpOnly: s.opts.httpOnly,
|
||||
}
|
||||
if s.opts.persist == true {
|
||||
cookie.Expires = expiry
|
||||
// The addition of 0.5 means MaxAge is correctly rounded to the nearest
|
||||
// second instead of being floored.
|
||||
cookie.MaxAge = int(expiry.Sub(time.Now()).Seconds() + 0.5)
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
s.written = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
RegenerateToken creates a new session token while retaining the current session
|
||||
data. The session lifetime is also reset.
|
||||
|
||||
The old session token and accompanying data are deleted from the storage engine.
|
||||
|
||||
To mitigate the risk of session fixation attacks, it's important that you call
|
||||
RegenerateToken before making any changes to privilege levels (e.g. login and
|
||||
logout operations). See https://www.owasp.org/index.php/Session_fixation for
|
||||
additional information.
|
||||
|
||||
Usage:
|
||||
|
||||
func loginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userID := 123
|
||||
|
||||
// First regenerate the session token…
|
||||
err := session.RegenerateToken(r)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
// Then make the privilege-level change.
|
||||
err = session.PutInt(r, "userID", userID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
func RegenerateToken(r *http.Request) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return ErrAlreadyWritten
|
||||
}
|
||||
|
||||
err = s.engine.Delete(s.token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.token = token
|
||||
s.deadline = time.Now().Add(s.opts.lifetime)
|
||||
s.modified = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Renew creates a new session token and removes all data for the session. The
|
||||
// session lifetime is also reset.
|
||||
//
|
||||
// The old session token and accompanying data are deleted from the storage engine.
|
||||
//
|
||||
// The Renew function is essentially a concurrency-safe amalgamation of the
|
||||
// RegenerateToken and Clear functions.
|
||||
func Renew(r *http.Request) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return ErrAlreadyWritten
|
||||
}
|
||||
|
||||
err = s.engine.Delete(s.token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := generateToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.token = token
|
||||
for key := range s.data {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.deadline = time.Now().Add(s.opts.lifetime)
|
||||
s.modified = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy deletes the current session. The session token and accompanying
|
||||
// data are deleted from the storage engine, and the client is instructed to
|
||||
// delete the session cookie.
|
||||
//
|
||||
// Destroy operations are effective immediately, and any further operations on
|
||||
// the session in the same request cycle will return an ErrAlreadyWritten error.
|
||||
// If you see this error you probably want to use the Renew function instead.
|
||||
//
|
||||
// A new empty session will be created for any client that subsequently tries
|
||||
// to use the destroyed session token.
|
||||
func Destroy(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.written == true {
|
||||
return ErrAlreadyWritten
|
||||
}
|
||||
|
||||
err = s.engine.Delete(s.token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.token = ""
|
||||
for key := range s.data {
|
||||
delete(s.data, key)
|
||||
}
|
||||
s.modified = true
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: "",
|
||||
Path: s.opts.path,
|
||||
Domain: s.opts.domain,
|
||||
Secure: s.opts.secure,
|
||||
HttpOnly: s.opts.httpOnly,
|
||||
Expires: time.Unix(1, 0),
|
||||
MaxAge: -1,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
s.written = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save immediately writes the session cookie header to the ResponseWriter and
|
||||
// saves the session data to the storage engine, if needed.
|
||||
//
|
||||
// Using Save is not normally necessary. The session middleware (which buffers
|
||||
// all writes to the underlying connection) will automatically handle setting the
|
||||
// cookie header and storing the data for you.
|
||||
//
|
||||
// However there may be instances where you wish to break out of this normal
|
||||
// operation and (one way or another) write to the underlying connection before
|
||||
// control is passed back to the session middleware. In these instances, where
|
||||
// response headers have already been written, the middleware will be too late
|
||||
// to set the cookie header. The solution is to manually call Save before performing
|
||||
// any writes.
|
||||
//
|
||||
// An example is flushing data using the http.Flusher interface:
|
||||
//
|
||||
// func flushingHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// err := session.PutString(r, "foo", "bar")
|
||||
// if err != nil {
|
||||
// http.Error(w, err.Error(), 500)
|
||||
// return
|
||||
// }
|
||||
// err = session.Save(w, r)
|
||||
// if err != nil {
|
||||
// http.Error(w, err.Error(), 500)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// fw, ok := w.(http.Flusher)
|
||||
// if !ok {
|
||||
// http.Error(w, "could not assert to http.Flusher", 500)
|
||||
// return
|
||||
// }
|
||||
// w.Write([]byte("This is some…"))
|
||||
// fw.Flush()
|
||||
// w.Write([]byte("flushed data"))
|
||||
// }
|
||||
func Save(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessionFromContext(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
wr := s.written
|
||||
s.mu.Unlock()
|
||||
if wr == true {
|
||||
return ErrAlreadyWritten
|
||||
}
|
||||
|
||||
return write(w, r)
|
||||
}
|
||||
|
||||
func sessionFromContext(r *http.Request) (*session, error) {
|
||||
s, ok := r.Context().Value(ContextName).(*session)
|
||||
if ok == false {
|
||||
return nil, errors.New("request.Context does not contain a *session value")
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func requestWithSession(r *http.Request, s *session) *http.Request {
|
||||
ctx := context.WithValue(r.Context(), ContextName, s)
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
func encodeToJSON(data map[string]interface{}, deadline time.Time) ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Deadline int64 `json:"deadline"`
|
||||
}{
|
||||
Data: data,
|
||||
Deadline: deadline.UnixNano(),
|
||||
})
|
||||
}
|
||||
|
||||
func decodeFromJSON(j []byte) (map[string]interface{}, time.Time, error) {
|
||||
aux := struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Deadline int64 `json:"deadline"`
|
||||
}{}
|
||||
|
||||
dec := json.NewDecoder(bytes.NewReader(j))
|
||||
dec.UseNumber()
|
||||
err := dec.Decode(&aux)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
return aux.Data, time.Unix(0, aux.Deadline), nil
|
||||
}
|
@ -1,516 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alexedwards/scs/engine/memstore"
|
||||
)
|
||||
|
||||
var testEngine Engine
|
||||
var testServeMux *http.ServeMux
|
||||
|
||||
type testUser struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(new(testUser))
|
||||
|
||||
testEngine = memstore.New(time.Minute)
|
||||
testServeMux = http.NewServeMux()
|
||||
|
||||
testServeMux.HandleFunc("/PutString", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutString(r, "test_string", "lorem ipsum")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetString", func(w http.ResponseWriter, r *http.Request) {
|
||||
s, err := GetString(r, "test_string")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, s)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopString", func(w http.ResponseWriter, r *http.Request) {
|
||||
s, err := PopString(r, "test_string")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, s)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutBool", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutBool(r, "test_bool", true)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetBool", func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := GetBool(r, "test_bool")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v", b)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopBool", func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := PopBool(r, "test_bool")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v", b)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutInt", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutInt(r, "test_int", 12345)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetInt", func(w http.ResponseWriter, r *http.Request) {
|
||||
i, err := GetInt(r, "test_int")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%d", i)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopInt", func(w http.ResponseWriter, r *http.Request) {
|
||||
i, err := PopInt(r, "test_int")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%d", i)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutInt64", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutInt64(r, "test_int", 9223372036854775807)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetInt64", func(w http.ResponseWriter, r *http.Request) {
|
||||
i, err := GetInt64(r, "test_int")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%d", i)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopInt64", func(w http.ResponseWriter, r *http.Request) {
|
||||
i, err := PopInt64(r, "test_int")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%d", i)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutFloat", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutFloat(r, "test_float", 12.345)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetFloat", func(w http.ResponseWriter, r *http.Request) {
|
||||
f, err := GetFloat(r, "test_float")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%.3f", f)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopFloat", func(w http.ResponseWriter, r *http.Request) {
|
||||
f, err := PopFloat(r, "test_float")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%.3f", f)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutTime", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutTime(r, "test_time", time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC))
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetTime", func(w http.ResponseWriter, r *http.Request) {
|
||||
t, err := GetTime(r, "test_time")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, t.Format(time.RFC822))
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopTime", func(w http.ResponseWriter, r *http.Request) {
|
||||
t, err := PopTime(r, "test_time")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, t.Format(time.RFC822))
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutBytes", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutBytes(r, "test_bytes", []byte("lorem ipsum"))
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetBytes", func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := GetBytes(r, "test_bytes")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s", b)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopBytes", func(w http.ResponseWriter, r *http.Request) {
|
||||
b, err := PopBytes(r, "test_bytes")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s", b)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PutObject", func(w http.ResponseWriter, r *http.Request) {
|
||||
u := &testUser{"alice", 21}
|
||||
err := PutObject(r, "test_object", u)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/GetObject", func(w http.ResponseWriter, r *http.Request) {
|
||||
u := new(testUser)
|
||||
err := GetObject(r, "test_object", u)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/PopObject", func(w http.ResponseWriter, r *http.Request) {
|
||||
u := new(testUser)
|
||||
err := PopObject(r, "test_object", u)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%s: %d", u.Name, u.Age)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := Keys(r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v", keys)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Exists", func(w http.ResponseWriter, r *http.Request) {
|
||||
exists, err := Exists(r, "test_string")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "%v", exists)
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/RemoveString", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := Remove(r, "test_string")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Clear", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := Clear(r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Destroy", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := Destroy(w, r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/RegenerateToken", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := RegenerateToken(r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Renew", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := Renew(r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Save", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutString(r, "test_string", "lorem ipsum")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
err = Save(w, r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
io.WriteString(w, "OK")
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/Flush", func(w http.ResponseWriter, r *http.Request) {
|
||||
err := PutString(r, "test_string", "lorem ipsum")
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
err = Save(w, r)
|
||||
if err != nil {
|
||||
io.WriteString(w, err.Error())
|
||||
return
|
||||
}
|
||||
fw, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "could not assert to Flusher", 500)
|
||||
return
|
||||
}
|
||||
w.Write([]byte("This is some…"))
|
||||
fw.Flush()
|
||||
w.Write([]byte("flushed data"))
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/WriteHeader", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTeapot)
|
||||
io.WriteString(w, http.StatusText(http.StatusTeapot))
|
||||
})
|
||||
|
||||
testServeMux.HandleFunc("/DumpContext", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "%v", r.Context())
|
||||
})
|
||||
}
|
||||
|
||||
func testRequest(t *testing.T, h http.Handler, path string, cookie string) (int, string, string) {
|
||||
rr := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if cookie != "" {
|
||||
r.Header.Add("Cookie", cookie)
|
||||
}
|
||||
h.ServeHTTP(rr, r)
|
||||
|
||||
code := rr.Code
|
||||
body := string(rr.Body.Bytes())
|
||||
cookie = rr.Header().Get("Set-Cookie")
|
||||
return code, body, cookie
|
||||
}
|
||||
|
||||
func extractTokenFromCookie(c string) string {
|
||||
parts := strings.Split(c, ";")
|
||||
return strings.TrimPrefix(parts[0], fmt.Sprintf("%s=", CookieName))
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
id, err := generateToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
match, err := regexp.MatchString("^[0-9a-zA-Z_\\-]{43}$", id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if match == false {
|
||||
t.Errorf("got %q: should match %q", id, "^[0-9a-zA-Z_\\-]{43}$")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDestroy(t *testing.T) {
|
||||
e := testEngine
|
||||
m := Manage(e)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/Destroy", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r.Header.Add("Cookie", cookie)
|
||||
h.ServeHTTP(rr, r)
|
||||
body := string(rr.Body.Bytes())
|
||||
cookie = rr.Header().Get("Set-Cookie")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
if len(rr.Header()["Set-Cookie"]) != 1 {
|
||||
t.Fatalf("got %d: expected %d", len(rr.Header()["Set-Cookie"]), 1)
|
||||
}
|
||||
if strings.HasPrefix(cookie, fmt.Sprintf("%s=;", CookieName)) == false {
|
||||
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=;", CookieName))
|
||||
}
|
||||
if strings.Contains(cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT")
|
||||
}
|
||||
if strings.Contains(cookie, "Max-Age=0") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Max-Age=0")
|
||||
}
|
||||
_, found, _ := e.Find(oldToken)
|
||||
if found != false {
|
||||
t.Fatalf("got %v: expected %v", found, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegenerateToken(t *testing.T) {
|
||||
e := testEngine
|
||||
m := Manage(e)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/RegenerateToken", cookie)
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatal("expected a difference")
|
||||
}
|
||||
_, found, _ := e.Find(oldToken)
|
||||
if found != false {
|
||||
t.Fatalf("got %v: expected %v", found, false)
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenew(t *testing.T) {
|
||||
e := testEngine
|
||||
m := Manage(e)
|
||||
h := m(testServeMux)
|
||||
|
||||
_, _, cookie := testRequest(t, h, "/PutString", "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, h, "/Renew", cookie)
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatal("expected a difference")
|
||||
}
|
||||
_, found, _ := e.Find(oldToken)
|
||||
if found != false {
|
||||
t.Fatalf("got %v: expected %v", found, false)
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, h, "/GetString", cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
e := testEngine
|
||||
m := Manage(e)
|
||||
h := m(testServeMux)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
r, err := http.NewRequest("GET", "/Save", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h.ServeHTTP(rr, r)
|
||||
|
||||
body := string(rr.Body.Bytes())
|
||||
cookie := rr.Header().Get("Set-Cookie")
|
||||
token := extractTokenFromCookie(cookie)
|
||||
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
if len(rr.Header()["Set-Cookie"]) != 1 {
|
||||
t.Fatalf("got %d: expected %d", len(rr.Header()["Set-Cookie"]), 1)
|
||||
}
|
||||
if strings.HasPrefix(cookie, fmt.Sprintf("%s=", CookieName)) == false {
|
||||
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=", CookieName))
|
||||
}
|
||||
_, found, _ := e.Find(token)
|
||||
if found != true {
|
||||
t.Fatalf("got %v: expected %v", found, true)
|
||||
}
|
||||
}
|
195
session_test.go
Normal file
195
session_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package scs
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testUser struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(new(testUser))
|
||||
}
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
id, err := generateToken()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
match, err := regexp.MatchString("^[0-9a-zA-Z_\\-]{43}$", id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if match == false {
|
||||
t.Errorf("got %q: should match %q", id, "^[0-9a-zA-Z_\\-]{43}$")
|
||||
}
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
_, body, cookie := testRequest(t, testPutString(manager), "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetString(manager), cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, testPopString(manager), cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetString(manager), cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestObject(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
_, body, cookie := testRequest(t, testPutObject(manager), "")
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetObject(manager), cookie)
|
||||
if body != "alice: 21" {
|
||||
t.Fatalf("got %q: expected %q", body, "alice: 21")
|
||||
}
|
||||
|
||||
_, body, cookie = testRequest(t, testPopObject(manager), cookie)
|
||||
if body != "alice: 21" {
|
||||
t.Fatalf("got %q: expected %q", body, "alice: 21")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetObject(manager), cookie)
|
||||
if body != ": 0" {
|
||||
t.Fatalf("got %q: expected %q", body, ": 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDestroy(t *testing.T) {
|
||||
store := newMockStore()
|
||||
manager := NewManager(store)
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, testDestroy(manager), cookie)
|
||||
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
if strings.HasPrefix(cookie, fmt.Sprintf("%s=;", CookieName)) == false {
|
||||
t.Fatalf("got %q: expected prefix %q", cookie, fmt.Sprintf("%s=;", CookieName))
|
||||
}
|
||||
if strings.Contains(cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Expires=Thu, 01 Jan 1970 00:00:01 GMT")
|
||||
}
|
||||
if strings.Contains(cookie, "Max-Age=0") == false {
|
||||
t.Fatalf("got %q: expected to contain %q", cookie, "Max-Age=0")
|
||||
}
|
||||
_, found, _ := store.Find(oldToken)
|
||||
if found != false {
|
||||
t.Fatalf("got %v: expected %v", found, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewToken(t *testing.T) {
|
||||
store := newMockStore()
|
||||
manager := NewManager(store)
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
oldToken := extractTokenFromCookie(cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, testRenewToken(manager), cookie)
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
newToken := extractTokenFromCookie(cookie)
|
||||
if newToken == oldToken {
|
||||
t.Fatal("expected a difference")
|
||||
}
|
||||
_, found, _ := store.Find(oldToken)
|
||||
if found != false {
|
||||
t.Fatalf("got %v: expected %v", found, false)
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetString(manager), cookie)
|
||||
if body != "lorem ipsum" {
|
||||
t.Fatalf("got %q: expected %q", body, "lorem ipsum")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClear(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
_, _, cookie = testRequest(t, testPutBool(manager), cookie)
|
||||
|
||||
_, body, cookie := testRequest(t, testClear(manager), cookie)
|
||||
if body != "OK" {
|
||||
t.Fatalf("got %q: expected %q", body, "OK")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetString(manager), cookie)
|
||||
if body != "" {
|
||||
t.Fatalf("got %q: expected %q", body, "")
|
||||
}
|
||||
|
||||
_, body, _ = testRequest(t, testGetBool(manager), cookie)
|
||||
if body != "false" {
|
||||
t.Fatalf("got %q: expected %q", body, "false")
|
||||
}
|
||||
|
||||
// Check that it's a no-op if there is no data in the session
|
||||
_, _, cookie = testRequest(t, testClear(manager), cookie)
|
||||
if cookie != "" {
|
||||
t.Fatalf("got %q: expected %q", cookie, "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeys(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
_, _, cookie := testRequest(t, testPutString(manager), "")
|
||||
_, _, _ = testRequest(t, testPutBool(manager), cookie)
|
||||
|
||||
_, body, _ := testRequest(t, testKeys(manager), cookie)
|
||||
if body != "[test_bool test_string]" {
|
||||
t.Fatalf("got %q: expected %q", body, "[test_bool test_string]")
|
||||
}
|
||||
|
||||
_, _, _ = testRequest(t, testClear(manager), cookie)
|
||||
_, body, _ = testRequest(t, testKeys(manager), cookie)
|
||||
if body != "[]" {
|
||||
t.Fatalf("got %q: expected %q", body, "[]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFailure(t *testing.T) {
|
||||
manager := NewManager(newMockStore())
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: "session",
|
||||
Value: "force-error",
|
||||
}
|
||||
|
||||
_, body, _ := testRequest(t, testPutString(manager), cookie.String())
|
||||
if body != "forced-error\n" {
|
||||
t.Fatalf("got %q: expected %q", body, "forced-error\n")
|
||||
}
|
||||
}
|
@ -1,27 +1,27 @@
|
||||
package session
|
||||
package scs
|
||||
|
||||
import "time"
|
||||
|
||||
// Engine is the interface for storage engines.
|
||||
type Engine interface {
|
||||
// Store is the interface for session stores.
|
||||
type Store interface {
|
||||
// Delete should remove the session token and corresponding data from the
|
||||
// session engine. If the token does not exist then Delete should be a no-op
|
||||
// session store. If the token does not exist then Delete should be a no-op
|
||||
// and return nil (not an error).
|
||||
Delete(token string) (err error)
|
||||
|
||||
// Find should return the data for a session token from the storage engine.
|
||||
// Find should return the data for a session token from the session store.
|
||||
// If the session token is not found or is expired, the found return value
|
||||
// should be false (and the err return value should be nil). Similarly, tampered
|
||||
// or malformed tokens should result in a found return value of false and a
|
||||
// nil err value. The err return value should be used for system errors only.
|
||||
Find(token string) (b []byte, found bool, err error)
|
||||
|
||||
// Save should add the session token and data to the storage engine, with
|
||||
// Save should add the session token and data to the session store, with
|
||||
// the given expiry time. If the session token already exists, then the data
|
||||
// and expiry time should be overwritten.
|
||||
Save(token string, b []byte, expiry time.Time) (err error)
|
||||
}
|
||||
|
||||
type cookieEngine interface {
|
||||
type cookieStore interface {
|
||||
MakeToken(b []byte, expiry time.Time) (token string, err error)
|
||||
}
|
41
store_test.go
Normal file
41
store_test.go
Normal file
@ -0,0 +1,41 @@
|
||||
package scs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockStore struct {
|
||||
m map[string]*mockEntry
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
b []byte
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
func newMockStore() *mockStore {
|
||||
m := make(map[string]*mockEntry)
|
||||
return &mockStore{m}
|
||||
}
|
||||
|
||||
func (s *mockStore) Delete(token string) error {
|
||||
delete(s.m, token)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mockStore) Find(token string) (b []byte, found bool, err error) {
|
||||
if token == "force-error" {
|
||||
return nil, false, errors.New("forced-error")
|
||||
}
|
||||
entry, exists := s.m[token]
|
||||
if !exists || entry.expiry.UnixNano() < time.Now().UnixNano() {
|
||||
return nil, false, nil
|
||||
}
|
||||
return entry.b, true, nil
|
||||
}
|
||||
|
||||
func (s *mockStore) Save(token string, b []byte, expiry time.Time) error {
|
||||
s.m[token] = &mockEntry{b, expiry}
|
||||
return nil
|
||||
}
|
Reference in New Issue
Block a user