1
0
mirror of https://github.com/woodpecker-ci/woodpecker.git synced 2025-01-23 17:53:23 +02:00

114 lines
3.1 KiB
Go
Raw Normal View History

package migrate
import (
"database/sql"
"fmt"
"strings"
)
type mysqlDriver struct {
Tx *sql.Tx
}
func MySQL(tx *sql.Tx) *MigrationDriver {
return &MigrationDriver{
Tx: tx,
Operation: &mysqlDriver{Tx: tx},
T: &columnType{
AttrMap: map[int]string{AUTOINCREMENT: "AUTO_INCREMENT"},
},
}
}
func (m *mysqlDriver) CreateTable(tableName string, args []string) (sql.Result, error) {
2014-03-27 17:43:07 +07:00
return m.Tx.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s) ROW_FORMAT=DYNAMIC",
tableName, strings.Join(args, ", ")))
}
func (m *mysqlDriver) RenameTable(tableName, newName string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s RENAME TO %s", tableName, newName))
}
func (m *mysqlDriver) DropTable(tableName string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName))
}
func (m *mysqlDriver) AddColumn(tableName, columnSpec string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN (%s)", tableName, columnSpec))
}
func (m *mysqlDriver) ChangeColumn(tableName, columnName, newSpecs string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s MODIFY %s %s", tableName, columnName, newSpecs))
}
func (m *mysqlDriver) DropColumns(tableName string, columnsToDrop ...string) (sql.Result, error) {
2014-03-27 17:43:07 +07:00
if len(columnsToDrop) == 0 {
return nil, fmt.Errorf("No columns to drop.")
}
for k, v := range columnsToDrop {
columnsToDrop[k] = fmt.Sprintf("DROP %s", v)
}
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(columnsToDrop, ", ")))
}
func (m *mysqlDriver) RenameColumns(tableName string, columnChanges map[string]string) (sql.Result, error) {
var columns []string
tableSQL, err := m.getTableDefinition(tableName)
if err != nil {
return nil, err
}
columns, err = fetchColumns(tableSQL)
if err != nil {
return nil, err
}
var colspec []string
for k, v := range columnChanges {
for _, col := range columns {
col = strings.Trim(col, " \n")
cols := strings.SplitN(col, " ", 2)
if quote(k) == cols[0] {
colspec = append(colspec, fmt.Sprintf("CHANGE %s %s %s", k, v, cols[1]))
break
}
}
}
return m.Tx.Exec(fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(colspec, ", ")))
}
func (m *mysqlDriver) AddIndex(tableName string, columns []string, flags ...string) (sql.Result, error) {
flag := ""
if len(flags) > 0 {
switch strings.ToUpper(flags[0]) {
case "UNIQUE":
fallthrough
case "FULLTEXT":
fallthrough
case "SPATIAL":
flag = flags[0]
}
}
return m.Tx.Exec(fmt.Sprintf("CREATE %s INDEX %s ON %s (%s)", flag,
indexName(tableName, columns), tableName, strings.Join(columns, ", ")))
}
func (m *mysqlDriver) DropIndex(tableName string, columns []string) (sql.Result, error) {
return m.Tx.Exec(fmt.Sprintf("DROP INDEX %s on %s", indexName(tableName, columns), tableName))
}
func (m *mysqlDriver) getTableDefinition(tableName string) (string, error) {
var name, def string
st := fmt.Sprintf("SHOW CREATE TABLE %s", tableName)
if err := m.Tx.QueryRow(st).Scan(&name, &def); err != nil {
return "", err
}
return def, nil
}
func quote(name string) string {
return fmt.Sprintf("`%s`", name)
}