From 4465b2654dd2c0ca439c9e4b7e7aff6a61b36c25 Mon Sep 17 00:00:00 2001 From: Nurahmadie Date: Sat, 15 Feb 2014 22:17:22 +0700 Subject: [PATCH] Fix migration step not checked against current version. Add tests for DropColumns. --- pkg/database/migrate/migrate.go | 7 ++- pkg/database/migrate/sqlite.go | 11 ++--- pkg/database/migrate/sqlite_test.go | 68 +++++++++++++++++++++++++++++ pkg/database/migrate/util.go | 4 +- 4 files changed, 79 insertions(+), 11 deletions(-) diff --git a/pkg/database/migrate/migrate.go b/pkg/database/migrate/migrate.go index 66bc96095..cc96f7ed7 100644 --- a/pkg/database/migrate/migrate.go +++ b/pkg/database/migrate/migrate.go @@ -53,7 +53,6 @@ DELETE FROM migration where revision = ? // Implementation details is specific for each database, // see migrate/sqlite.go for implementation reference. type Operation interface { - CreateTable(tableName string, args []string) (sql.Result, error) RenameTable(tableName, newName string) (sql.Result, error) @@ -147,7 +146,7 @@ func (m *Migration) up(target, current int64) error { // loop through and execute revisions for _, rev := range m.revs { - if rev.Revision() >= target { + if rev.Revision() > current { current = rev.Revision() // execute the revision Upgrade. if err := rev.Up(op); err != nil { @@ -191,7 +190,7 @@ func (m *Migration) down(target, current int64) error { current = rev.Revision() // execute the revision Upgrade. if err := rev.Down(op); err != nil { - log.Printf("Failed to downgrade to Revision Number %v\n", current) + log.Printf("Failed to downgrade from Revision Number %v\n", current) log.Println(err) return tx.Rollback() } @@ -202,7 +201,7 @@ func (m *Migration) down(target, current int64) error { return tx.Rollback() } - log.Printf("Successfully downgraded to Revision %v\n", current) + log.Printf("Successfully downgraded from Revision %v\n", current) } } diff --git a/pkg/database/migrate/sqlite.go b/pkg/database/migrate/sqlite.go index a2074b08e..d79e76deb 100644 --- a/pkg/database/migrate/sqlite.go +++ b/pkg/database/migrate/sqlite.go @@ -5,8 +5,8 @@ import ( "fmt" "strings" - _ "github.com/mattn/go-sqlite3" "github.com/dchest/uniuri" + _ "github.com/mattn/go-sqlite3" ) type SQLiteDriver MigrationDriver @@ -48,7 +48,8 @@ func (s *SQLiteDriver) DropColumns(tableName string, columnsToDrop []string) (sq } columnNames := selectName(columns) - preparedColumns := make([]string, len(columnNames)-len(columnsToDrop)) + + var preparedColumns []string for k, column := range columnNames { listed := false for _, dropped := range columnsToDrop { @@ -98,8 +99,8 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] return nil, err } - oldColumns := make([]string, len(columnChanges)) - newColumns := make([]string, len(columnChanges)) + var oldColumns []string + var newColumns []string for k, column := range selectName(columns) { for Old, New := range columnChanges { if column == Old { @@ -126,7 +127,7 @@ func (s *SQLiteDriver) RenameColumns(tableName string, columnChanges map[string] func (s *SQLiteDriver) getDDLFromTable(tableName string) (string, error) { var sql string - query := `SELECT sql FROM sqlite_master WHERE type='table' and name='?';` + query := `SELECT sql FROM sqlite_master WHERE type='table' and name=?;` err := s.Tx.QueryRow(query, tableName).Scan(&sql) if err != nil { return "", err diff --git a/pkg/database/migrate/sqlite_test.go b/pkg/database/migrate/sqlite_test.go index ab7d8492b..b9b4d9e11 100644 --- a/pkg/database/migrate/sqlite_test.go +++ b/pkg/database/migrate/sqlite_test.go @@ -76,6 +76,29 @@ func (r *revision2) Revision() int64 { // ---------- end of revision 2 +// ---------- revision 3 + +type revision3 struct{} + +func (r *revision3) Up(op Operation) error { + if _, err := op.AddColumn("samples", "url VARCHAR(255)"); err != nil { + return err + } + _, err := op.AddColumn("samples", "likes INTEGER") + return err +} + +func (r *revision3) Down(op Operation) error { + _, err := op.DropColumns("samples", []string{"likes", "url"}) + return err +} + +func (r *revision3) Revision() int64 { + return 3 +} + +// ---------- end of revision 3 + var db *sql.DB var testSchema = ` @@ -144,6 +167,51 @@ func TestMigrateRenameTable(t *testing.T) { } } +type TableInfo struct { + CID int64 `meddler:"cid,pk"` + Name string `meddler:"name"` + Type string `meddler:"type"` + Notnull bool `meddler:"notnull"` + DfltValue interface{} `meddler:"dflt_value"` + PK bool `meddler:"pk"` +} + +func TestMigrateAddRemoveColumns(t *testing.T) { + defer tearDown() + if err := setUp(); err != nil { + t.Fatalf("Error preparing database: %q", err) + } + + Driver = SQLite + + mgr := New(db) + if err := mgr.Add(&revision1{}).Add(&revision3{}).Migrate(); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var columns []*TableInfo + if err := meddler.QueryAll(db, &columns, `PRAGMA table_info(samples);`); err != nil { + t.Errorf("Can not access table info: %q", err) + } + + if len(columns) < 5 { + t.Errorf("Expect length columns: %d\nGot: %d", 5, len(columns)) + } + + if err := mgr.MigrateTo(1); err != nil { + t.Errorf("Can not migrate: %q", err) + } + + var another_columns []*TableInfo + if err := meddler.QueryAll(db, &another_columns, `PRAGMA table_info(samples);`); err != nil { + t.Errorf("Can not access table info: %q", err) + } + + if len(another_columns) != 3 { + t.Errorf("Expect length columns: %d\nGot: %d", 3, len(columns)) + } +} + func setUp() error { var err error db, err = sql.Open("sqlite3", "migration_tests.sqlite") diff --git a/pkg/database/migrate/util.go b/pkg/database/migrate/util.go index 1dfec95d7..a0f6bfb59 100644 --- a/pkg/database/migrate/util.go +++ b/pkg/database/migrate/util.go @@ -15,7 +15,7 @@ func fetchColumns(sql string) ([]string, error) { } func selectName(columns []string) []string { - results := make([]string, len(columns)) + var results []string for _, column := range columns { col := strings.SplitN(strings.Trim(column, " \n\t"), " ", 2) results = append(results, col[0]) @@ -24,7 +24,7 @@ func selectName(columns []string) []string { } func setForUpdate(left []string, right []string) string { - results := make([]string, len(left)) + var results []string for k, str := range left { results = append(results, fmt.Sprintf("%s = %s", str, right[k])) }