// Copyright (c) 2023-2024 Vasiliy Vasilyuk. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package testingpg import ( "context" "crypto/rand" "database/sql" "encoding/base64" "fmt" "net/url" "os" "strings" "sync" "time" "unicode" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/require" ) type TestingT interface { require.TestingT Cleanup(f func()) Log(args ...any) Logf(format string, args ...any) Name() string Failed() bool } func NewWithIsolatedDatabase(t TestingT) *Postgres { return newPostgres(t).cloneFromReference() } type Postgres struct { t TestingT url string ref string sqlDB *sql.DB sqlDBOnce sync.Once } func newPostgres(t TestingT) *Postgres { urlStr := os.Getenv("TESTING_DB_URL") if urlStr == "" { urlStr = "postgresql://postgres:postgres@localhost:32260/postgres?sslmode=disable" const format = "env TESTING_DB_URL is empty, used default value: %s" t.Logf(format, urlStr) } refDatabase := os.Getenv("TESTING_DB_REF") if refDatabase == "" { refDatabase = "reference" } return &Postgres{ t: t, url: urlStr, ref: refDatabase, } } func (p *Postgres) URL() string { return p.url } func (p *Postgres) DB() *sql.DB { p.sqlDBOnce.Do(func() { p.sqlDB = open(p.t, p.URL()) }) return p.sqlDB } func (p *Postgres) cloneFromReference() *Postgres { newDBName := newUniqueHumanReadableDatabaseName(p.t) p.t.Log("database name for this test:", newDBName) sql := fmt.Sprintf( `CREATE DATABASE %q WITH TEMPLATE %q;`, newDBName, p.ref, ) _, err := p.DB().ExecContext(context.Background(), sql) require.NoError(p.t, err) // Automatically drop database copy after the test is completed. p.t.Cleanup(func() { sql := fmt.Sprintf(`DROP DATABASE %q WITH (FORCE);`, newDBName) ctx, done := context.WithTimeout(context.Background(), time.Minute) defer done() _, err := p.DB().ExecContext(ctx, sql) require.NoError(p.t, err) }) return &Postgres{ t: p.t, url: replaceDBName(p.t, p.URL(), newDBName), ref: newDBName, } } func newUniqueHumanReadableDatabaseName(t TestingT) string { output := strings.Builder{} // Reports the maximum identifier length. It is determined as one less // than the value of NAMEDATALEN when building the server. The default // value of NAMEDATALEN is 64; therefore the default max_identifier_length // is 63 bytes, which can be less than 63 characters when using multibyte // encodings. // See https://www.postgresql.org/docs/15/runtime-config-preset.html const maxIdentifierLengthBytes = 63 uid := genUnique8BytesID(t) maxHumanReadableLenBytes := maxIdentifierLengthBytes - len(uid) lastSymbolIsHyphen := false for _, r := range t.Name() { if unicode.IsLetter(r) || unicode.IsNumber(r) { output.WriteRune(r) lastSymbolIsHyphen = false } else { if !lastSymbolIsHyphen { output.WriteRune('-') } lastSymbolIsHyphen = true } if output.Len() >= maxHumanReadableLenBytes { break } } output.WriteString(uid) return output.String() } func genUnique8BytesID(t TestingT) string { bs := make([]byte, 6) _, err := rand.Read(bs) require.NoError(t, err) return base64.RawURLEncoding.EncodeToString(bs) } func replaceDBName(t TestingT, dataSourceURL, dbname string) string { r, err := url.Parse(dataSourceURL) require.NoError(t, err) r.Path = dbname return r.String() } func open(t TestingT, dataSourceURL string) *sql.DB { db, err := sql.Open("pgx/v5", dataSourceURL) require.NoError(t, err) // Automatically close connection after the test is completed. t.Cleanup(func() { db.Close() }) return db }