From 06ca217cdeb6454286d05d106db3b3549200871a Mon Sep 17 00:00:00 2001 From: Ralph Slooten Date: Mon, 1 Jan 2024 23:46:34 +1300 Subject: [PATCH] Chore: Convert to many-to-many message tag relationships --- internal/storage/database.go | 90 ++++++------ internal/storage/migrationTasks.go | 217 ++++++----------------------- internal/storage/migrations.go | 17 +++ internal/storage/search.go | 44 +++--- internal/storage/tags.go | 189 +++++++++++++++++++++---- internal/storage/tags_test.go | 70 +++++++++- internal/storage/utils.go | 10 ++ server/apiv1/api.go | 10 +- server/server.go | 4 +- server/server_test.go | 12 +- 10 files changed, 392 insertions(+), 271 deletions(-) diff --git a/internal/storage/database.go b/internal/storage/database.go index ceeb20d..a8c0779 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -14,7 +14,6 @@ import ( "os/signal" "path" "path/filepath" - "sort" "strings" "syscall" "time" @@ -181,11 +180,6 @@ func Store(body []byte) (string, error) { tagData := uniqueTagsFromString(tagStr) - tagJSON, err := json.Marshal(tagData) - if err != nil { - return "", err - } - // begin a transaction to ensure both the message // and data are stored successfully ctx := context.Background() @@ -204,8 +198,8 @@ func Store(body []byte) (string, error) { snippet := tools.CreateSnippet(env.Text, env.HTML) // insert mail summary data - _, err = tx.Exec("INSERT INTO mailbox(Created, ID, MessageID, Subject, Metadata, Size, Inline, Attachments, SearchText, Tags, Read, Snippet) values(?,?,?,?,?,?,?,?,?,?,0, ?)", - created.UnixMilli(), id, messageID, subject, string(summaryJSON), size, inline, attachments, searchText, string(tagJSON), snippet) + _, err = tx.Exec("INSERT INTO mailbox(Created, ID, MessageID, Subject, Metadata, Size, Inline, Attachments, SearchText, Read, Snippet) values(?,?,?,?,?,?,?,?,?,0,?)", + created.UnixMilli(), id, messageID, subject, string(summaryJSON), size, inline, attachments, searchText, snippet) if err != nil { return "", err } @@ -221,6 +215,13 @@ func Store(body []byte) (string, error) { return "", err } + if len(tagData) > 0 { + // set tags after tx.Commit() + if err := SetMessageTags(id, tagData); err != nil { + return "", err + } + } + c := &MessageSummary{} if err := json.Unmarshal(summaryJSON, c); err != nil { return "", err @@ -249,10 +250,11 @@ func Store(body []byte) (string, error) { // sorted latest to oldest func List(start, limit int) ([]MessageSummary, error) { results := []MessageSummary{} + tsStart := time.Now() - q := sqlf.From("mailbox"). - Select(`Created, ID, MessageID, Subject, Metadata, Size, Attachments, Read, Tags, Snippet`). - OrderBy("Created DESC"). + q := sqlf.From("mailbox m"). + Select(`m.Created, m.ID, m.MessageID, m.Subject, m.Metadata, m.Size, m.Attachments, m.Read, m.Snippet`). + OrderBy("m.Created DESC"). Limit(limit). Offset(start) @@ -264,12 +266,11 @@ func List(start, limit int) ([]MessageSummary, error) { var metadata string var size int var attachments int - var tags string var read int var snippet string em := MessageSummary{} - if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &tags, &snippet); err != nil { + if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &snippet); err != nil { logger.Log().Error(err) return } @@ -279,11 +280,6 @@ func List(start, limit int) ([]MessageSummary, error) { return } - if err := json.Unmarshal([]byte(tags), &em.Tags); err != nil { - logger.Log().Error(err) - return - } - em.Created = time.UnixMilli(created) em.ID = id em.MessageID = messageID @@ -298,8 +294,17 @@ func List(start, limit int) ([]MessageSummary, error) { return results, err } + // set tags for listed messages only + for i, m := range results { + results[i].Tags = getMessageTags(m.ID) + } + dbLastAction = time.Now() + elapsed := time.Since(tsStart) + + logger.Log().Debugf("[db] list INBOX in %s", elapsed) + return results, nil } @@ -616,6 +621,10 @@ func DeleteOneMessage(id string) error { logger.Log().Debugf("[db] deleted message %s", id) } + if err := DeleteAllMessageTags(id); err != nil { + return err + } + dbLastAction = time.Now() dbDataDeleted = true @@ -655,6 +664,16 @@ func DeleteAllMessages() error { return err } + _, err = tx.Exec("DELETE FROM tags") + if err != nil { + return err + } + + _, err = tx.Exec("DELETE FROM message_tags") + if err != nil { + return err + } + if err := tx.Commit(); err != nil { return err } @@ -676,38 +695,19 @@ func DeleteAllMessages() error { // GetAllTags returns all used tags func GetAllTags() []string { - q := sqlf.From("mailbox"). - Select(`DISTINCT Tags`). - Where("Tags != ?", "[]") - var tags = []string{} + var name string - if err := q.QueryAndClose(nil, db, func(row *sql.Rows) { - var tagData string - t := []string{} - - if err := row.Scan(&tagData); err != nil { - logger.Log().Error(err) - return - } - - if err := json.Unmarshal([]byte(tagData), &t); err != nil { - logger.Log().Error(err) - return - } - - for _, tag := range t { - if !inArray(tag, tags) { - tags = append(tags, tag) - } - } - - }); err != nil { + if err := sqlf. + Select(`DISTINCT Name`). + From("tags").To(&name). + OrderBy("Name"). + QueryAndClose(nil, db, func(row *sql.Rows) { + tags = append(tags, name) + }); err != nil { logger.Log().Error(err) } - sort.Strings(tags) - return tags } diff --git a/internal/storage/migrationTasks.go b/internal/storage/migrationTasks.go index 85f6bb8..94ebff3 100644 --- a/internal/storage/migrationTasks.go +++ b/internal/storage/migrationTasks.go @@ -1,200 +1,73 @@ package storage -import ( - "bytes" - "context" - "database/sql" - "strings" - "time" +// These functions are used to migrate data formats/structure on startup. + +import ( + "database/sql" + "encoding/json" - "github.com/axllent/mailpit/config" "github.com/axllent/mailpit/internal/logger" - "github.com/jhillyerd/enmime" "github.com/leporo/sqlf" - "golang.org/x/text/language" - "golang.org/x/text/message" ) func dataMigrations() { - updateOrderByCreatedTask() - assignMessageIDsTask() + migrateTagsToManyMany() } -// Update Created column using Created metadata datetime <= v1.6.5 -// Migration task implemented 05/2023 - can be removed end 2023 -func updateOrderByCreatedTask() { - q := sqlf.From("mailbox"). - Select("ID"). - Select(`json_extract(Metadata, '$.Created') as Created`). - Where("Created < ?", 1155000600) - - toUpdate := make(map[string]int64) - p := message.NewPrinter(language.English) +// Migrate tags to ManyMany structure +// Migration task implemented 12/2023 +// Can be removed end 06/2024 and Tags column & index dropped from mailbox +func migrateTagsToManyMany() { + toConvert := make(map[string][]string) + q := sqlf. + Select("ID, Tags"). + From("mailbox"). + Where("Tags != ?", "[]"). + Where("Tags IS NOT NULL") if err := q.QueryAndClose(nil, db, func(row *sql.Rows) { var id string - var ts sql.NullString - if err := row.Scan(&id, &ts); err != nil { - logger.Log().Error("[migration]", err) + var jsonTags string + if err := row.Scan(&id, &jsonTags); err != nil { + logger.Log().Errorf("[migration] %s", err.Error()) return } - if !ts.Valid { - logger.Log().Errorf("[migration] cannot get Created timestamp from %s", id) + tags := []string{} + + if err := json.Unmarshal([]byte(jsonTags), &tags); err != nil { + logger.Log().Error(err) return } - t, _ := time.Parse(time.RFC3339Nano, ts.String) - toUpdate[id] = t.UnixMilli() + toConvert[id] = tags }); err != nil { - logger.Log().Error("[migration]", err) - return + logger.Log().Errorf("[migration] %s", err.Error()) } - total := len(toUpdate) - - if total == 0 { - return - } - - logger.Log().Infof("[migration] updating timestamp for %s messages", p.Sprintf("%d", len(toUpdate))) - - // begin a transaction - ctx := context.Background() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Log().Error("[migration]", err) - return - } - - // roll back if it fails - defer tx.Rollback() - - var blockTime = time.Now() - - count := 0 - for id, ts := range toUpdate { - count++ - _, err := tx.Exec(`UPDATE mailbox SET Created = ? WHERE ID = ?`, ts, id) - if err != nil { - logger.Log().Error("[migration]", err) + if len(toConvert) > 0 { + logger.Log().Infof("[migration] converting %d message tags", len(toConvert)) + for id, tags := range toConvert { + if err := SetMessageTags(id, tags); err != nil { + logger.Log().Errorf("[migration] %s", err.Error()) + } else { + if _, err := sqlf.Update("mailbox"). + Set("Tags", nil). + Where("ID = ?", id). + ExecAndClose(nil, db); err != nil { + logger.Log().Errorf("[migration] %s", err.Error()) + } + } } - if count%1000 == 0 { - percent := (100 * count) / total - logger.Log().Infof("[migration] updated timestamp for 1,000 messages [%d%%] in %s", percent, time.Since(blockTime)) - blockTime = time.Now() - } + logger.Log().Info("[migration] tags conversion complete") } - logger.Log().Infof("[migration] commit %s changes", p.Sprintf("%d", count)) - - if err := tx.Commit(); err != nil { - logger.Log().Error("[migration]", err) - return + // set all legacy `[]` tags to NULL + if _, err := sqlf.Update("mailbox"). + Set("Tags", nil). + Where("Tags = ?", "[]"). + ExecAndClose(nil, db); err != nil { + logger.Log().Errorf("[migration] %s", err.Error()) } - - logger.Log().Infof("[migration] complete") -} - -// Find any messages without a stored Message-ID and update it <= v1.6.5 -// Migration task implemented 05/2023 - can be removed end 2023 -func assignMessageIDsTask() { - if !config.IgnoreDuplicateIDs { - return - } - - q := sqlf.From("mailbox"). - Select("ID"). - Where("MessageID = ''") - - missingIDS := make(map[string]string) - - if err := q.QueryAndClose(nil, db, func(row *sql.Rows) { - var id string - if err := row.Scan(&id); err != nil { - logger.Log().Error("[migration]", err) - return - } - missingIDS[id] = "" - }); err != nil { - logger.Log().Error("[migration]", err) - } - - if len(missingIDS) == 0 { - return - } - - var count int - var blockTime = time.Now() - p := message.NewPrinter(language.English) - - total := len(missingIDS) - - logger.Log().Infof("[migration] extracting Message-IDs for %s messages", p.Sprintf("%d", total)) - - for id := range missingIDS { - raw, err := GetMessageRaw(id) - if err != nil { - logger.Log().Error("[migration]", err) - continue - } - - r := bytes.NewReader(raw) - - env, err := enmime.ReadEnvelope(r) - if err != nil { - logger.Log().Error("[migration]", err) - continue - } - - messageID := strings.Trim(env.GetHeader("Message-ID"), "<>") - - missingIDS[id] = messageID - - count++ - - if count%1000 == 0 { - percent := (100 * count) / total - logger.Log().Infof("[migration] extracted 1,000 Message-IDs [%d%%] in %s", percent, time.Since(blockTime)) - blockTime = time.Now() - } - } - - // begin a transaction - ctx := context.Background() - tx, err := db.BeginTx(ctx, nil) - if err != nil { - logger.Log().Error("[migration]", err) - return - } - - // roll back if it fails - defer tx.Rollback() - - count = 0 - - for id, mid := range missingIDS { - _, err = tx.Exec(`UPDATE mailbox SET MessageID = ? WHERE ID = ?`, mid, id) - if err != nil { - logger.Log().Error("[migration]", err) - } - - count++ - - if count%1000 == 0 { - percent := (100 * count) / total - logger.Log().Infof("[migration] stored 1,000 Message-IDs [%d%%] in %s", percent, time.Since(blockTime)) - blockTime = time.Now() - } - } - - logger.Log().Infof("[migration] commit %s changes", p.Sprintf("%d", count)) - - if err := tx.Commit(); err != nil { - logger.Log().Error("[migration]", err) - return - } - - logger.Log().Infof("[migration] complete") } diff --git a/internal/storage/migrations.go b/internal/storage/migrations.go index ac5937b..dc84cec 100644 --- a/internal/storage/migrations.go +++ b/internal/storage/migrations.go @@ -71,6 +71,23 @@ var ( Description: "Create snippet column", Script: `ALTER TABLE mailbox ADD COLUMN Snippet Text NOT NULL DEFAULT '';`, }, + { + Version: 1.4, + Description: "Create tag tables", + Script: `CREATE TABLE IF NOT EXISTS tags ( + ID INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + Name TEXT COLLATE NOCASE + ); + CREATE UNIQUE INDEX IF NOT EXISTS idx_tag_name ON tags (Name); + + CREATE TABLE IF NOT EXISTS message_tags( + Key INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + ID TEXT REFERENCES mailbox(ID), + TagID INT REFERENCES tags(ID) + ); + CREATE INDEX IF NOT EXISTS idx_message_tag_id ON message_tags (ID); + CREATE INDEX IF NOT EXISTS idx_message_tag_tagid ON message_tags (TagID);`, + }, } ) diff --git a/internal/storage/search.go b/internal/storage/search.go index abb8a6a..ecd873e 100644 --- a/internal/storage/search.go +++ b/internal/storage/search.go @@ -37,13 +37,12 @@ func Search(search string, start, limit int) ([]MessageSummary, int, error) { var metadata string var size int var attachments int - var tags string var snippet string var read int var ignore string em := MessageSummary{} - if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &tags, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil { + if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil { logger.Log().Error(err) return } @@ -53,11 +52,6 @@ func Search(search string, start, limit int) ([]MessageSummary, int, error) { return } - if err := json.Unmarshal([]byte(tags), &em.Tags); err != nil { - logger.Log().Error(err) - return - } - em.Created = time.UnixMilli(created) em.ID = id em.MessageID = messageID @@ -85,6 +79,11 @@ func Search(search string, start, limit int) ([]MessageSummary, int, error) { results = allResults[start:end] } + // set tags for listed messages only + for i, m := range results { + results[i].Tags = getMessageTags(m.ID) + } + elapsed := time.Since(tsStart) logger.Log().Debugf("[db] search for \"%s\" in %s", search, elapsed) @@ -109,12 +108,12 @@ func DeleteSearch(search string) error { var metadata string var size int var attachments int - var tags string + // var tags string var read int var snippet string var ignore string - if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &tags, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil { + if err := row.Scan(&created, &id, &messageID, &subject, &metadata, &size, &attachments, &read, &snippet, &ignore, &ignore, &ignore, &ignore); err != nil { logger.Log().Error(err) return } @@ -172,10 +171,21 @@ func DeleteSearch(search string) error { if err != nil { return err } + + sqlDelete3 := `DELETE FROM message_tags WHERE ID IN (?` + strings.Repeat(",?", len(ids)-1) + `)` + + _, err = tx.Exec(sqlDelete3, delIDs...) + if err != nil { + return err + } } err = tx.Commit() + if err := pruneUnusedTags(); err != nil { + return err + } + if err == nil { logger.Log().Debugf("[db] deleted %d messages matching %s", total, search) } @@ -195,13 +205,15 @@ func searchQueryBuilder(searchString string) *sqlf.Stmt { // group strings with quotes as a single argument and remove quotes args := tools.ArgsParser(searchString) - q := sqlf.From("mailbox"). - Select(`Created, ID, MessageID, Subject, Metadata, Size, Attachments, Read, Tags, Snippet, + q := sqlf.From("mailbox m"). + Select(`m.Created, m.ID, m.MessageID, m.Subject, m.Metadata, m.Size, m.Attachments, m.Read, + m.Snippet, IFNULL(json_extract(Metadata, '$.To'), '{}') as ToJSON, IFNULL(json_extract(Metadata, '$.From'), '{}') as FromJSON, IFNULL(json_extract(Metadata, '$.Cc'), '{}') as CcJSON, IFNULL(json_extract(Metadata, '$.Bcc'), '{}') as BccJSON - `).OrderBy("Created DESC") + `). + OrderBy("m.Created DESC") for _, w := range args { if cleanString(w) == "" { @@ -278,9 +290,9 @@ func searchQueryBuilder(searchString string) *sqlf.Stmt { w = cleanString(w[4:]) if w != "" { if exclude { - q.Where("Tags NOT LIKE ?", "%\""+escPercentChar(w)+"\"%") + q.Where(`m.ID NOT IN (SELECT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID WHERE t.Name = ?)`, w) } else { - q.Where("Tags LIKE ?", "%\""+escPercentChar(w)+"\"%") + q.Where(`m.ID IN (SELECT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID WHERE t.Name = ?)`, w) } } } else if w == "is:read" { @@ -297,9 +309,9 @@ func searchQueryBuilder(searchString string) *sqlf.Stmt { } } else if w == "is:tagged" { if exclude { - q.Where("Tags = ?", "[]") + q.Where(`m.ID NOT IN (SELECT DISTINCT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID)`) } else { - q.Where("Tags != ?", "[]") + q.Where(`m.ID IN (SELECT DISTINCT mt.ID FROM message_tags mt JOIN tags t ON mt.TagID = t.ID)`) } } else if w == "has:attachment" || w == "has:attachments" { if exclude { diff --git a/internal/storage/tags.go b/internal/storage/tags.go index d4f2521..0ccc2a9 100644 --- a/internal/storage/tags.go +++ b/internal/storage/tags.go @@ -1,8 +1,7 @@ package storage import ( - "context" - "encoding/json" + "database/sql" "sort" "strings" @@ -12,8 +11,8 @@ import ( "github.com/leporo/sqlf" ) -// SetTags will set the tags for a given database ID, used via API -func SetTags(id string, tags []string) error { +// SetMessageTags will set the tags for a given database ID +func SetMessageTags(id string, tags []string) error { applyTags := []string{} for _, t := range tags { t = tools.CleanTag(t) @@ -22,26 +21,162 @@ func SetTags(id string, tags []string) error { } } - sort.Strings(applyTags) + currentTags := getMessageTags(id) + origTagCount := len(currentTags) - tagJSON, err := json.Marshal(applyTags) - if err != nil { - logger.Log().Errorf("[db] setting tags for message %s", id) + for _, t := range applyTags { + t = tools.CleanTag(t) + if t == "" || !config.ValidTagRegexp.MatchString(t) || inArray(t, currentTags) { + continue + } + + if err := AddMessageTag(id, t); err != nil { + return err + } + } + + if origTagCount > 0 { + currentTags = getMessageTags(id) + + for _, t := range currentTags { + if !inArray(t, applyTags) { + if err := DeleteMessageTag(id, t); err != nil { + return err + } + } + } + } + + return nil +} + +// AddMessageTag adds a tag to a message +func AddMessageTag(id, name string) error { + var tagID int + + q := sqlf.From("tags"). + Select("ID").To(&tagID). + Where("Name = ?", name) + + // tag exists - add tag to message + if err := q.QueryRowAndClose(nil, db); err == nil { + // check message does not already have this tag + var count int + if _, err := sqlf.From("message_tags"). + Select("COUNT(ID)").To(&count). + Where("ID = ?", id). + Where("TagID = ?", tagID). + ExecAndClose(nil, db); err != nil { + return err + } + if count != 0 { + // already exists + return nil + } + + logger.Log().Debugf("[tags] adding tag \"%s\" to %s", name, id) + + _, err := sqlf.InsertInto("message_tags"). + Set("ID", id). + Set("TagID", tagID). + ExecAndClose(nil, db) return err } - _, err = sqlf.Update("mailbox"). - Set("Tags", string(tagJSON)). - Where("ID = ?", id). - ExecAndClose(context.Background(), db) + logger.Log().Debugf("[tags] adding tag \"%s\" to %s", name, id) - if err == nil { - logger.Log().Debugf("[db] set tags %s for message %s", string(tagJSON), id) + // tag dos not exist, add new one + if err := sqlf.InsertInto("tags"). + Set("Name", name). + Returning("ID").To(&tagID). + QueryRowAndClose(nil, db); err != nil { + return err } + // check message does not already have this tag + var count int + if _, err := sqlf.From("message_tags"). + Select("COUNT(ID)").To(&count). + Where("ID = ?", id). + Where("TagID = ?", tagID). + ExecAndClose(nil, db); err != nil { + return err + } + if count != 0 { + return nil // already exists + } + + // add tag to message + _, err := sqlf.InsertInto("message_tags"). + Set("ID", id). + Set("TagID", tagID). + ExecAndClose(nil, db) return err } +// DeleteMessageTag deleted a tag from a message +func DeleteMessageTag(id, name string) error { + if _, err := sqlf.DeleteFrom("message_tags"). + Where("message_tags.ID = ?", id). + Where(`message_tags.Key IN (SELECT Key FROM message_tags LEFT JOIN tags ON TagID=tags.ID WHERE Name = ?)`, name). + ExecAndClose(nil, db); err != nil { + return err + } + + return pruneUnusedTags() +} + +// DeleteAllMessageTags deleted all tags from a message +func DeleteAllMessageTags(id string) error { + if _, err := sqlf.DeleteFrom("message_tags"). + Where("message_tags.ID = ?", id). + ExecAndClose(nil, db); err != nil { + return err + } + + return pruneUnusedTags() +} + +// PruneUnusedTags will delete all unused tags from the database +func pruneUnusedTags() error { + q := sqlf.From("tags"). + Select("tags.ID, tags.Name, COUNT(message_tags.ID) as COUNT"). + LeftJoin("message_tags", "tags.ID = message_tags.TagID"). + GroupBy("tags.ID") + + toDel := []int{} + + if err := q.QueryAndClose(nil, db, func(row *sql.Rows) { + var n string + var id int + var c int + + if err := row.Scan(&id, &n, &c); err != nil { + logger.Log().Error("[tags]", err) + return + } + + if c == 0 { + logger.Log().Debugf("[tags] deleting unused tag \"%s\"", n) + toDel = append(toDel, id) + } + }); err != nil { + return err + } + + if len(toDel) > 0 { + for _, id := range toDel { + if _, err := sqlf.DeleteFrom("tags"). + Where("ID = ?", id). + ExecAndClose(nil, db); err != nil { + return err + } + } + } + + return nil +} + // Find tags set via --tags in raw message. // Returns a comma-separated string. func findTagsInRawMessage(message *[]byte) string { @@ -64,20 +199,18 @@ func findTagsInRawMessage(message *[]byte) string { // Used when parsing a raw email. func getMessageTags(id string) []string { tags := []string{} - var data string + var name string - q := sqlf.From("mailbox"). - Select(`Tags`).To(&data). - Where(`ID = ?`, id) - - err := q.QueryRowAndClose(context.Background(), db) - if err != nil { - logger.Log().Error(err) - return tags - } - - if err := json.Unmarshal([]byte(data), &tags); err != nil { - logger.Log().Error(err) + if err := sqlf. + Select(`Name`).To(&name). + From("Tags"). + LeftJoin("message_tags", "Tags.ID=message_tags.TagID"). + Where(`message_tags.ID = ?`, id). + OrderBy("Name"). + QueryAndClose(nil, db, func(row *sql.Rows) { + tags = append(tags, name) + }); err != nil { + logger.Log().Errorf("[tags] %s", err.Error()) return tags } @@ -103,7 +236,7 @@ func uniqueTagsFromString(s string) []string { tags = append(tags, w) } } else { - logger.Log().Debugf("[db] ignoring invalid tag: %s", w) + logger.Log().Debugf("[tags] ignoring invalid tag: %s", w) } } diff --git a/internal/storage/tags_test.go b/internal/storage/tags_test.go index c7f847a..414ffdb 100644 --- a/internal/storage/tags_test.go +++ b/internal/storage/tags_test.go @@ -2,6 +2,7 @@ package storage import ( "fmt" + "strings" "testing" ) @@ -23,7 +24,7 @@ func TestTags(t *testing.T) { } for i := 0; i < 10; i++ { - if err := SetTags(ids[i], []string{fmt.Sprintf("Tag-%d", i)}); err != nil { + if err := SetMessageTags(ids[i], []string{fmt.Sprintf("Tag-%d", i)}); err != nil { t.Log("error ", err) t.Fail() } @@ -40,4 +41,71 @@ func TestTags(t *testing.T) { t.Fatal("Message tags do not match") } } + + if err := DeleteAllMessages(); err != nil { + t.Log("error ", err) + t.Fail() + } + + // test 20 tags + id, err := Store(testMimeEmail) + if err != nil { + t.Log("error ", err) + t.Fail() + } + newTags := []string{} + for i := 0; i < 20; i++ { + // pad number with 0 to ensure they are returned alphabetically + newTags = append(newTags, fmt.Sprintf("AnotherTag %02d", i)) + } + if err := SetMessageTags(id, newTags); err != nil { + t.Log("error ", err) + t.Fail() + } + returnedTags := getMessageTags(id) + assertEqual(t, strings.Join(newTags, "|"), strings.Join(returnedTags, "|"), "Message tags do not match") + + // remove first tag + if err := DeleteMessageTag(id, newTags[0]); err != nil { + t.Log("error ", err) + t.Fail() + } + returnedTags = getMessageTags(id) + assertEqual(t, strings.Join(newTags[1:], "|"), strings.Join(returnedTags, "|"), "Message tags do not match after deleting 1") + + // remove all tags + if err := DeleteAllMessageTags(id); err != nil { + t.Log("error ", err) + t.Fail() + } + returnedTags = getMessageTags(id) + assertEqual(t, "", strings.Join(returnedTags, "|"), "Message tags should be empty") + + // apply the same tag twice + if err := SetMessageTags(id, []string{"Duplicate Tag", "Duplicate Tag"}); err != nil { + t.Log("error ", err) + t.Fail() + } + returnedTags = getMessageTags(id) + assertEqual(t, "Duplicate Tag", strings.Join(returnedTags, "|"), "Message tags should be duplicated") + if err := DeleteAllMessageTags(id); err != nil { + t.Log("error ", err) + t.Fail() + } + + // apply tag with invalid characters + if err := SetMessageTags(id, []string{"Dirty! \"Tag\""}); err != nil { + t.Log("error ", err) + t.Fail() + } + returnedTags = getMessageTags(id) + assertEqual(t, "Dirty Tag", strings.Join(returnedTags, "|"), "Dirty message tag did not clean as expected") + if err := DeleteAllMessageTags(id); err != nil { + t.Log("error ", err) + t.Fail() + } + + // Check deleted message tags also prune the tags database + allTags := GetAllTags() + assertEqual(t, "", strings.Join(allTags, "|"), "Dirty message tag did not clean as expected") } diff --git a/internal/storage/utils.go b/internal/storage/utils.go index c25e7d6..a5ff20f 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -139,6 +139,12 @@ func dbCron() { continue } + _, err = tx.Query(`DELETE FROM message_tags WHERE ID IN (?`+strings.Repeat(",?", len(ids)-1)+`)`, args...) // #nosec + if err != nil { + logger.Log().Errorf("[db] %s", err.Error()) + continue + } + err = tx.Commit() if err != nil { @@ -148,6 +154,10 @@ func dbCron() { } } + if err := pruneUnusedTags(); err != nil { + logger.Log().Errorf("[db] %s", err.Error()) + } + dbDataDeleted = true elapsed := time.Since(start) diff --git a/server/apiv1/api.go b/server/apiv1/api.go index c04f49b..d315af5 100644 --- a/server/apiv1/api.go +++ b/server/apiv1/api.go @@ -511,9 +511,9 @@ func SetReadStatus(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("ok")) } -// GetTags (method: GET) will get all tags currently in use -func GetTags(w http.ResponseWriter, _ *http.Request) { - // swagger:route GET /api/v1/tags tags GetTags +// GetAllTags (method: GET) will get all tags currently in use +func GetAllTags(w http.ResponseWriter, _ *http.Request) { + // swagger:route GET /api/v1/tags tags GetAllTags // // # Get all current tags // @@ -541,7 +541,7 @@ func GetTags(w http.ResponseWriter, _ *http.Request) { } // SetTags (method: PUT) will set the tags for all provided IDs -func SetTags(w http.ResponseWriter, r *http.Request) { +func SetMessageTags(w http.ResponseWriter, r *http.Request) { // swagger:route PUT /api/v1/tags tags SetTags // // # Set message tags @@ -577,7 +577,7 @@ func SetTags(w http.ResponseWriter, r *http.Request) { if len(ids) > 0 { for _, id := range ids { - if err := storage.SetTags(id, data.Tags); err != nil { + if err := storage.SetMessageTags(id, data.Tags); err != nil { httpError(w, err.Error()) return } diff --git a/server/server.go b/server/server.go index d695c58..85101ca 100644 --- a/server/server.go +++ b/server/server.go @@ -108,8 +108,8 @@ func apiRoutes() *mux.Router { r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.GetMessages)).Methods("GET") r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.SetReadStatus)).Methods("PUT") r.HandleFunc(config.Webroot+"api/v1/messages", middleWareFunc(apiv1.DeleteMessages)).Methods("DELETE") - r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.GetTags)).Methods("GET") - r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.SetTags)).Methods("PUT") + r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.GetAllTags)).Methods("GET") + r.HandleFunc(config.Webroot+"api/v1/tags", middleWareFunc(apiv1.SetMessageTags)).Methods("PUT") r.HandleFunc(config.Webroot+"api/v1/search", middleWareFunc(apiv1.Search)).Methods("GET") r.HandleFunc(config.Webroot+"api/v1/search", middleWareFunc(apiv1.DeleteSearch)).Methods("DELETE") r.HandleFunc(config.Webroot+"api/v1/message/{id}/part/{partID}", middleWareFunc(apiv1.DownloadAttachment)).Methods("GET") diff --git a/server/server_test.go b/server/server_test.go index 07345d7..fe31b73 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -186,7 +186,7 @@ func TestAPIv1Search(t *testing.T) { defer ts.Close() // insert 100 - t.Log("Insert 100 messages") + t.Log("Insert 100 messages & tag") insertEmailData(t) assertStatsEqual(t, ts.URL+"/api/v1/messages", 100, 100) @@ -201,6 +201,8 @@ func TestAPIv1Search(t *testing.T) { assertSearchEqual(t, ts.URL+"/api/v1/search", "!thisdoesnotexist", 100) assertSearchEqual(t, ts.URL+"/api/v1/search", "-thisdoesnotexist", 100) assertSearchEqual(t, ts.URL+"/api/v1/search", "thisdoesnotexist", 0) + assertSearchEqual(t, ts.URL+"/api/v1/search", "tag:\"Test tag 065\"", 1) + assertSearchEqual(t, ts.URL+"/api/v1/search", "!tag:\"Test tag 023\"", 99) } func setup() { @@ -272,7 +274,13 @@ func insertEmailData(t *testing.T) { t.Fail() } - if _, err := storage.Store(buf.Bytes()); err != nil { + id, err := storage.Store(buf.Bytes()) + if err != nil { + t.Log("error ", err) + t.Fail() + } + + if err := storage.SetMessageTags(id, []string{fmt.Sprintf("Test tag %03d", i)}); err != nil { t.Log("error ", err) t.Fail() }