2014-02-10 03:03:22 -07:00
|
|
|
package migrate
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"log"
|
|
|
|
)
|
|
|
|
|
|
|
|
const migrationTableStmt = `
|
|
|
|
CREATE TABLE IF NOT EXISTS migration (
|
2014-03-14 02:14:16 +07:00
|
|
|
revision BIGINT PRIMARY KEY
|
2014-02-10 03:03:22 -07:00
|
|
|
)
|
|
|
|
`
|
|
|
|
|
|
|
|
const migrationSelectStmt = `
|
|
|
|
SELECT revision FROM migration
|
|
|
|
WHERE revision = ?
|
|
|
|
`
|
|
|
|
|
|
|
|
const migrationSelectMaxStmt = `
|
|
|
|
SELECT max(revision) FROM migration
|
|
|
|
`
|
|
|
|
|
|
|
|
const insertRevisionStmt = `
|
|
|
|
INSERT INTO migration (revision) VALUES (?)
|
|
|
|
`
|
|
|
|
|
|
|
|
const deleteRevisionStmt = `
|
|
|
|
DELETE FROM migration where revision = ?
|
|
|
|
`
|
2014-03-10 11:30:39 +07:00
|
|
|
|
2014-02-10 03:03:22 -07:00
|
|
|
type Revision interface {
|
2014-03-08 12:19:28 +07:00
|
|
|
Up(mg *MigrationDriver) error
|
|
|
|
Down(mg *MigrationDriver) error
|
2014-02-10 03:03:22 -07:00
|
|
|
Revision() int64
|
|
|
|
}
|
|
|
|
|
|
|
|
type Migration struct {
|
|
|
|
db *sql.DB
|
|
|
|
revs []Revision
|
|
|
|
}
|
|
|
|
|
2014-03-10 07:08:58 +07:00
|
|
|
var Driver DriverBuilder
|
2014-02-15 20:16:54 +07:00
|
|
|
|
2014-02-10 03:03:22 -07:00
|
|
|
func New(db *sql.DB) *Migration {
|
|
|
|
return &Migration{db: db}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add the Revision to the list of migrations.
|
|
|
|
func (m *Migration) Add(rev ...Revision) *Migration {
|
|
|
|
m.revs = append(m.revs, rev...)
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
|
2014-03-16 10:52:28 +07:00
|
|
|
// Migrate executes the full list of migrations.
|
2014-02-10 03:03:22 -07:00
|
|
|
func (m *Migration) Migrate() error {
|
|
|
|
var target int64
|
|
|
|
if len(m.revs) > 0 {
|
|
|
|
// get the last revision number in
|
|
|
|
// the list. This is what we'll
|
|
|
|
// migrate toward.
|
|
|
|
target = m.revs[len(m.revs)-1].Revision()
|
|
|
|
}
|
|
|
|
return m.MigrateTo(target)
|
|
|
|
}
|
|
|
|
|
2014-03-16 10:52:28 +07:00
|
|
|
// MigrateTo executes all database migration until
|
2014-02-10 03:03:22 -07:00
|
|
|
// you are at the specified revision number.
|
|
|
|
// If the revision number is less than the
|
|
|
|
// current revision, then we will downgrade.
|
|
|
|
func (m *Migration) MigrateTo(target int64) error {
|
|
|
|
|
|
|
|
// make sure the migration table is created.
|
|
|
|
if _, err := m.db.Exec(migrationTableStmt); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the current revision
|
|
|
|
var current int64
|
|
|
|
m.db.QueryRow(migrationSelectMaxStmt).Scan(¤t)
|
|
|
|
|
|
|
|
// already up to date
|
|
|
|
if current == target {
|
|
|
|
log.Println("Database already up-to-date.")
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// should we downgrade?
|
|
|
|
if target < current {
|
|
|
|
return m.down(target, current)
|
|
|
|
}
|
|
|
|
|
|
|
|
// else upgrade
|
|
|
|
return m.up(target, current)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migration) up(target, current int64) error {
|
|
|
|
// create the database transaction
|
|
|
|
tx, err := m.db.Begin()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2014-03-08 12:19:28 +07:00
|
|
|
mg := Driver(tx)
|
2014-02-15 20:16:54 +07:00
|
|
|
|
2014-02-10 03:03:22 -07:00
|
|
|
// loop through and execute revisions
|
|
|
|
for _, rev := range m.revs {
|
2014-02-16 00:56:03 +07:00
|
|
|
if rev.Revision() > current && rev.Revision() <= target {
|
2014-02-10 03:03:22 -07:00
|
|
|
current = rev.Revision()
|
|
|
|
// execute the revision Upgrade.
|
2014-03-08 12:19:28 +07:00
|
|
|
if err := rev.Up(mg); err != nil {
|
2014-02-10 03:03:22 -07:00
|
|
|
log.Printf("Failed to upgrade to Revision Number %v\n", current)
|
|
|
|
log.Println(err)
|
|
|
|
return tx.Rollback()
|
|
|
|
}
|
|
|
|
// update the revision number in the database
|
|
|
|
if _, err := tx.Exec(insertRevisionStmt, current); err != nil {
|
|
|
|
log.Printf("Failed to register Revision Number %v\n", current)
|
|
|
|
log.Println(err)
|
|
|
|
return tx.Rollback()
|
|
|
|
}
|
|
|
|
|
|
|
|
log.Printf("Successfully upgraded to Revision %v\n", current)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return tx.Commit()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (m *Migration) down(target, current int64) error {
|
|
|
|
// create the database transaction
|
|
|
|
tx, err := m.db.Begin()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2014-03-08 12:19:28 +07:00
|
|
|
mg := Driver(tx)
|
2014-02-15 20:16:54 +07:00
|
|
|
|
2014-02-10 03:03:22 -07:00
|
|
|
// reverse the list of revisions
|
|
|
|
revs := []Revision{}
|
|
|
|
for _, rev := range m.revs {
|
|
|
|
revs = append([]Revision{rev}, revs...)
|
|
|
|
}
|
|
|
|
|
|
|
|
// loop through the (reversed) list of
|
|
|
|
// revisions and execute.
|
|
|
|
for _, rev := range revs {
|
|
|
|
if rev.Revision() > target {
|
|
|
|
current = rev.Revision()
|
|
|
|
// execute the revision Upgrade.
|
2014-03-08 12:19:28 +07:00
|
|
|
if err := rev.Down(mg); err != nil {
|
2014-02-15 22:17:22 +07:00
|
|
|
log.Printf("Failed to downgrade from Revision Number %v\n", current)
|
2014-02-10 03:03:22 -07:00
|
|
|
log.Println(err)
|
|
|
|
return tx.Rollback()
|
|
|
|
}
|
|
|
|
// update the revision number in the database
|
|
|
|
if _, err := tx.Exec(deleteRevisionStmt, current); err != nil {
|
|
|
|
log.Printf("Failed to unregistser Revision Number %v\n", current)
|
|
|
|
log.Println(err)
|
|
|
|
return tx.Rollback()
|
|
|
|
}
|
|
|
|
|
2014-02-15 22:17:22 +07:00
|
|
|
log.Printf("Successfully downgraded from Revision %v\n", current)
|
2014-02-10 03:03:22 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return tx.Commit()
|
|
|
|
}
|