diff --git a/ch/query_insert.go b/ch/query_insert.go index ba258f1..d8e1f79 100644 --- a/ch/query_insert.go +++ b/ch/query_insert.go @@ -43,6 +43,11 @@ func (q *InsertQuery) TableExpr(query string, args ...any) *InsertQuery { return q } +func (q *InsertQuery) ModelTable(table string) *InsertQuery { + q.modelTableName = chschema.UnsafeIdent(table) + return q +} + func (q *InsertQuery) ModelTableExpr(query string, args ...any) *InsertQuery { q.modelTableName = chschema.SafeQuery(query, args) return q diff --git a/ch/query_select.go b/ch/query_select.go index e3cd4e9..d01738e 100644 --- a/ch/query_select.go +++ b/ch/query_select.go @@ -128,6 +128,11 @@ func (q *SelectQuery) TableExpr(query string, args ...any) *SelectQuery { return q } +func (q *SelectQuery) ModelTable(table string) *SelectQuery { + q.modelTableName = chschema.UnsafeIdent(table) + return q +} + func (q *SelectQuery) ModelTableExpr(query string, args ...any) *SelectQuery { q.modelTableName = chschema.SafeQuery(query, args) return q diff --git a/ch/query_table_create.go b/ch/query_table_create.go index 3f2d14c..b57b8a7 100644 --- a/ch/query_table_create.go +++ b/ch/query_table_create.go @@ -12,6 +12,7 @@ type CreateTableQuery struct { baseQuery ifNotExists bool + as chschema.QueryWithArgs onCluster chschema.QueryWithArgs engine chschema.QueryWithArgs ttl chschema.QueryWithArgs @@ -52,11 +53,21 @@ func (q *CreateTableQuery) TableExpr(query string, args ...any) *CreateTableQuer return q } +func (q *CreateTableQuery) ModelTable(table string) *CreateTableQuery { + q.modelTableName = chschema.UnsafeIdent(table) + return q +} + func (q *CreateTableQuery) ModelTableExpr(query string, args ...any) *CreateTableQuery { q.modelTableName = chschema.SafeQuery(query, args) return q } +func (q *CreateTableQuery) As(table string) *CreateTableQuery { + q.as = chschema.UnsafeIdent(table) + return q +} + func (q *CreateTableQuery) ColumnExpr(query string, args ...any) *CreateTableQuery { q.addColumn(chschema.SafeQuery(query, args)) return q @@ -111,10 +122,6 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ [] if q.err != nil { return nil, q.err } - if q.table == nil { - return nil, errNilModel - } - b = append(b, "CREATE TABLE "...) if q.ifNotExists { b = append(b, "IF NOT EXISTS "...) @@ -133,36 +140,46 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ [] } } - b = append(b, " ("...) - - for i, field := range q.table.Fields { - if i > 0 { - b = append(b, ", "...) - } - - b = append(b, field.CHName...) - b = append(b, " "...) - b = append(b, field.CHType...) - if field.NotNull { - b = append(b, " NOT NULL"...) - } - if field.CHDefault != "" { - b = append(b, " DEFAULT "...) - b = append(b, field.CHDefault...) - } - } - - for i, col := range q.columns { - if i > 0 || len(q.table.Fields) > 0 { - b = append(b, ", "...) - } - b, err = col.AppendQuery(fmter, b) + if !q.as.IsEmpty() { + b = append(b, " AS "...) + b, err = q.as.AppendQuery(fmter, b) if err != nil { return nil, err } } - b = append(b, ")"...) + if q.table != nil { + b = append(b, " ("...) + + for i, field := range q.table.Fields { + if i > 0 { + b = append(b, ", "...) + } + + b = append(b, field.CHName...) + b = append(b, " "...) + b = append(b, field.CHType...) + if field.NotNull { + b = append(b, " NOT NULL"...) + } + if field.CHDefault != "" { + b = append(b, " DEFAULT "...) + b = append(b, field.CHDefault...) + } + } + + for i, col := range q.columns { + if i > 0 || len(q.table.Fields) > 0 { + b = append(b, ", "...) + } + b, err = col.AppendQuery(fmter, b) + if err != nil { + return nil, err + } + } + + b = append(b, ")"...) + } b = append(b, " Engine = "...) @@ -189,17 +206,19 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ [] return nil, err } b = append(b, ')') - } else if len(q.table.PKs) > 0 { - b = append(b, " ORDER BY ("...) - for i, pk := range q.table.PKs { - if i > 0 { - b = append(b, ", "...) + } else if q.table != nil { + if len(q.table.PKs) > 0 { + b = append(b, " ORDER BY ("...) + for i, pk := range q.table.PKs { + if i > 0 { + b = append(b, ", "...) + } + b = append(b, pk.CHName...) } - b = append(b, pk.CHName...) + b = append(b, ')') + } else if q.table.CHEngine == "" { + b = append(b, " ORDER BY tuple()"...) } - b = append(b, ')') - } else if q.table.CHEngine == "" { - b = append(b, " ORDER BY tuple()"...) } if !q.ttl.IsZero() { @@ -219,7 +238,7 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ [] } func (q *CreateTableQuery) appendPartition(fmter chschema.Formatter, b []byte) ([]byte, error) { - if q.partition.IsZero() && q.table.CHPartition == "" { + if q.partition.IsZero() && (q.table == nil || q.table.CHPartition == "") { return b, nil } diff --git a/ch/query_test.go b/ch/query_test.go index 772ffcd..49c69e8 100644 --- a/ch/query_test.go +++ b/ch/query_test.go @@ -125,6 +125,15 @@ func TestQuery(t *testing.T) { q2 := db.NewSelect().Model(new(Model)) return q1.UnionAll(q2) }, + func(db *ch.DB) chschema.QueryAppender { + return db.NewCreateTable(). + Table("my-table_dist"). + As("my-table"). + Engine("Distributed(?, currentDatabase(), ?, rand())", + ch.Ident("my-cluster"), ch.Ident("my-table")). + OnCluster("my-cluster"). + IfNotExists() + }, } db := chDB() diff --git a/ch/testdata/snapshots/TestQuery-19 b/ch/testdata/snapshots/TestQuery-19 new file mode 100644 index 0000000..ca7a15c --- /dev/null +++ b/ch/testdata/snapshots/TestQuery-19 @@ -0,0 +1 @@ +CREATE TABLE IF NOT EXISTS "my-table_dist" AS "my-table" ON CLUSTER "my-cluster" Engine = Distributed("my-cluster", currentDatabase(), "my-table", rand()) diff --git a/chmigrate/migrator.go b/chmigrate/migrator.go index 659f7c7..442ee05 100644 --- a/chmigrate/migrator.go +++ b/chmigrate/migrator.go @@ -17,7 +17,7 @@ type MigratorOption func(m *Migrator) func WithTableName(table string) MigratorOption { return func(m *Migrator) { - m.table = table + m.migrationsTable = table } } @@ -33,9 +33,15 @@ func WithReplicated(on bool) MigratorOption { } } +func WithDistributed(on bool) MigratorOption { + return func(m *Migrator) { + m.distributed = on + } +} + func WithOnCluster(cluster string) MigratorOption { return func(m *Migrator) { - m.onCluster = cluster + m.cluster = cluster } } @@ -53,10 +59,11 @@ type Migrator struct { ms MigrationSlice - table string + migrationsTable string locksTable string replicated bool - onCluster string + distributed bool + cluster string markAppliedOnSuccess bool } @@ -67,8 +74,8 @@ func NewMigrator(db *ch.DB, migrations *Migrations, opts ...MigratorOption) *Mig ms: migrations.ms, - table: "ch_migrations", - locksTable: "ch_migration_locks", + migrationsTable: "ch_migrations", + locksTable: "ch_migration_locks", } for _, opt := range opts { opt(m) @@ -107,6 +114,12 @@ func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, in } func (m *Migrator) Init(ctx context.Context) error { + if m.distributed { + if m.cluster == "" { + return errors.New("chmigrate: distributed requires a cluster name") + } + } + if _, err := m.db.NewCreateTable(). Model((*Migration)(nil)). Apply(func(q *ch.CreateTableQuery) *ch.CreateTableQuery { @@ -115,12 +128,13 @@ func (m *Migrator) Init(ctx context.Context) error { } return q.Engine("CollapsingMergeTree(sign)") }). - ModelTableExpr(m.table). - OnCluster(m.onCluster). + ModelTable(m.migrationsTable). + OnCluster(m.cluster). IfNotExists(). Exec(ctx); err != nil { return err } + if _, err := m.db.NewCreateTable(). Model((*migrationLock)(nil)). Apply(func(q *ch.CreateTableQuery) *ch.CreateTableQuery { @@ -129,31 +143,47 @@ func (m *Migrator) Init(ctx context.Context) error { } return q.Engine("MergeTree") }). - ModelTableExpr(m.locksTable). - OnCluster(m.onCluster). + ModelTable(m.locksTable). + OnCluster(m.cluster). IfNotExists(). Exec(ctx); err != nil { return err } + + if m.distributed { + if _, err := m.db.NewCreateTable(). + Table(m.distTable(m.migrationsTable)). + As(m.migrationsTable). + Engine("Distributed(?, currentDatabase(), ?, rand())", + ch.Ident(m.cluster), ch.Ident(m.migrationsTable)). + OnCluster(m.cluster). + IfNotExists(). + Exec(ctx); err != nil { + return err + } + } + return nil } func (m *Migrator) Reset(ctx context.Context) error { - if _, err := m.db.NewDropTable(). - Model((*Migration)(nil)). - ModelTableExpr(m.table). - OnCluster(m.onCluster). - IfExists(). - Exec(ctx); err != nil { - return err + tables := []string{ + m.migrationsTable, + m.locksTable, } - if _, err := m.db.NewDropTable(). - Model((*migrationLock)(nil)). - ModelTableExpr(m.locksTable). - OnCluster(m.onCluster). - IfExists(). - Exec(ctx); err != nil { - return err + if m.distributed { + tables = append(tables, + m.distTable(m.migrationsTable), + ) + } + for _, tableName := range tables { + if _, err := m.db.NewDropTable(). + Table(tableName). + OnCluster(m.cluster). + IfExists(). + Exec(ctx); err != nil { + return err + } } return m.Init(ctx) } @@ -363,7 +393,7 @@ func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error migration.MigratedAt = time.Now() _, err := m.db.NewInsert(). Model(migration). - ModelTableExpr(m.table). + ModelTable(m.distTable(m.migrationsTable)). Exec(ctx) return err } @@ -373,13 +403,13 @@ func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) erro migration.Sign = -1 _, err := m.db.NewInsert(). Model(migration). - ModelTableExpr(m.table). + ModelTable(m.distTable(m.migrationsTable)). Exec(ctx) return err } func (m *Migrator) TruncateTable(ctx context.Context) error { - _, err := m.db.Exec("TRUNCATE TABLE ?", ch.Ident(m.table)) + _, err := m.db.Exec("TRUNCATE TABLE ?", ch.Ident(m.distTable(m.migrationsTable))) return err } @@ -407,7 +437,7 @@ func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error if err := m.db.NewSelect(). ColumnExpr("*"). Model(&ms). - ModelTableExpr(m.table). + ModelTable(m.distTable(m.migrationsTable)). Final(). Scan(ctx); err != nil { return nil, err @@ -415,10 +445,6 @@ func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error return ms, nil } -func (m *Migrator) formattedTableName(db *ch.DB) string { - return db.Formatter().FormatQuery(m.table) -} - func (m *Migrator) validate() error { if len(m.ms) == 0 { return errors.New("chmigrate: there are no any migrations") @@ -426,6 +452,13 @@ func (m *Migrator) validate() error { return nil } +func (m *Migrator) distTable(table string) string { + if m.distributed { + return table + "_dist" + } + return table +} + //------------------------------------------------------------------------------ type migrationLock struct {