// Package storage handles all database actions package storage import ( "context" "database/sql" "fmt" "os" "os/signal" "path" "path/filepath" "syscall" "time" "github.com/axllent/mailpit/config" "github.com/axllent/mailpit/internal/logger" "github.com/klauspost/compress/zstd" "github.com/leporo/sqlf" // sqlite (native) - https://gitlab.com/cznic/sqlite _ "modernc.org/sqlite" ) var ( db *sql.DB dbFile string dbIsTemp bool dbLastAction time.Time // zstd compression encoder & decoder dbEncoder, _ = zstd.NewWriter(nil) dbDecoder, _ = zstd.NewReader(nil) ) // InitDB will initialise the database func InitDB() error { p := config.DataFile if p == "" { // when no path is provided then we create a temporary file // which will get deleted on Close(), SIGINT or SIGTERM p = fmt.Sprintf("%s-%d.db", path.Join(os.TempDir(), "mailpit"), time.Now().UnixNano()) dbIsTemp = true logger.Log().Debugf("[db] using temporary database: %s", p) } else { p = filepath.Clean(p) } config.DataFile = p logger.Log().Debugf("[db] opening database %s", p) var err error dsn := fmt.Sprintf("file:%s?cache=shared", p) db, err = sql.Open("sqlite", dsn) if err != nil { return err } // prevent "database locked" errors // @see https://github.com/mattn/go-sqlite3#faq db.SetMaxOpenConns(1) // SQLite performance tuning (https://phiresky.github.io/blog/2020/sqlite-performance-tuning/) _, err = db.Exec("PRAGMA journal_mode = WAL; PRAGMA synchronous = normal;") if err != nil { return err } // create tables if necessary & apply migrations if err := dbApplyMigrations(); err != nil { return err } dbFile = p dbLastAction = time.Now() sigs := make(chan os.Signal, 1) // catch all signals since not explicitly listing // Program that will listen to the SIGINT and SIGTERM // SIGINT will listen to CTRL-C. // SIGTERM will be caught if kill command executed signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) // method invoked upon seeing signal go func() { s := <-sigs fmt.Printf("[db] got %s signal, shutting down\n", s) Close() os.Exit(0) }() // auto-prune & delete go dbCron() go dataMigrations() return nil } // Close will close the database, and delete if a temporary table func Close() { if db != nil { if err := db.Close(); err != nil { logger.Log().Warn("[db] error closing database, ignoring") } } if dbIsTemp && isFile(dbFile) { logger.Log().Debugf("[db] deleting temporary file %s", dbFile) if err := os.Remove(dbFile); err != nil { logger.Log().Errorf("[db] %s", err.Error()) } } } // Ping the database connection and return an error if unsuccessful func Ping() error { return db.Ping() } // StatsGet returns the total/unread statistics for a mailbox func StatsGet() MailboxStats { var ( total = CountTotal() unread = CountUnread() tags = GetAllTags() ) dbLastAction = time.Now() return MailboxStats{ Total: total, Unread: unread, Tags: tags, } } // CountTotal returns the number of emails in the database func CountTotal() int { var total int _ = sqlf.From("mailbox"). Select("COUNT(*)").To(&total). QueryRowAndClose(context.TODO(), db) return total } // CountUnread returns the number of emails in the database that are unread. func CountUnread() int { var total int _ = sqlf.From("mailbox"). Select("COUNT(*)").To(&total). Where("Read = ?", 0). QueryRowAndClose(context.TODO(), db) return total } // CountRead returns the number of emails in the database that are read. func CountRead() int { var total int _ = sqlf.From("mailbox"). Select("COUNT(*)").To(&total). Where("Read = ?", 1). QueryRowAndClose(context.TODO(), db) return total } // IsUnread returns whether a message is unread or not. func IsUnread(id string) bool { var unread int _ = sqlf.From("mailbox"). Select("COUNT(*)").To(&unread). Where("Read = ?", 0). Where("ID = ?", id). QueryRowAndClose(context.TODO(), db) return unread == 1 } // MessageIDExists checks whether a Message-ID exists in the DB func MessageIDExists(id string) bool { var total int _ = sqlf.From("mailbox"). Select("COUNT(*)").To(&total). Where("MessageID = ?", id). QueryRowAndClose(context.TODO(), db) return total != 0 }