1
0
mirror of https://github.com/axllent/mailpit.git synced 2025-01-28 03:56:50 +02:00

Chore: Convert to many-to-many message tag relationships

This commit is contained in:
Ralph Slooten 2024-01-01 23:46:34 +13:00
parent e032d27ef6
commit 06ca217cde
10 changed files with 392 additions and 271 deletions

View File

@ -14,7 +14,6 @@ import (
"os/signal" "os/signal"
"path" "path"
"path/filepath" "path/filepath"
"sort"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@ -181,11 +180,6 @@ func Store(body []byte) (string, error) {
tagData := uniqueTagsFromString(tagStr) tagData := uniqueTagsFromString(tagStr)
tagJSON, err := json.Marshal(tagData)
if err != nil {
return "", err
}
// begin a transaction to ensure both the message // begin a transaction to ensure both the message
// and data are stored successfully // and data are stored successfully
ctx := context.Background() ctx := context.Background()
@ -204,8 +198,8 @@ func Store(body []byte) (string, error) {
snippet := tools.CreateSnippet(env.Text, env.HTML) snippet := tools.CreateSnippet(env.Text, env.HTML)
// insert mail summary data // insert mail summary data
_, err = tx.Exec("INSERT INTO mailbox(Created, ID, MessageID, Subject, Metadata, Size, Inline, Attachments, SearchText, Tags, Read, Snippet) values(?,?,?,?,?,?,?,?,?,?,0, ?)", _, err = tx.Exec("INSERT INTO mailbox(Created, ID, MessageID, Subject, Metadata, Size, Inline, Attachments, SearchText, Read, Snippet) values(?,?,?,?,?,?,?,?,?,0,?)",
created.UnixMilli(), id, messageID, subject, string(summaryJSON), size, inline, attachments, searchText, string(tagJSON), snippet) created.UnixMilli(), id, messageID, subject, string(summaryJSON), size, inline, attachments, searchText, snippet)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -221,6 +215,13 @@ func Store(body []byte) (string, error) {
return "", err return "", err
} }
if len(tagData) > 0 {
// set tags after tx.Commit()
if err := SetMessageTags(id, tagData); err != nil {
return "", err
}
}
c := &MessageSummary{} c := &MessageSummary{}
if err := json.Unmarshal(summaryJSON, c); err != nil { if err := json.Unmarshal(summaryJSON, c); err != nil {
return "", err return "", err
@ -249,10 +250,11 @@ func Store(body []byte) (string, error) {
// sorted latest to oldest // sorted latest to oldest
func List(start, limit int) ([]MessageSummary, error) { func List(start, limit int) ([]MessageSummary, error) {
results := []MessageSummary{} results := []MessageSummary{}
tsStart := time.Now()
q := sqlf.From("mailbox"). q := sqlf.From("mailbox m").
Select(`Created, ID, MessageID, Subject, Metadata, Size, Attachments, Read, Tags, Snippet`). Select(`m.Created, m.ID, m.MessageID, m.Subject, m.Metadata, m.Size, m.Attachments, m.Read, m.Snippet`).
OrderBy("Created DESC"). OrderBy("m.Created DESC").
Limit(limit). Limit(limit).
Offset(start) Offset(start)
@ -264,12 +266,11 @@ func List(start, limit int) ([]MessageSummary, error) {
var metadata string var metadata string
var size int var size int
var attachments int var attachments int
var tags string
var read int var read int
var snippet string var snippet string
em := MessageSummary{} em := MessageSummary{}
if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &tags, &snippet); err != nil { if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &snippet); err != nil {
logger.Log().Error(err) logger.Log().Error(err)
return return
} }
@ -279,11 +280,6 @@ func List(start, limit int) ([]MessageSummary, error) {
return return
} }
if err := json.Unmarshal([]byte(tags), &em.Tags); err != nil {
logger.Log().Error(err)
return
}
em.Created = time.UnixMilli(created) em.Created = time.UnixMilli(created)
em.ID = id em.ID = id
em.MessageID = messageID em.MessageID = messageID
@ -298,8 +294,17 @@ func List(start, limit int) ([]MessageSummary, error) {
return results, err return results, err
} }
// set tags for listed messages only
for i, m := range results {
results[i].Tags = getMessageTags(m.ID)
}
dbLastAction = time.Now() dbLastAction = time.Now()
elapsed := time.Since(tsStart)
logger.Log().Debugf("[db] list INBOX in %s", elapsed)
return results, nil return results, nil
} }
@ -616,6 +621,10 @@ func DeleteOneMessage(id string) error {
logger.Log().Debugf("[db] deleted message %s", id) logger.Log().Debugf("[db] deleted message %s", id)
} }
if err := DeleteAllMessageTags(id); err != nil {
return err
}
dbLastAction = time.Now() dbLastAction = time.Now()
dbDataDeleted = true dbDataDeleted = true
@ -655,6 +664,16 @@ func DeleteAllMessages() error {
return err return err
} }
_, err = tx.Exec("DELETE FROM tags")
if err != nil {
return err
}
_, err = tx.Exec("DELETE FROM message_tags")
if err != nil {
return err
}
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return err return err
} }
@ -676,38 +695,19 @@ func DeleteAllMessages() error {
// GetAllTags returns all used tags // GetAllTags returns all used tags
func GetAllTags() []string { func GetAllTags() []string {
q := sqlf.From("mailbox").
Select(`DISTINCT Tags`).
Where("Tags != ?", "[]")
var tags = []string{} var tags = []string{}
var name string
if err := q.QueryAndClose(nil, db, func(row *sql.Rows) { if err := sqlf.
var tagData string Select(`DISTINCT Name`).
t := []string{} From("tags").To(&name).
OrderBy("Name").
if err := row.Scan(&tagData); err != nil { QueryAndClose(nil, db, func(row *sql.Rows) {
logger.Log().Error(err) tags = append(tags, name)
return }); err != nil {
}
if err := json.Unmarshal([]byte(tagData), &t); err != nil {
logger.Log().Error(err)
return
}
for _, tag := range t {
if !inArray(tag, tags) {
tags = append(tags, tag)
}
}
}); err != nil {
logger.Log().Error(err) logger.Log().Error(err)
} }
sort.Strings(tags)
return tags return tags
} }

View File

@ -1,200 +1,73 @@
package storage package storage
import ( // These functions are used to migrate data formats/structure on startup.
"bytes"
"context" import (
"database/sql" "database/sql"
"strings" "encoding/json"
"time"
"github.com/axllent/mailpit/config"
"github.com/axllent/mailpit/internal/logger" "github.com/axllent/mailpit/internal/logger"
"github.com/jhillyerd/enmime"
"github.com/leporo/sqlf" "github.com/leporo/sqlf"
"golang.org/x/text/language"
"golang.org/x/text/message"
) )
func dataMigrations() { func dataMigrations() {
updateOrderByCreatedTask() migrateTagsToManyMany()
assignMessageIDsTask()
} }
// Update Created column using Created metadata datetime <= v1.6.5 // Migrate tags to ManyMany structure
// Migration task implemented 05/2023 - can be removed end 2023 // Migration task implemented 12/2023
func updateOrderByCreatedTask() { // Can be removed end 06/2024 and Tags column & index dropped from mailbox
q := sqlf.From("mailbox"). func migrateTagsToManyMany() {
Select("ID"). toConvert := make(map[string][]string)
Select(`json_extract(Metadata, '$.Created') as Created`). q := sqlf.
Where("Created < ?", 1155000600) Select("ID, Tags").
From("mailbox").
toUpdate := make(map[string]int64) Where("Tags != ?", "[]").
p := message.NewPrinter(language.English) Where("Tags IS NOT NULL")
if err := q.QueryAndClose(nil, db, func(row *sql.Rows) { if err := q.QueryAndClose(nil, db, func(row *sql.Rows) {
var id string var id string
var ts sql.NullString var jsonTags string
if err := row.Scan(&id, &ts); err != nil { if err := row.Scan(&id, &jsonTags); err != nil {
logger.Log().Error("[migration]", err) logger.Log().Errorf("[migration] %s", err.Error())
return return
} }
if !ts.Valid { tags := []string{}
logger.Log().Errorf("[migration] cannot get Created timestamp from %s", id)
if err := json.Unmarshal([]byte(jsonTags), &tags); err != nil {
logger.Log().Error(err)
return return
} }
t, _ := time.Parse(time.RFC3339Nano, ts.String) toConvert[id] = tags
toUpdate[id] = t.UnixMilli()
}); err != nil { }); err != nil {
logger.Log().Error("[migration]", err) logger.Log().Errorf("[migration] %s", err.Error())
return
} }
total := len(toUpdate) if len(toConvert) > 0 {
logger.Log().Infof("[migration] converting %d message tags", len(toConvert))
if total == 0 { for id, tags := range toConvert {
return if err := SetMessageTags(id, tags); err != nil {
} logger.Log().Errorf("[migration] %s", err.Error())
} else {
logger.Log().Infof("[migration] updating timestamp for %s messages", p.Sprintf("%d", len(toUpdate))) if _, err := sqlf.Update("mailbox").
Set("Tags", nil).
// begin a transaction Where("ID = ?", id).
ctx := context.Background() ExecAndClose(nil, db); err != nil {
tx, err := db.BeginTx(ctx, nil) logger.Log().Errorf("[migration] %s", err.Error())
if err != nil { }
logger.Log().Error("[migration]", err) }
return
}
// roll back if it fails
defer tx.Rollback()
var blockTime = time.Now()
count := 0
for id, ts := range toUpdate {
count++
_, err := tx.Exec(`UPDATE mailbox SET Created = ? WHERE ID = ?`, ts, id)
if err != nil {
logger.Log().Error("[migration]", err)
} }
if count%1000 == 0 { logger.Log().Info("[migration] tags conversion complete")
percent := (100 * count) / total
logger.Log().Infof("[migration] updated timestamp for 1,000 messages [%d%%] in %s", percent, time.Since(blockTime))
blockTime = time.Now()
}
} }
logger.Log().Infof("[migration] commit %s changes", p.Sprintf("%d", count)) // set all legacy `[]` tags to NULL
if _, err := sqlf.Update("mailbox").
if err := tx.Commit(); err != nil { Set("Tags", nil).
logger.Log().Error("[migration]", err) Where("Tags = ?", "[]").
return ExecAndClose(nil, db); err != nil {
logger.Log().Errorf("[migration] %s", err.Error())
} }
logger.Log().Infof("[migration] complete")
}
// Find any messages without a stored Message-ID and update it <= v1.6.5
// Migration task implemented 05/2023 - can be removed end 2023
func assignMessageIDsTask() {
if !config.IgnoreDuplicateIDs {
return
}
q := sqlf.From("mailbox").
Select("ID").
Where("MessageID = ''")
missingIDS := make(map[string]string)
if err := q.QueryAndClose(nil, db, func(row *sql.Rows) {
var id string
if err := row.Scan(&id); err != nil {
logger.Log().Error("[migration]", err)
return
}
missingIDS[id] = ""
}); err != nil {
logger.Log().Error("[migration]", err)
}
if len(missingIDS) == 0 {
return
}
var count int
var blockTime = time.Now()
p := message.NewPrinter(language.English)
total := len(missingIDS)
logger.Log().Infof("[migration] extracting Message-IDs for %s messages", p.Sprintf("%d", total))
for id := range missingIDS {
raw, err := GetMessageRaw(id)
if err != nil {
logger.Log().Error("[migration]", err)
continue
}
r := bytes.NewReader(raw)
env, err := enmime.ReadEnvelope(r)
if err != nil {
logger.Log().Error("[migration]", err)
continue
}
messageID := strings.Trim(env.GetHeader("Message-ID"), "<>")
missingIDS[id] = messageID
count++
if count%1000 == 0 {
percent := (100 * count) / total
logger.Log().Infof("[migration] extracted 1,000 Message-IDs [%d%%] in %s", percent, time.Since(blockTime))
blockTime = time.Now()
}
}
// begin a transaction
ctx := context.Background()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Log().Error("[migration]", err)
return
}
// roll back if it fails
defer tx.Rollback()
count = 0
for id, mid := range missingIDS {
_, err = tx.Exec(`UPDATE mailbox SET MessageID = ? WHERE ID = ?`, mid, id)
if err != nil {
logger.Log().Error("[migration]", err)
}
count++
if count%1000 == 0 {
percent := (100 * count) / total
logger.Log().Infof("[migration] stored 1,000 Message-IDs [%d%%] in %s", percent, time.Since(blockTime))
blockTime = time.Now()
}
}
logger.Log().Infof("[migration] commit %s changes", p.Sprintf("%d", count))
if err := tx.Commit(); err != nil {
logger.Log().Error("[migration]", err)
return
}
logger.Log().Infof("[migration] complete")
} }

View File

@ -71,6 +71,23 @@ var (
Description: "Create snippet column", Description: "Create snippet column",
Script: `ALTER TABLE mailbox ADD COLUMN Snippet Text NOT NULL DEFAULT '';`, Script: `ALTER TABLE mailbox ADD COLUMN Snippet Text NOT NULL DEFAULT '';`,
}, },
{
Version: 1.4,
Description: "Create tag tables",
Script: `CREATE TABLE IF NOT EXISTS tags (
ID INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
Name TEXT COLLATE NOCASE
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_tag_name ON tags (Name);
CREATE TABLE IF NOT EXISTS message_tags(
Key INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
ID TEXT REFERENCES mailbox(ID),
TagID INT REFERENCES tags(ID)
);
CREATE INDEX IF NOT EXISTS idx_message_tag_id ON message_tags (ID);
CREATE INDEX IF NOT EXISTS idx_message_tag_tagid ON message_tags (TagID);`,
},
} }
) )

View File

@ -37,13 +37,12 @@ func Search(search string, start, limit int) ([]MessageSummary, int, error) {
var metadata string var metadata string
var size int var size int
var attachments int var attachments int
var tags string
var snippet string var snippet string
var read int var read int
var ignore string var ignore string
em := MessageSummary{} em := MessageSummary{}
if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &tags, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil { if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil {
logger.Log().Error(err) logger.Log().Error(err)
return return
} }
@ -53,11 +52,6 @@ func Search(search string, start, limit int) ([]MessageSummary, int, error) {
return return
} }
if err := json.Unmarshal([]byte(tags), &em.Tags); err != nil {
logger.Log().Error(err)
return
}
em.Created = time.UnixMilli(created) em.Created = time.UnixMilli(created)
em.ID = id em.ID = id
em.MessageID = messageID em.MessageID = messageID
@ -85,6 +79,11 @@ func Search(search string, start, limit int) ([]MessageSummary, int, error) {
results = allResults[start:end] results = allResults[start:end]
} }
// set tags for listed messages only
for i, m := range results {
results[i].Tags = getMessageTags(m.ID)
}
elapsed := time.Since(tsStart) elapsed := time.Since(tsStart)
logger.Log().Debugf("[db] search for \"%s\" in %s", search, elapsed) logger.Log().Debugf("[db] search for \"%s\" in %s", search, elapsed)
@ -109,12 +108,12 @@ func DeleteSearch(search string) error {
var metadata string var metadata string
var size int var size int
var attachments int var attachments int
var tags string // var tags string
var read int var read int
var snippet string var snippet string
var ignore string var ignore string
if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &tags, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil { if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil {
logger.Log().Error(err) logger.Log().Error(err)
return return
} }
@ -172,10 +171,21 @@ func DeleteSearch(search string) error {
if err != nil { if err != nil {
return err return err
} }
sqlDelete3 := `DELETE FROM message_tags WHERE ID IN (?` + strings.Repeat(",?", len(ids)-1) + `)`
_, err = tx.Exec(sqlDelete3, delIDs...)
if err != nil {
return err
}
} }
err = tx.Commit() err = tx.Commit()
if err := pruneUnusedTags(); err != nil {
return err
}
if err == nil { if err == nil {
logger.Log().Debugf("[db] deleted %d messages matching %s", total, search) logger.Log().Debugf("[db] deleted %d messages matching %s", total, search)
} }
@ -195,13 +205,15 @@ func searchQueryBuilder(searchString string) *sqlf.Stmt {
// group strings with quotes as a single argument and remove quotes // group strings with quotes as a single argument and remove quotes
args := tools.ArgsParser(searchString) args := tools.ArgsParser(searchString)
q := sqlf.From("mailbox"). q := sqlf.From("mailbox m").
Select(`Created, ID, MessageID, Subject, Metadata, Size, Attachments, Read, Tags, Snippet, Select(`m.Created, m.ID, m.MessageID, m.Subject, m.Metadata, m.Size, m.Attachments, m.Read,
m.Snippet,
IFNULL(json_extract(Metadata, '$.To'), '{}') as ToJSON, IFNULL(json_extract(Metadata, '$.To'), '{}') as ToJSON,
IFNULL(json_extract(Metadata, '$.From'), '{}') as FromJSON, IFNULL(json_extract(Metadata, '$.From'), '{}') as FromJSON,
IFNULL(json_extract(Metadata, '$.Cc'), '{}') as CcJSON, IFNULL(json_extract(Metadata, '$.Cc'), '{}') as CcJSON,
IFNULL(json_extract(Metadata, '$.Bcc'), '{}') as BccJSON IFNULL(json_extract(Metadata, '$.Bcc'), '{}') as BccJSON
`).OrderBy("Created DESC") `).
OrderBy("m.Created DESC")
for _, w := range args { for _, w := range args {
if cleanString(w) == "" { if cleanString(w) == "" {
@ -278,9 +290,9 @@ func searchQueryBuilder(searchString string) *sqlf.Stmt {
w = cleanString(w[4:]) w = cleanString(w[4:])
if w != "" { if w != "" {
if exclude { if exclude {
q.Where("Tags NOT LIKE ?", "%\""+escPercentChar(w)+"\"%") q.Where(`m.ID NOT IN (SELECT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID WHERE t.Name = ?)`, w)
} else { } else {
q.Where("Tags LIKE ?", "%\""+escPercentChar(w)+"\"%") q.Where(`m.ID IN (SELECT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID WHERE t.Name = ?)`, w)
} }
} }
} else if w == "is:read" { } else if w == "is:read" {
@ -297,9 +309,9 @@ func searchQueryBuilder(searchString string) *sqlf.Stmt {
} }
} else if w == "is:tagged" { } else if w == "is:tagged" {
if exclude { if exclude {
q.Where("Tags = ?", "[]") q.Where(`m.ID NOT IN (SELECT DISTINCT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID)`)
} else { } else {
q.Where("Tags != ?", "[]") q.Where(`m.ID IN (SELECT DISTINCT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID)`)
} }
} else if w == "has:attachment" || w == "has:attachments" { } else if w == "has:attachment" || w == "has:attachments" {
if exclude { if exclude {

View File

@ -1,8 +1,7 @@
package storage package storage
import ( import (
"context" "database/sql"
"encoding/json"
"sort" "sort"
"strings" "strings"
@ -12,8 +11,8 @@ import (
"github.com/leporo/sqlf" "github.com/leporo/sqlf"
) )
// SetTags will set the tags for a given database ID, used via API // SetMessageTags will set the tags for a given database ID
func SetTags(id string, tags []string) error { func SetMessageTags(id string, tags []string) error {
applyTags := []string{} applyTags := []string{}
for _, t := range tags { for _, t := range tags {
t = tools.CleanTag(t) t = tools.CleanTag(t)
@ -22,26 +21,162 @@ func SetTags(id string, tags []string) error {
} }
} }
sort.Strings(applyTags) currentTags := getMessageTags(id)
origTagCount := len(currentTags)
tagJSON, err := json.Marshal(applyTags) for _, t := range applyTags {
if err != nil { t = tools.CleanTag(t)
logger.Log().Errorf("[db] setting tags for message %s", id) if t == "" || !config.ValidTagRegexp.MatchString(t) || inArray(t, currentTags) {
continue
}
if err := AddMessageTag(id, t); err != nil {
return err
}
}
if origTagCount > 0 {
currentTags = getMessageTags(id)
for _, t := range currentTags {
if !inArray(t, applyTags) {
if err := DeleteMessageTag(id, t); err != nil {
return err
}
}
}
}
return nil
}
// AddMessageTag adds a tag to a message
func AddMessageTag(id, name string) error {
var tagID int
q := sqlf.From("tags").
Select("ID").To(&tagID).
Where("Name = ?", name)
// tag exists - add tag to message
if err := q.QueryRowAndClose(nil, db); err == nil {
// check message does not already have this tag
var count int
if _, err := sqlf.From("message_tags").
Select("COUNT(ID)").To(&count).
Where("ID = ?", id).
Where("TagID = ?", tagID).
ExecAndClose(nil, db); err != nil {
return err
}
if count != 0 {
// already exists
return nil
}
logger.Log().Debugf("[tags] adding tag \"%s\" to %s", name, id)
_, err := sqlf.InsertInto("message_tags").
Set("ID", id).
Set("TagID", tagID).
ExecAndClose(nil, db)
return err return err
} }
_, err = sqlf.Update("mailbox"). logger.Log().Debugf("[tags] adding tag \"%s\" to %s", name, id)
Set("Tags", string(tagJSON)).
Where("ID = ?", id).
ExecAndClose(context.Background(), db)
if err == nil { // tag dos not exist, add new one
logger.Log().Debugf("[db] set tags %s for message %s", string(tagJSON), id) if err := sqlf.InsertInto("tags").
Set("Name", name).
Returning("ID").To(&tagID).
QueryRowAndClose(nil, db); err != nil {
return err
} }
// check message does not already have this tag
var count int
if _, err := sqlf.From("message_tags").
Select("COUNT(ID)").To(&count).
Where("ID = ?", id).
Where("TagID = ?", tagID).
ExecAndClose(nil, db); err != nil {
return err
}
if count != 0 {
return nil // already exists
}
// add tag to message
_, err := sqlf.InsertInto("message_tags").
Set("ID", id).
Set("TagID", tagID).
ExecAndClose(nil, db)
return err return err
} }
// DeleteMessageTag deleted a tag from a message
func DeleteMessageTag(id, name string) error {
if _, err := sqlf.DeleteFrom("message_tags").
Where("message_tags.ID = ?", id).
Where(`message_tags.Key IN (SELECT Key FROM message_tags LEFT JOIN tags ON TagID=tags.ID WHERE Name = ?)`, name).
ExecAndClose(nil, db); err != nil {
return err
}
return pruneUnusedTags()
}
// DeleteAllMessageTags deleted all tags from a message
func DeleteAllMessageTags(id string) error {
if _, err := sqlf.DeleteFrom("message_tags").
Where("message_tags.ID = ?", id).
ExecAndClose(nil, db); err != nil {
return err
}
return pruneUnusedTags()
}
// PruneUnusedTags will delete all unused tags from the database
func pruneUnusedTags() error {
q := sqlf.From("tags").
Select("tags.ID, tags.Name, COUNT(message_tags.ID) as COUNT").
LeftJoin("message_tags", "tags.ID = message_tags.TagID").
GroupBy("tags.ID")
toDel := []int{}
if err := q.QueryAndClose(nil, db, func(row *sql.Rows) {
var n string
var id int
var c int
if err := row.Scan(&id, &n, &c); err != nil {
logger.Log().Error("[tags]", err)
return
}
if c == 0 {
logger.Log().Debugf("[tags] deleting unused tag \"%s\"", n)
toDel = append(toDel, id)
}
}); err != nil {
return err
}
if len(toDel) > 0 {
for _, id := range toDel {
if _, err := sqlf.DeleteFrom("tags").
Where("ID = ?", id).
ExecAndClose(nil, db); err != nil {
return err
}
}
}
return nil
}
// Find tags set via --tags in raw message. // Find tags set via --tags in raw message.
// Returns a comma-separated string. // Returns a comma-separated string.
func findTagsInRawMessage(message *[]byte) string { func findTagsInRawMessage(message *[]byte) string {
@ -64,20 +199,18 @@ func findTagsInRawMessage(message *[]byte) string {
// Used when parsing a raw email. // Used when parsing a raw email.
func getMessageTags(id string) []string { func getMessageTags(id string) []string {
tags := []string{} tags := []string{}
var data string var name string
q := sqlf.From("mailbox"). if err := sqlf.
Select(`Tags`).To(&data). Select(`Name`).To(&name).
Where(`ID = ?`, id) From("Tags").
LeftJoin("message_tags", "Tags.ID=message_tags.TagID").
err := q.QueryRowAndClose(context.Background(), db) Where(`message_tags.ID = ?`, id).
if err != nil { OrderBy("Name").
logger.Log().Error(err) QueryAndClose(nil, db, func(row *sql.Rows) {
return tags tags = append(tags, name)
} }); err != nil {
logger.Log().Errorf("[tags] %s", err.Error())
if err := json.Unmarshal([]byte(data), &tags); err != nil {
logger.Log().Error(err)
return tags return tags
} }
@ -103,7 +236,7 @@ func uniqueTagsFromString(s string) []string {
tags = append(tags, w) tags = append(tags, w)
} }
} else { } else {
logger.Log().Debugf("[db] ignoring invalid tag: %s", w) logger.Log().Debugf("[tags] ignoring invalid tag: %s", w)
} }
} }

View File

@ -2,6 +2,7 @@ package storage
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
) )
@ -23,7 +24,7 @@ func TestTags(t *testing.T) {
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if err := SetTags(ids[i], []string{fmt.Sprintf("Tag-%d", i)}); err != nil { if err := SetMessageTags(ids[i], []string{fmt.Sprintf("Tag-%d", i)}); err != nil {
t.Log("error ", err) t.Log("error ", err)
t.Fail() t.Fail()
} }
@ -40,4 +41,71 @@ func TestTags(t *testing.T) {
t.Fatal("Message tags do not match") t.Fatal("Message tags do not match")
} }
} }
if err := DeleteAllMessages(); err != nil {
t.Log("error ", err)
t.Fail()
}
// test 20 tags
id, err := Store(testMimeEmail)
if err != nil {
t.Log("error ", err)
t.Fail()
}
newTags := []string{}
for i := 0; i < 20; i++ {
// pad number with 0 to ensure they are returned alphabetically
newTags = append(newTags, fmt.Sprintf("AnotherTag %02d", i))
}
if err := SetMessageTags(id, newTags); err != nil {
t.Log("error ", err)
t.Fail()
}
returnedTags := getMessageTags(id)
assertEqual(t, strings.Join(newTags, "|"), strings.Join(returnedTags, "|"), "Message tags do not match")
// remove first tag
if err := DeleteMessageTag(id, newTags[0]); err != nil {
t.Log("error ", err)
t.Fail()
}
returnedTags = getMessageTags(id)
assertEqual(t, strings.Join(newTags[1:], "|"), strings.Join(returnedTags, "|"), "Message tags do not match after deleting 1")
// remove all tags
if err := DeleteAllMessageTags(id); err != nil {
t.Log("error ", err)
t.Fail()
}
returnedTags = getMessageTags(id)
assertEqual(t, "", strings.Join(returnedTags, "|"), "Message tags should be empty")
// apply the same tag twice
if err := SetMessageTags(id, []string{"Duplicate Tag", "Duplicate Tag"}); err != nil {
t.Log("error ", err)
t.Fail()
}
returnedTags = getMessageTags(id)
assertEqual(t, "Duplicate Tag", strings.Join(returnedTags, "|"), "Message tags should be duplicated")
if err := DeleteAllMessageTags(id); err != nil {
t.Log("error ", err)
t.Fail()
}
// apply tag with invalid characters
if err := SetMessageTags(id, []string{"Dirty! \"Tag\""}); err != nil {
t.Log("error ", err)
t.Fail()
}
returnedTags = getMessageTags(id)
assertEqual(t, "Dirty Tag", strings.Join(returnedTags, "|"), "Dirty message tag did not clean as expected")
if err := DeleteAllMessageTags(id); err != nil {
t.Log("error ", err)
t.Fail()
}
// Check deleted message tags also prune the tags database
allTags := GetAllTags()
assertEqual(t, "", strings.Join(allTags, "|"), "Dirty message tag did not clean as expected")
} }

View File

@ -139,6 +139,12 @@ func dbCron() {
continue continue
} }
_, err = tx.Query(`DELETE FROM message_tags WHERE ID IN (?`+strings.Repeat(",?", len(ids)-1)+`)`, args...) // #nosec
if err != nil {
logger.Log().Errorf("[db] %s", err.Error())
continue
}
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
@ -148,6 +154,10 @@ func dbCron() {
} }
} }
if err := pruneUnusedTags(); err != nil {
logger.Log().Errorf("[db] %s", err.Error())
}
dbDataDeleted = true dbDataDeleted = true
elapsed := time.Since(start) elapsed := time.Since(start)

View File

@ -511,9 +511,9 @@ func SetReadStatus(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("ok")) _, _ = w.Write([]byte("ok"))
} }
// GetTags (method: GET) will get all tags currently in use // GetAllTags (method: GET) will get all tags currently in use
func GetTags(w http.ResponseWriter, _ *http.Request) { func GetAllTags(w http.ResponseWriter, _ *http.Request) {
// swagger:route GET /api/v1/tags tags GetTags // swagger:route GET /api/v1/tags tags GetAllTags
// //
// # Get all current tags // # Get all current tags
// //
@ -541,7 +541,7 @@ func GetTags(w http.ResponseWriter, _ *http.Request) {
} }
// SetTags (method: PUT) will set the tags for all provided IDs // SetTags (method: PUT) will set the tags for all provided IDs
func SetTags(w http.ResponseWriter, r *http.Request) { func SetMessageTags(w http.ResponseWriter, r *http.Request) {
// swagger:route PUT /api/v1/tags tags SetTags // swagger:route PUT /api/v1/tags tags SetTags
// //
// # Set message tags // # Set message tags
@ -577,7 +577,7 @@ func SetTags(w http.ResponseWriter, r *http.Request) {
if len(ids) > 0 { if len(ids) > 0 {
for _, id := range ids { for _, id := range ids {
if err := storage.SetTags(id, data.Tags); err != nil { if err := storage.SetMessageTags(id, data.Tags); err != nil {
httpError(w, err.Error()) httpError(w, err.Error())
return return
} }

View File

@ -108,8 +108,8 @@ func apiRoutes() *mux.Router {
r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.GetMessages)).Methods("GET") r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.GetMessages)).Methods("GET")
r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.SetReadStatus)).Methods("PUT") r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.SetReadStatus)).Methods("PUT")
r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.DeleteMessages)).Methods("DELETE") r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.DeleteMessages)).Methods("DELETE")
r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.GetTags)).Methods("GET") r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.GetAllTags)).Methods("GET")
r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.SetTags)).Methods("PUT") r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.SetMessageTags)).Methods("PUT")
r.HandleFunc(config.Webroot+"api/v1/search", middleWareFunc(apiv1.Search)).Methods("GET") r.HandleFunc(config.Webroot+"api/v1/search", middleWareFunc(apiv1.Search)).Methods("GET")
r.HandleFunc(config.Webroot+"api/v1/search", middleWareFunc(apiv1.DeleteSearch)).Methods("DELETE") r.HandleFunc(config.Webroot+"api/v1/search", middleWareFunc(apiv1.DeleteSearch)).Methods("DELETE")
r.HandleFunc(config.Webroot+"api/v1/message/{id}/part/{partID}", middleWareFunc(apiv1.DownloadAttachment)).Methods("GET") r.HandleFunc(config.Webroot+"api/v1/message/{id}/part/{partID}", middleWareFunc(apiv1.DownloadAttachment)).Methods("GET")

View File

@ -186,7 +186,7 @@ func TestAPIv1Search(t *testing.T) {
defer ts.Close() defer ts.Close()
// insert 100 // insert 100
t.Log("Insert 100 messages") t.Log("Insert 100 messages & tag")
insertEmailData(t) insertEmailData(t)
assertStatsEqual(t, ts.URL+"/api/v1/messages", 100, 100) assertStatsEqual(t, ts.URL+"/api/v1/messages", 100, 100)
@ -201,6 +201,8 @@ func TestAPIv1Search(t *testing.T) {
assertSearchEqual(t, ts.URL+"/api/v1/search", "!thisdoesnotexist", 100) assertSearchEqual(t, ts.URL+"/api/v1/search", "!thisdoesnotexist", 100)
assertSearchEqual(t, ts.URL+"/api/v1/search", "-thisdoesnotexist", 100) assertSearchEqual(t, ts.URL+"/api/v1/search", "-thisdoesnotexist", 100)
assertSearchEqual(t, ts.URL+"/api/v1/search", "thisdoesnotexist", 0) assertSearchEqual(t, ts.URL+"/api/v1/search", "thisdoesnotexist", 0)
assertSearchEqual(t, ts.URL+"/api/v1/search", "tag:\"Test tag 065\"", 1)
assertSearchEqual(t, ts.URL+"/api/v1/search", "!tag:\"Test tag 023\"", 99)
} }
func setup() { func setup() {
@ -272,7 +274,13 @@ func insertEmailData(t *testing.T) {
t.Fail() t.Fail()
} }
if _, err := storage.Store(buf.Bytes()); err != nil { id, err := storage.Store(buf.Bytes())
if err != nil {
t.Log("error ", err)
t.Fail()
}
if err := storage.SetMessageTags(id, []string{fmt.Sprintf("Test tag %03d", i)}); err != nil {
t.Log("error ", err) t.Log("error ", err)
t.Fail() t.Fail()
} }