diff --git a/testingpg/testingpg.go b/testingpg/testingpg.go index 1403b83..4b3a829 100644 --- a/testingpg/testingpg.go +++ b/testingpg/testingpg.go @@ -28,6 +28,8 @@ func New(t TestingT) *Postgres { } type Postgres struct { + t TestingT + url string ref string @@ -51,6 +53,8 @@ func newPostgres(t TestingT) *Postgres { require.NoError(t, err) return &Postgres{ + t: t, + url: urlStr, ref: refDatabase, @@ -63,26 +67,28 @@ func (p *Postgres) URL() string { } func (p *Postgres) PgxPool() *pgxpool.Pool { - return p.conn + pool, err := pgxpool.New(context.Background(), p.URL()) + require.NoError(p.t, err) + + // Automatically close connection after the test is completed. + p.t.Cleanup(func() { + pool.Close() + }) + + return pool } func (p *Postgres) cloneFromReference(t TestingT) *Postgres { - cfg, err := pgxpool.ParseConfig(p.url) - require.NoError(t, err) - - pool, err := pgxpool.New(context.Background(), p.url) - require.NoError(t, err) - newDatabaseName := uuid.New().String() - const sqlTemplate = `CREATE DATABASE %q WITH TEMPLATE %s OWNER %s;` + const sqlTemplate = `CREATE DATABASE %q WITH TEMPLATE %s;` sql := fmt.Sprintf( sqlTemplate, newDatabaseName, p.ref, - cfg.ConnConfig.User, ) - _, err = pool.Exec(context.Background(), sql) + + _, err := p.PgxPool().Exec(context.Background(), sql) require.NoError(t, err) // Automatically drop database copy after the test is completed. @@ -96,20 +102,17 @@ func (p *Postgres) cloneFromReference(t TestingT) *Postgres { require.NoError(t, err) }) - urlString := replaceDBName(t, cfg, newDatabaseName) - newPool, err := pgxpool.New(context.Background(), urlString) - require.NoError(t, err) + urlString := replaceDBName(t, p.URL(), newDatabaseName) return &Postgres{ + t: p.t, url: urlString, ref: newDatabaseName, - - conn: newPool, } } -func replaceDBName(t TestingT, cfg *pgxpool.Config, dbname string) string { - r, err := url.Parse(cfg.ConnString()) +func replaceDBName(t TestingT, dataSourceURL, dbname string) string { + r, err := url.Parse(dataSourceURL) require.NoError(t, err) r.Path = dbname return r.String()