package storage import ( "bytes" "context" "database/sql" "embed" "encoding/json" "log" "path" "sort" "strings" "text/template" "time" "github.com/axllent/mailpit/internal/logger" "github.com/axllent/semver" "github.com/leporo/sqlf" ) //go:embed schemas/* var schemaScripts embed.FS // Create tables and apply schemas if required func dbApplySchemas() error { if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS ` + tenant("schemas") + ` (Version TEXT PRIMARY KEY NOT NULL)`); err != nil { return err } var legacyMigrationTable int err := db.QueryRow(`SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name=?)`, tenant("darwin_migrations")).Scan(&legacyMigrationTable) if err != nil { return err } if legacyMigrationTable == 1 { rows, err := db.Query(`SELECT version FROM ` + tenant("darwin_migrations")) if err != nil { return err } legacySchemas := []string{} for rows.Next() { var oldID string if err := rows.Scan(&oldID); err == nil { legacySchemas = append(legacySchemas, semver.MajorMinor(oldID)+"."+semver.Patch(oldID)) } } legacySchemas = semver.SortMin(legacySchemas) for _, v := range legacySchemas { var migrated int err := db.QueryRow(`SELECT EXISTS(SELECT 1 FROM `+tenant("schemas")+` WHERE Version = ?)`, v).Scan(&migrated) if err != nil { return err } if migrated == 0 { // copy to tenant("schemas") if _, err := db.Exec(`INSERT INTO `+tenant("schemas")+` (Version) VALUES (?)`, v); err != nil { return err } } } // delete legacy migration database after 01/10/2024 if time.Now().After(time.Date(2024, 10, 1, 0, 0, 0, 0, time.Local)) { if _, err := db.Exec(`DROP TABLE IF EXISTS ` + tenant("darwin_migrations")); err != nil { return err } } } schemaFiles, err := schemaScripts.ReadDir("schemas") if err != nil { log.Fatal(err) } temp := template.New("") temp.Funcs( template.FuncMap{ "tenant": tenant, }, ) type schema struct { Name string Semver string } scripts := []schema{} for _, s := range schemaFiles { if !s.Type().IsRegular() || !strings.HasSuffix(s.Name(), ".sql") { continue } schemaID := strings.TrimRight(s.Name(), ".sql") if !semver.IsValid(schemaID) { logger.Log().Warnf("[db] invalid schema name: %s", s.Name()) continue } script := schema{s.Name(), semver.MajorMinor(schemaID) + "." + semver.Patch(schemaID)} scripts = append(scripts, script) } // sort schemas by semver, low to high sort.Slice(scripts, func(i, j int) bool { return semver.Compare(scripts[j].Semver, scripts[i].Semver) == 1 }) for _, s := range scripts { var complete int err := db.QueryRow(`SELECT EXISTS(SELECT 1 FROM `+tenant("schemas")+` WHERE Version = ?)`, s.Semver).Scan(&complete) if err != nil { return err } if complete == 1 { // already completed, ignore continue } // use path.Join for Windows compatibility, see https://github.com/golang/go/issues/44305 b, err := schemaScripts.ReadFile(path.Join("schemas", s.Name)) if err != nil { return err } // parse import script t1, err := temp.Parse(string(b)) if err != nil { return err } buf := new(bytes.Buffer) err = t1.Execute(buf, nil) if _, err := db.Exec(buf.String()); err != nil { return err } if _, err := db.Exec(`INSERT INTO `+tenant("schemas")+` (Version) VALUES (?)`, s.Semver); err != nil { return err } logger.Log().Debugf("[db] applied schema: %s", s.Name) } return nil } // These functions are used to migrate data formats/structure on startup. func dataMigrations() { // ensure DeletedSize has a value if empty if SettingGet("DeletedSize") == "" { _ = SettingPut("DeletedSize", "0") } migrateTagsToManyMany() } // Migrate tags to ManyMany structure // Migration task implemented 12/2023 // TODO: Can be removed end 06/2024 and Tags column & index dropped from mailbox func migrateTagsToManyMany() { toConvert := make(map[string][]string) q := sqlf. Select("ID, Tags"). From(tenant("mailbox")). Where("Tags != ?", "[]"). Where("Tags IS NOT NULL") if err := q.QueryAndClose(context.TODO(), db, func(row *sql.Rows) { var id string var jsonTags string if err := row.Scan(&id, &jsonTags); err != nil { logger.Log().Errorf("[migration] %s", err.Error()) return } tags := []string{} if err := json.Unmarshal([]byte(jsonTags), &tags); err != nil { logger.Log().Errorf("[json] %s", err.Error()) return } toConvert[id] = tags }); err != nil { logger.Log().Errorf("[migration] %s", err.Error()) } if len(toConvert) > 0 { logger.Log().Infof("[migration] converting %d message tags", len(toConvert)) for id, tags := range toConvert { if err := SetMessageTags(id, tags); err != nil { logger.Log().Errorf("[migration] %s", err.Error()) } else { if _, err := sqlf.Update(tenant("mailbox")). Set("Tags", nil). Where("ID = ?", id). ExecAndClose(context.TODO(), db); err != nil { logger.Log().Errorf("[migration] %s", err.Error()) } } } logger.Log().Info("[migration] tags conversion complete") } // set all legacy `[]` tags to NULL if _, err := sqlf.Update(tenant("mailbox")). Set("Tags", nil). Where("Tags = ?", "[]"). ExecAndClose(context.TODO(), db); err != nil { logger.Log().Errorf("[migration] %s", err.Error()) } }