mirror of
https://github.com/axllent/mailpit.git
synced 2025-01-22 03:39:59 +02:00
260 lines
5.8 KiB
Go
260 lines
5.8 KiB
Go
// Package storage handles all database actions
|
|
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"os"
|
|
"os/signal"
|
|
"path"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/axllent/mailpit/config"
|
|
"github.com/axllent/mailpit/internal/logger"
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/leporo/sqlf"
|
|
|
|
// sqlite - https://gitlab.com/cznic/sqlite
|
|
_ "modernc.org/sqlite"
|
|
|
|
// rqlite - https://github.com/rqlite/gorqlite | https://rqlite.io/
|
|
_ "github.com/rqlite/gorqlite/stdlib"
|
|
)
|
|
|
|
var (
|
|
db *sql.DB
|
|
dbFile string
|
|
dbIsTemp bool
|
|
sqlDriver string
|
|
dbLastAction time.Time
|
|
|
|
// zstd compression encoder & decoder
|
|
dbEncoder, _ = zstd.NewWriter(nil)
|
|
dbDecoder, _ = zstd.NewReader(nil)
|
|
)
|
|
|
|
// InitDB will initialise the database
|
|
func InitDB() error {
|
|
p := config.Database
|
|
var dsn string
|
|
|
|
if p == "" {
|
|
// when no path is provided then we create a temporary file
|
|
// which will get deleted on Close(), SIGINT or SIGTERM
|
|
p = fmt.Sprintf("%s-%d.db", path.Join(os.TempDir(), "mailpit"), time.Now().UnixNano())
|
|
dbIsTemp = true
|
|
sqlDriver = "sqlite"
|
|
dsn = p
|
|
logger.Log().Debugf("[db] using temporary database: %s", p)
|
|
} else if strings.HasPrefix(p, "http://") || strings.HasPrefix(p, "https://") {
|
|
sqlDriver = "rqlite"
|
|
dsn = p
|
|
logger.Log().Debugf("[db] opening rqlite database %s", p)
|
|
} else {
|
|
p = filepath.Clean(p)
|
|
sqlDriver = "sqlite"
|
|
dsn = fmt.Sprintf("file:%s?cache=shared", p)
|
|
logger.Log().Debugf("[db] opening database %s", p)
|
|
}
|
|
|
|
config.Database = p
|
|
|
|
if sqlDriver == "sqlite" {
|
|
if !isFile(p) {
|
|
// try create a file to ensure permissions
|
|
f, err := os.Create(p)
|
|
if err != nil {
|
|
return fmt.Errorf("[db] %s", err.Error())
|
|
}
|
|
_ = f.Close()
|
|
}
|
|
}
|
|
|
|
var err error
|
|
|
|
db, err = sql.Open(sqlDriver, dsn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := 1; i < 6; i++ {
|
|
if err := Ping(); err != nil {
|
|
logger.Log().Errorf("[db] %s", err.Error())
|
|
logger.Log().Infof("[db] reconnecting in 5 seconds (attempt %d/5)", i)
|
|
time.Sleep(5 * time.Second)
|
|
} else {
|
|
continue
|
|
}
|
|
}
|
|
|
|
// prevent "database locked" errors
|
|
// @see https://github.com/mattn/go-sqlite3#faq
|
|
db.SetMaxOpenConns(1)
|
|
|
|
if sqlDriver == "sqlite" {
|
|
// SQLite performance tuning (https://phiresky.github.io/blog/2020/sqlite-performance-tuning/)
|
|
_, err = db.Exec("PRAGMA journal_mode = WAL; PRAGMA synchronous = normal;")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// create tables if necessary & apply migrations
|
|
if err := dbApplySchemas(); err != nil {
|
|
return err
|
|
}
|
|
|
|
LoadTagFilters()
|
|
|
|
dbFile = p
|
|
dbLastAction = time.Now()
|
|
|
|
sigs := make(chan os.Signal, 1)
|
|
// catch all signals since not explicitly listing
|
|
// Program that will listen to the SIGINT and SIGTERM
|
|
// SIGINT will listen to CTRL-C.
|
|
// SIGTERM will be caught if kill command executed
|
|
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
|
|
// method invoked upon seeing signal
|
|
go func() {
|
|
s := <-sigs
|
|
fmt.Printf("[db] got %s signal, shutting down\n", s)
|
|
Close()
|
|
os.Exit(0)
|
|
}()
|
|
|
|
// auto-prune & delete
|
|
go dbCron()
|
|
|
|
go dataMigrations()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Tenant applies an optional prefix to the table name
|
|
func tenant(table string) string {
|
|
return fmt.Sprintf("%s%s", config.TenantID, table)
|
|
}
|
|
|
|
// Close will close the database, and delete if temporary
|
|
func Close() {
|
|
// on a fatal exit (eg: ports blocked), allow Mailpit to run migration tasks before closing the DB
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
if db != nil {
|
|
if err := db.Close(); err != nil {
|
|
logger.Log().Warn("[db] error closing database, ignoring")
|
|
}
|
|
}
|
|
|
|
// allow SQLite to finish closing DB & write WAL logs if local
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
if dbIsTemp && isFile(dbFile) {
|
|
logger.Log().Debugf("[db] deleting temporary file %s", dbFile)
|
|
if err := os.Remove(dbFile); err != nil {
|
|
logger.Log().Errorf("[db] %s", err.Error())
|
|
}
|
|
}
|
|
}
|
|
|
|
// Ping the database connection and return an error if unsuccessful
|
|
func Ping() error {
|
|
return db.Ping()
|
|
}
|
|
|
|
// StatsGet returns the total/unread statistics for a mailbox
|
|
func StatsGet() MailboxStats {
|
|
var (
|
|
total = CountTotal()
|
|
unread = CountUnread()
|
|
tags = GetAllTags()
|
|
)
|
|
|
|
dbLastAction = time.Now()
|
|
|
|
return MailboxStats{
|
|
Total: total,
|
|
Unread: unread,
|
|
Tags: tags,
|
|
}
|
|
}
|
|
|
|
// CountTotal returns the number of emails in the database
|
|
func CountTotal() float64 {
|
|
var total float64
|
|
|
|
_ = sqlf.From(tenant("mailbox")).
|
|
Select("COUNT(*)").To(&total).
|
|
QueryRowAndClose(context.TODO(), db)
|
|
|
|
return total
|
|
}
|
|
|
|
// CountUnread returns the number of emails in the database that are unread.
|
|
func CountUnread() float64 {
|
|
var total float64
|
|
|
|
_ = sqlf.From(tenant("mailbox")).
|
|
Select("COUNT(*)").To(&total).
|
|
Where("Read = ?", 0).
|
|
QueryRowAndClose(context.TODO(), db)
|
|
|
|
return total
|
|
}
|
|
|
|
// CountRead returns the number of emails in the database that are read.
|
|
func CountRead() float64 {
|
|
var total float64
|
|
|
|
_ = sqlf.From(tenant("mailbox")).
|
|
Select("COUNT(*)").To(&total).
|
|
Where("Read = ?", 1).
|
|
QueryRowAndClose(context.TODO(), db)
|
|
|
|
return total
|
|
}
|
|
|
|
// DbSize returns the size of the SQLite database.
|
|
func DbSize() float64 {
|
|
var total sql.NullFloat64
|
|
|
|
err := db.QueryRow("SELECT page_count * page_size AS size FROM pragma_page_count(), pragma_page_size()").Scan(&total)
|
|
|
|
if err != nil {
|
|
logger.Log().Errorf("[db] %s", err.Error())
|
|
return total.Float64
|
|
}
|
|
|
|
return total.Float64
|
|
}
|
|
|
|
// IsUnread returns whether a message is unread or not.
|
|
func IsUnread(id string) bool {
|
|
var unread int
|
|
|
|
_ = sqlf.From(tenant("mailbox")).
|
|
Select("COUNT(*)").To(&unread).
|
|
Where("Read = ?", 0).
|
|
Where("ID = ?", id).
|
|
QueryRowAndClose(context.TODO(), db)
|
|
|
|
return unread == 1
|
|
}
|
|
|
|
// MessageIDExists checks whether a Message-ID exists in the DB
|
|
func MessageIDExists(id string) bool {
|
|
var total int
|
|
|
|
_ = sqlf.From(tenant("mailbox")).
|
|
Select("COUNT(*)").To(&total).
|
|
Where("MessageID = ?", id).
|
|
QueryRowAndClose(context.TODO(), db)
|
|
|
|
return total != 0
|
|
}
|