// Package storage handles all database actions
package storage

import (
	"context"
	"database/sql"
	"fmt"
	"os"
	"os/signal"
	"path"
	"path/filepath"
	"strings"
	"syscall"
	"time"

	"github.com/axllent/mailpit/config"
	"github.com/axllent/mailpit/internal/logger"
	"github.com/klauspost/compress/zstd"
	"github.com/leporo/sqlf"

	// sqlite - https://gitlab.com/cznic/sqlite
	_ "modernc.org/sqlite"

	// rqlite - https://github.com/rqlite/gorqlite | https://rqlite.io/
	_ "github.com/rqlite/gorqlite/stdlib"
)

var (
	db           *sql.DB
	dbFile       string
	dbIsTemp     bool
	sqlDriver    string
	dbLastAction time.Time

	// zstd compression encoder & decoder
	dbEncoder, _ = zstd.NewWriter(nil)
	dbDecoder, _ = zstd.NewReader(nil)
)

// InitDB will initialise the database
func InitDB() error {
	p := config.Database
	var dsn string

	if p == "" {
		// when no path is provided then we create a temporary file
		// which will get deleted on Close(), SIGINT or SIGTERM
		p = fmt.Sprintf("%s-%d.db", path.Join(os.TempDir(), "mailpit"), time.Now().UnixNano())
		dbIsTemp = true
		sqlDriver = "sqlite"
		dsn = p
		logger.Log().Debugf("[db] using temporary database: %s", p)
	} else if strings.HasPrefix(p, "http://") || strings.HasPrefix(p, "https://") {
		sqlDriver = "rqlite"
		dsn = p
		logger.Log().Debugf("[db] opening rqlite database %s", p)
	} else {
		p = filepath.Clean(p)
		sqlDriver = "sqlite"
		dsn = fmt.Sprintf("file:%s?cache=shared", p)
		logger.Log().Debugf("[db] opening database %s", p)
	}

	config.Database = p

	var err error

	db, err = sql.Open(sqlDriver, dsn)
	if err != nil {
		return err
	}

	for i := 1; i < 6; i++ {
		if err := Ping(); err != nil {
			logger.Log().Errorf("[db] %s", err.Error())
			logger.Log().Infof("[db] reconnecting in 5 seconds (%d/5)", i)
			time.Sleep(5 * time.Second)
		} else {
			continue
		}
	}

	// prevent "database locked" errors
	// @see https://github.com/mattn/go-sqlite3#faq
	db.SetMaxOpenConns(1)

	if sqlDriver == "sqlite" {
		// SQLite performance tuning (https://phiresky.github.io/blog/2020/sqlite-performance-tuning/)
		_, err = db.Exec("PRAGMA journal_mode = WAL; PRAGMA synchronous = normal;")
		if err != nil {
			return err
		}
	}

	// create tables if necessary & apply migrations
	if err := dbApplySchemas(); err != nil {
		return err
	}

	dbFile = p
	dbLastAction = time.Now()

	sigs := make(chan os.Signal, 1)
	// catch all signals since not explicitly listing
	// Program that will listen to the SIGINT and SIGTERM
	// SIGINT will listen to CTRL-C.
	// SIGTERM will be caught if kill command executed
	signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
	// method invoked upon seeing signal
	go func() {
		s := <-sigs
		fmt.Printf("[db] got %s signal, shutting down\n", s)
		Close()
		os.Exit(0)
	}()

	// auto-prune & delete
	go dbCron()

	go dataMigrations()

	return nil
}

// Tenant applies an optional prefix to the table name
func tenant(table string) string {
	return fmt.Sprintf("%s%s", config.TenantID, table)
}

// Close will close the database, and delete if a temporary table
func Close() {
	if db != nil {
		if err := db.Close(); err != nil {
			logger.Log().Warn("[db] error closing database, ignoring")
		}
	}

	if dbIsTemp && isFile(dbFile) {
		logger.Log().Debugf("[db] deleting temporary file %s", dbFile)
		if err := os.Remove(dbFile); err != nil {
			logger.Log().Errorf("[db] %s", err.Error())
		}
	}
}

// Ping the database connection and return an error if unsuccessful
func Ping() error {
	return db.Ping()
}

// StatsGet returns the total/unread statistics for a mailbox
func StatsGet() MailboxStats {
	var (
		total  = CountTotal()
		unread = CountUnread()
		tags   = GetAllTags()
	)

	dbLastAction = time.Now()

	return MailboxStats{
		Total:  total,
		Unread: unread,
		Tags:   tags,
	}
}

// CountTotal returns the number of emails in the database
func CountTotal() float64 {
	var total float64

	_ = sqlf.From(tenant("mailbox")).
		Select("COUNT(*)").To(&total).
		QueryRowAndClose(context.TODO(), db)

	return total
}

// CountUnread returns the number of emails in the database that are unread.
func CountUnread() float64 {
	var total float64

	_ = sqlf.From(tenant("mailbox")).
		Select("COUNT(*)").To(&total).
		Where("Read = ?", 0).
		QueryRowAndClose(context.TODO(), db)

	return total
}

// CountRead returns the number of emails in the database that are read.
func CountRead() float64 {
	var total float64

	_ = sqlf.From(tenant("mailbox")).
		Select("COUNT(*)").To(&total).
		Where("Read = ?", 1).
		QueryRowAndClose(context.TODO(), db)

	return total
}

// DbSize returns the size of the SQLite database.
func DbSize() float64 {
	var total sql.NullFloat64

	err := db.QueryRow("SELECT page_count * page_size AS size FROM pragma_page_count(), pragma_page_size()").Scan(&total)

	if err != nil {
		logger.Log().Errorf("[db] %s", err.Error())
		return total.Float64
	}

	return total.Float64
}

// IsUnread returns whether a message is unread or not.
func IsUnread(id string) bool {
	var unread int

	_ = sqlf.From(tenant("mailbox")).
		Select("COUNT(*)").To(&unread).
		Where("Read = ?", 0).
		Where("ID = ?", id).
		QueryRowAndClose(context.TODO(), db)

	return unread == 1
}

// MessageIDExists checks whether a Message-ID exists in the DB
func MessageIDExists(id string) bool {
	var total int

	_ = sqlf.From(tenant("mailbox")).
		Select("COUNT(*)").To(&total).
		Where("MessageID = ?", id).
		QueryRowAndClose(context.TODO(), db)

	return total != 0
}