1
0
mirror of https://github.com/uptrace/go-clickhouse.git synced 2025-06-27 00:21:13 +02:00

feat: add WithAutoCreateDatabase option

This commit is contained in:
Vladimir Mihailenco
2022-07-28 11:20:11 +03:00
parent d2293eb95a
commit 74e949e01d
4 changed files with 76 additions and 11 deletions

View File

@ -13,8 +13,7 @@ import (
)
const (
discardUnknownColumnsFlag = internal.Flag(1) << iota
columnarFlag
columnarFlag = internal.Flag(1) << iota
afterScanBlockHookFlag
)

View File

@ -17,6 +17,7 @@ import (
const (
discardUnknownColumnsFlag internal.Flag = 1 << iota
autoCreateDatabaseFlag
)
type Config struct {
@ -24,7 +25,6 @@ type Config struct {
Compression bool
Network string
Addr string
User string
Password string
@ -42,6 +42,11 @@ type Config struct {
MaxRetryBackoff time.Duration
}
func (cfg *Config) clone() *Config {
clone := *cfg
return &clone
}
func (cfg *Config) netDialer() *net.Dialer {
return &net.Dialer{
Timeout: cfg.DialTimeout,
@ -62,7 +67,6 @@ func defaultConfig() *Config {
Compression: true,
Network: "tcp",
Addr: "localhost:9000",
User: "default",
Database: "default",
@ -93,7 +97,13 @@ func WithCompression(enabled bool) Option {
}
}
// WithAddr configures TCP host:port or Unix socket depending on Network.
func WithAutoCreateDatabase(enabled bool) Option {
return func(db *DB) {
db.flags.Set(autoCreateDatabaseFlag)
}
}
// WithAddr configures TCP host:port.
func WithAddr(addr string) Option {
return func(db *DB) {
db.cfg.Addr = addr

View File

@ -35,15 +35,21 @@ type DB struct {
}
func Connect(opts ...Option) *DB {
db := &DB{
cfg: defaultConfig(),
db := newDB(defaultConfig(), opts...)
if db.flags.Has(autoCreateDatabaseFlag) {
db.autoCreateDatabase()
}
return db
}
func newDB(cfg *Config, opts ...Option) *DB {
db := &DB{
cfg: cfg,
}
for _, opt := range opts {
opt(db)
}
db.pool = newConnPool(db.cfg)
return db
}
@ -53,12 +59,12 @@ func newConnPool(cfg *Config) *chpool.ConnPool {
if cfg.TLSConfig != nil {
return tls.DialWithDialer(
cfg.netDialer(),
cfg.Network,
"tcp",
cfg.Addr,
cfg.TLSConfig,
)
}
return cfg.netDialer().DialContext(ctx, cfg.Network, cfg.Addr)
return cfg.netDialer().DialContext(ctx, "tcp", cfg.Addr)
}
return chpool.New(&poolcfg)
}
@ -106,6 +112,32 @@ func (db *DB) Stats() DBStats {
}
}
func (db *DB) autoCreateDatabase() {
ctx := context.Background()
switch err := db.Ping(ctx); err := err.(type) {
case nil: // all is good
return
case *Error:
if err.Code != 81 { // 81 - database does not exist
return
}
default:
// ignore the error
return
}
cfg := db.cfg.clone()
cfg.Database = ""
tmp := newDB(cfg)
defer tmp.Close()
if _, err := tmp.Exec("CREATE DATABASE IF NOT EXISTS ?", Ident(db.cfg.Database)); err != nil {
internal.Logger.Printf("create database %q failed: %s", db.cfg.Database, err)
}
}
func (db *DB) getConn(ctx context.Context) (*chpool.Conn, error) {
cn, err := db.pool.Get(ctx)
if err != nil {

View File

@ -23,7 +23,7 @@ func chDB(opts ...ch.Option) *ch.DB {
dsn = "clickhouse://localhost:9000/test?sslmode=disable"
}
opts = append(opts, ch.WithDSN(dsn))
opts = append(opts, ch.WithDSN(dsn), ch.WithAutoCreateDatabase(true))
db := ch.Connect(opts...)
db.AddQueryHook(chdebug.NewQueryHook(
chdebug.WithEnabled(false),
@ -32,6 +32,30 @@ func chDB(opts ...ch.Option) *ch.DB {
return db
}
func TestAutoCreateDatabase(t *testing.T) {
ctx := context.Background()
dbName := "auto_create_database"
{
db := ch.Connect()
defer db.Close()
_, err := db.Exec("DROP DATABASE IF EXISTS ?", ch.Ident(dbName))
require.NoError(t, err)
}
{
db := ch.Connect(
ch.WithDatabase(dbName),
ch.WithAutoCreateDatabase(true),
)
defer db.Close()
err := db.Ping(ctx)
require.NoError(t, err)
}
}
func TestCHError(t *testing.T) {
ctx := context.Background()