mirror of
https://github.com/alexedwards/scs.git
synced 2025-07-15 01:04:36 +02:00
256 lines
6.2 KiB
Go
256 lines
6.2 KiB
Go
package bunstore
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/uptrace/bun"
|
|
"github.com/uptrace/bun/dialect/mysqldialect"
|
|
"github.com/uptrace/bun/dialect/pgdialect"
|
|
"github.com/uptrace/bun/dialect/sqlitedialect"
|
|
"github.com/uptrace/bun/driver/sqliteshim"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/uptrace/bun/driver/pgdriver"
|
|
)
|
|
|
|
func initWithCleanupInterval(t *testing.T, cleanupInterval time.Duration) *BunStore {
|
|
var db *bun.DB
|
|
var err error
|
|
|
|
dialect := os.Getenv("SCS_BUN_TEST_DIALECT")
|
|
switch dialect {
|
|
default:
|
|
dialect = "sqlite3"
|
|
fallthrough
|
|
case "postgres":
|
|
dsn := os.Getenv("SCS_POSTGRES_TEST_DSN")
|
|
sqldb, err := sql.Open("pg", dsn)
|
|
if err != nil {
|
|
sqldb.Close()
|
|
t.Fatal(err)
|
|
}
|
|
|
|
sqldb.Exec(`DROP TABLE IF EXISTS sessions`)
|
|
sqldb.Exec(`CREATE TABLE sessions (token TEXT PRIMARY KEY,data BYTEA NOT NULL,expiry TIMESTAMPTZ NOT NULL);`)
|
|
sqldb.Exec(`CREATE INDEX sessions_expiry_idx ON sessions (expiry);`)
|
|
|
|
db = bun.NewDB(sqldb, pgdialect.New())
|
|
case "mysql":
|
|
dsn := os.Getenv("SCS_MYSQL_TEST_DSN")
|
|
sqldb, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
sqldb.Close()
|
|
t.Fatal(err)
|
|
}
|
|
|
|
sqldb.Exec(`DROP TABLE IF EXISTS sessions`)
|
|
sqldb.Exec(`CREATE TABLE sessions (token CHAR(43) PRIMARY KEY,data BLOB NOT NULL,expiry TIMESTAMP(6) NOT NULL);`)
|
|
sqldb.Exec(`CREATE INDEX sessions_expiry_idx ON sessions (expiry);`)
|
|
|
|
db = bun.NewDB(sqldb, mysqldialect.New())
|
|
case "sqlite3":
|
|
dsn := os.Getenv("./testSQL3lite.db")
|
|
sqldb, err := sql.Open(sqliteshim.ShimName, dsn)
|
|
if err != nil {
|
|
sqldb.Close()
|
|
t.Fatal(err)
|
|
}
|
|
|
|
sqldb.Exec(`DROP TABLE IF EXISTS sessions`)
|
|
sqldb.Exec(`CREATE TABLE sessions (token TEXT PRIMARY KEY,data BLOB NOT NULL,expiry REAL NOT NULL);`)
|
|
sqldb.Exec(`CREATE INDEX sessions_expiry_idx ON sessions(expiry);`)
|
|
|
|
db = bun.NewDB(sqldb, sqlitedialect.New())
|
|
}
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if db.Ping(); err != nil {
|
|
db.Close()
|
|
t.Fatal(err)
|
|
}
|
|
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1000)
|
|
db.SetConnMaxLifetime(0)
|
|
|
|
b, err := NewWithCleanupInterval(db, cleanupInterval)
|
|
if err != nil {
|
|
db.Close()
|
|
t.Fatal(err)
|
|
}
|
|
|
|
return b
|
|
}
|
|
|
|
func TestFind(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
ctx := context.Background()
|
|
|
|
values := &map[string]interface{}{"token": "session_token", "data": []byte("encoded_data"), "expiry": time.Now().Add(1 * time.Minute)}
|
|
if _, err := b.db.NewInsert().Model(values).TableExpr("sessions").Exec(ctx); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
bb, found, err := b.FindCtx(ctx, "session_token")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if found != true {
|
|
t.Fatalf("got %v: expected %v", found, true)
|
|
}
|
|
if bytes.Equal(bb, []byte("encoded_data")) == false {
|
|
t.Fatalf("got %v: expected %v", b, []byte("encoded_data"))
|
|
}
|
|
}
|
|
|
|
func TestFindMissing(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
ctx := context.Background()
|
|
|
|
_, found, err := b.FindCtx(ctx, "missing_session_token")
|
|
if err != nil {
|
|
t.Fatalf("got %v: expected %v", err, nil)
|
|
}
|
|
if found != false {
|
|
t.Fatalf("got %v: expected %v", found, false)
|
|
}
|
|
}
|
|
|
|
func TestSaveNew(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
ctx := context.Background()
|
|
|
|
err := b.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(time.Minute))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
row := b.db.QueryRow("SELECT data FROM sessions WHERE token = 'session_token'")
|
|
var data []byte
|
|
err = row.Scan(&data)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if reflect.DeepEqual(data, []byte("encoded_data")) == false {
|
|
t.Fatalf("got %v: expected %v", data, []byte("encoded_data"))
|
|
}
|
|
}
|
|
|
|
func TestSaveUpdated(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
ctx := context.Background()
|
|
|
|
values := &map[string]interface{}{"token": "session_token", "data": []byte("encoded_data"), "expiry": time.Now().Add(1 * time.Minute)}
|
|
if _, err := b.db.NewInsert().Model(values).TableExpr("sessions").Exec(ctx); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err := b.CommitCtx(ctx, "session_token", []byte("new_encoded_data"), time.Now().Add(time.Minute))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
row := b.db.QueryRow("SELECT data FROM sessions WHERE token = 'session_token'")
|
|
var data []byte
|
|
err = row.Scan(&data)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if reflect.DeepEqual(data, []byte("new_encoded_data")) == false {
|
|
t.Fatalf("got %v: expected %v", data, []byte("new_encoded_data"))
|
|
}
|
|
}
|
|
|
|
func TestExpiry(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
ctx := context.Background()
|
|
|
|
err := b.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(1*time.Second))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, found, _ := b.FindCtx(ctx, "session_token")
|
|
if found != true {
|
|
t.Fatalf("got %v: expected %v", found, true)
|
|
}
|
|
|
|
time.Sleep(2 * time.Second)
|
|
_, found, _ = b.FindCtx(ctx, "session_token")
|
|
if found != false {
|
|
t.Fatalf("got %v: expected %v", found, false)
|
|
}
|
|
}
|
|
|
|
func TestDelete(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
ctx := context.Background()
|
|
|
|
err := b.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(1*time.Minute))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
err = b.DeleteCtx(ctx, "session_token")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
row := b.db.QueryRow("SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
|
|
var count int
|
|
err = row.Scan(&count)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if count != 0 {
|
|
t.Fatalf("got %d: expected %d", count, 0)
|
|
}
|
|
}
|
|
|
|
func TestCleanup(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 2*time.Second)
|
|
defer b.StopCleanup()
|
|
ctx := context.Background()
|
|
|
|
err := b.CommitCtx(ctx, "session_token", []byte("encoded_data"), time.Now().Add(1*time.Second))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
row := b.db.QueryRow("SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
|
|
var count int
|
|
err = row.Scan(&count)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if count != 1 {
|
|
t.Fatalf("got %d: expected %d", count, 1)
|
|
}
|
|
|
|
time.Sleep(3 * time.Second)
|
|
row = b.db.QueryRow("SELECT COUNT(*) FROM sessions WHERE token = 'session_token'")
|
|
err = row.Scan(&count)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if count != 0 {
|
|
t.Fatalf("got %d: expected %d", count, 0)
|
|
}
|
|
}
|
|
|
|
func TestStopNilCleanup(t *testing.T) {
|
|
b := initWithCleanupInterval(t, 0)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
// A send to a nil channel will block forever
|
|
b.StopCleanup()
|
|
}
|