diff --git a/storage/database.go b/storage/database.go index f638410..4b83004 100644 --- a/storage/database.go +++ b/storage/database.go @@ -140,40 +140,44 @@ func MailboxExists(name string) bool { } // CreateMailbox will create a collection if it does not exist -func CreateMailbox(name string) error { - if !MailboxExists(name) { - logger.Log().Infof("[db] creating mailbox: %s", name) +func CreateMailbox(mailbox string) error { + mailbox = sanitizeMailboxName(mailbox) - if err := db.CreateCollection(name); err != nil { + if !MailboxExists(mailbox) { + logger.Log().Infof("[db] creating mailbox: %s", mailbox) + + if err := db.CreateCollection(mailbox); err != nil { return err } // create Created index - if err := db.CreateIndex(name, "Created"); err != nil { + if err := db.CreateIndex(mailbox, "Created"); err != nil { return err } // create Read index - if err := db.CreateIndex(name, "Read"); err != nil { + if err := db.CreateIndex(mailbox, "Read"); err != nil { return err } // create separate collection for data - if err := db.CreateCollection(name + "_data"); err != nil { + if err := db.CreateCollection(mailbox + "_data"); err != nil { return err } // create Created index - if err := db.CreateIndex(name+"_data", "Created"); err != nil { + if err := db.CreateIndex(mailbox+"_data", "Created"); err != nil { return err } } - return statsRefresh(name) + return statsRefresh(mailbox) } // Store will store a message in the database and return the unique ID func Store(mailbox string, b []byte) (string, error) { + mailbox = sanitizeMailboxName(mailbox) + r := bytes.NewReader(b) // Parse message body with enmime. env, err := enmime.ReadEnvelope(r) @@ -254,6 +258,8 @@ func Store(mailbox string, b []byte) (string, error) { // as clover's `Skip()` returns a subset of all results which is much slower. // @see https://github.com/ostafen/clover/issues/73 func List(mailbox string, start, limit int) ([]data.Summary, error) { + mailbox = sanitizeMailboxName(mailbox) + var lastDoc *clover.Document count := 0 startAddingAt := start + 1 @@ -314,6 +320,8 @@ func List(mailbox string, start, limit int) ([]data.Summary, error) { // Search returns a summary of items mathing a search. It searched the SearchText field. func Search(mailbox, search string, start, limit int) ([]data.Summary, error) { + mailbox = sanitizeMailboxName(mailbox) + sq := fmt.Sprintf("(?i)%s", cleanString(regexp.QuoteMeta(search))) q, err := db.FindAll(clover.NewQuery(mailbox). Skip(start). @@ -340,11 +348,15 @@ func Search(mailbox, search string, start, limit int) ([]data.Summary, error) { // Count returns the total number of messages in a mailbox func Count(mailbox string) (int, error) { + mailbox = sanitizeMailboxName(mailbox) + return db.Count(clover.NewQuery(mailbox)) } // CountUnread returns the unread number of messages in a mailbox func CountUnread(mailbox string) (int, error) { + mailbox = sanitizeMailboxName(mailbox) + return db.Count( clover.NewQuery(mailbox). Where(clover.Field("Read").IsFalse()), @@ -355,6 +367,8 @@ func CountUnread(mailbox string) (int, error) { // ID must be supplied as this is not stored within the CloverStore but rather the // *clover.Document func GetMessage(mailbox, id string) (*data.Message, error) { + mailbox = sanitizeMailboxName(mailbox) + q, err := db.FindById(mailbox+"_data", id) if err != nil { return nil, err @@ -440,6 +454,8 @@ func GetMessage(mailbox, id string) (*data.Message, error) { // GetAttachmentPart returns an *enmime.Part (attachment or inline) from a message func GetAttachmentPart(mailbox, id, partID string) (*enmime.Part, error) { + mailbox = sanitizeMailboxName(mailbox) + data, err := GetMessageRaw(mailbox, id) if err != nil { return nil, err @@ -475,6 +491,8 @@ func GetAttachmentPart(mailbox, id, partID string) (*enmime.Part, error) { // GetMessageRaw returns an []byte of the full message func GetMessageRaw(mailbox, id string) ([]byte, error) { + mailbox = sanitizeMailboxName(mailbox) + q, err := db.FindById(mailbox+"_data", id) if err != nil { return nil, err @@ -491,6 +509,8 @@ func GetMessageRaw(mailbox, id string) ([]byte, error) { // UnreadMessage will delete all messages from a mailbox func UnreadMessage(mailbox, id string) error { + mailbox = sanitizeMailboxName(mailbox) + updates := make(map[string]interface{}) updates["Read"] = false @@ -501,6 +521,8 @@ func UnreadMessage(mailbox, id string) error { // DeleteOneMessage will delete a single message from a mailbox func DeleteOneMessage(mailbox, id string) error { + mailbox = sanitizeMailboxName(mailbox) + q, err := db.FindById(mailbox, id) if err != nil { return err @@ -519,6 +541,7 @@ func DeleteOneMessage(mailbox, id string) error { // DeleteAllMessages will delete all messages from a mailbox func DeleteAllMessages(mailbox string) error { + mailbox = sanitizeMailboxName(mailbox) totalStart := time.Now() diff --git a/storage/stats.go b/storage/stats.go index 37a8162..f78b81e 100644 --- a/storage/stats.go +++ b/storage/stats.go @@ -15,6 +15,8 @@ var ( // StatsGet returns the total/unread statistics for a mailbox func StatsGet(mailbox string) data.MailboxStats { + mailbox = sanitizeMailboxName(mailbox) + statsLock.Lock() defer statsLock.Unlock() s, ok := mailboxStats[mailbox] diff --git a/storage/utils.go b/storage/utils.go index f2bbb18..511ffdc 100644 --- a/storage/utils.go +++ b/storage/utils.go @@ -92,3 +92,11 @@ func pruneCron() { } } } + +// SanitizeMailboxName returns a clean mailbox name +// allowing only `alphanumeric` characters and `-`` +func sanitizeMailboxName(mailbox string) string { + re := regexp.MustCompile(`[^a-zA-Z0-9\-]`) + + return re.ReplaceAllString(mailbox, "") +}