diff --git a/ch/chschema/table.go b/ch/chschema/table.go index 36eb91c..75ffdd7 100644 --- a/ch/chschema/table.go +++ b/ch/chschema/table.go @@ -13,8 +13,7 @@ import ( ) const ( - discardUnknownColumnsFlag = internal.Flag(1) << iota - columnarFlag + columnarFlag = internal.Flag(1) << iota afterScanBlockHookFlag ) diff --git a/ch/config.go b/ch/config.go index 6aed087..2e4e1ec 100644 --- a/ch/config.go +++ b/ch/config.go @@ -17,6 +17,7 @@ import ( const ( discardUnknownColumnsFlag internal.Flag = 1 << iota + autoCreateDatabaseFlag ) type Config struct { @@ -24,7 +25,6 @@ type Config struct { Compression bool - Network string Addr string User string Password string @@ -42,6 +42,11 @@ type Config struct { MaxRetryBackoff time.Duration } +func (cfg *Config) clone() *Config { + clone := *cfg + return &clone +} + func (cfg *Config) netDialer() *net.Dialer { return &net.Dialer{ Timeout: cfg.DialTimeout, @@ -62,7 +67,6 @@ func defaultConfig() *Config { Compression: true, - Network: "tcp", Addr: "localhost:9000", User: "default", Database: "default", @@ -93,7 +97,13 @@ func WithCompression(enabled bool) Option { } } -// WithAddr configures TCP host:port or Unix socket depending on Network. +func WithAutoCreateDatabase(enabled bool) Option { + return func(db *DB) { + db.flags.Set(autoCreateDatabaseFlag) + } +} + +// WithAddr configures TCP host:port. func WithAddr(addr string) Option { return func(db *DB) { db.cfg.Addr = addr diff --git a/ch/db.go b/ch/db.go index 3159f14..d660c8a 100644 --- a/ch/db.go +++ b/ch/db.go @@ -35,15 +35,21 @@ type DB struct { } func Connect(opts ...Option) *DB { - db := &DB{ - cfg: defaultConfig(), + db := newDB(defaultConfig(), opts...) + if db.flags.Has(autoCreateDatabaseFlag) { + db.autoCreateDatabase() } + return db +} +func newDB(cfg *Config, opts ...Option) *DB { + db := &DB{ + cfg: cfg, + } for _, opt := range opts { opt(db) } db.pool = newConnPool(db.cfg) - return db } @@ -53,12 +59,12 @@ func newConnPool(cfg *Config) *chpool.ConnPool { if cfg.TLSConfig != nil { return tls.DialWithDialer( cfg.netDialer(), - cfg.Network, + "tcp", cfg.Addr, cfg.TLSConfig, ) } - return cfg.netDialer().DialContext(ctx, cfg.Network, cfg.Addr) + return cfg.netDialer().DialContext(ctx, "tcp", cfg.Addr) } return chpool.New(&poolcfg) } @@ -106,6 +112,32 @@ func (db *DB) Stats() DBStats { } } +func (db *DB) autoCreateDatabase() { + ctx := context.Background() + + switch err := db.Ping(ctx); err := err.(type) { + case nil: // all is good + return + case *Error: + if err.Code != 81 { // 81 - database does not exist + return + } + default: + // ignore the error + return + } + + cfg := db.cfg.clone() + cfg.Database = "" + + tmp := newDB(cfg) + defer tmp.Close() + + if _, err := tmp.Exec("CREATE DATABASE IF NOT EXISTS ?", Ident(db.cfg.Database)); err != nil { + internal.Logger.Printf("create database %q failed: %s", db.cfg.Database, err) + } +} + func (db *DB) getConn(ctx context.Context) (*chpool.Conn, error) { cn, err := db.pool.Get(ctx) if err != nil { diff --git a/ch/db_test.go b/ch/db_test.go index 6ca2562..6aaed58 100644 --- a/ch/db_test.go +++ b/ch/db_test.go @@ -23,7 +23,7 @@ func chDB(opts ...ch.Option) *ch.DB { dsn = "clickhouse://localhost:9000/test?sslmode=disable" } - opts = append(opts, ch.WithDSN(dsn)) + opts = append(opts, ch.WithDSN(dsn), ch.WithAutoCreateDatabase(true)) db := ch.Connect(opts...) db.AddQueryHook(chdebug.NewQueryHook( chdebug.WithEnabled(false), @@ -32,6 +32,30 @@ func chDB(opts ...ch.Option) *ch.DB { return db } +func TestAutoCreateDatabase(t *testing.T) { + ctx := context.Background() + dbName := "auto_create_database" + + { + db := ch.Connect() + defer db.Close() + + _, err := db.Exec("DROP DATABASE IF EXISTS ?", ch.Ident(dbName)) + require.NoError(t, err) + } + + { + db := ch.Connect( + ch.WithDatabase(dbName), + ch.WithAutoCreateDatabase(true), + ) + defer db.Close() + + err := db.Ping(ctx) + require.NoError(t, err) + } +} + func TestCHError(t *testing.T) { ctx := context.Background()