From 5a4d13b15aeae8bfa51dbc9a193c174a693e65ba Mon Sep 17 00:00:00 2001 From: Ralph Slooten Date: Fri, 25 Jul 2025 20:32:31 +1200 Subject: [PATCH] Security: Prevent integer overflow conversion to uint64 --- internal/pop3/server.go | 8 ++++---- internal/stats/stats.go | 3 ++- internal/storage/messages.go | 4 ++-- internal/storage/utils.go | 3 ++- internal/tools/utils.go | 19 +++++++++++++++++++ server/apiv1/messages.go | 9 +++++---- 6 files changed, 34 insertions(+), 12 deletions(-) diff --git a/internal/pop3/server.go b/internal/pop3/server.go index 454c3df..ed624ca 100644 --- a/internal/pop3/server.go +++ b/internal/pop3/server.go @@ -215,7 +215,7 @@ func handleTransactionCommand(conn net.Conn, cmd string, args []string, messages for _, m := range messages { totalSize += m.Size } - sendResponse(conn, fmt.Sprintf("+OK %d %d", len(messages), int64(totalSize))) + sendResponse(conn, fmt.Sprintf("+OK %d %d", len(messages), totalSize)) case "LIST": totalSize := uint64(0) for _, m := range messages { @@ -229,12 +229,12 @@ func handleTransactionCommand(conn net.Conn, cmd string, args []string, messages sendResponse(conn, "-ERR no such message") return } - sendResponse(conn, fmt.Sprintf("+OK %d %d", nr, int64(messages[nr-1].Size))) + sendResponse(conn, fmt.Sprintf("+OK %d %d", nr, messages[nr-1].Size)) } else { - sendResponse(conn, fmt.Sprintf("+OK %d messages (%d octets)", len(messages), int64(totalSize))) + sendResponse(conn, fmt.Sprintf("+OK %d messages (%d octets)", len(messages), totalSize)) for row, m := range messages { - sendResponse(conn, fmt.Sprintf("%d %d", row+1, int64(m.Size))) // Convert Size to int64 when printing + sendResponse(conn, fmt.Sprintf("%d %d", row+1, m.Size)) } sendResponse(conn, ".") } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 0869e64..af7793f 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -9,6 +9,7 @@ import ( "github.com/axllent/mailpit/config" "github.com/axllent/mailpit/internal/logger" "github.com/axllent/mailpit/internal/storage" + "github.com/axllent/mailpit/internal/tools" ) // Stores cached version along with its expiry time and error count. @@ -146,7 +147,7 @@ func Track() { func LogSMTPAccepted(size int) { mu.Lock() smtpAccepted = smtpAccepted + 1 - smtpAcceptedSize = smtpAcceptedSize + uint64(size) + smtpAcceptedSize = smtpAcceptedSize + tools.SafeUint64(size) mu.Unlock() } diff --git a/internal/storage/messages.go b/internal/storage/messages.go index f4e206d..34d9776 100644 --- a/internal/storage/messages.go +++ b/internal/storage/messages.go @@ -108,7 +108,7 @@ func Store(body *[]byte, username *string) (string, error) { if config.Compression > 0 { // insert compressed raw message - compressed := dbEncoder.EncodeAll(*body, make([]byte, 0, int(size))) + compressed := dbEncoder.EncodeAll(*body, make([]byte, 0, size)) if sqlDriver == "rqlite" { // rqlite does not support binary data in query, so we need to encode the compressed message into hexadecimal @@ -202,7 +202,7 @@ func Store(body *[]byte, username *string) (string, error) { BroadcastMailboxStats() - logger.Log().Debugf("[db] saved message %s (%d bytes)", id, int64(size)) + logger.Log().Debugf("[db] saved message %s (%d bytes)", id, size) return id, nil } diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 5e4e816..dc08675 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -9,6 +9,7 @@ import ( "github.com/axllent/mailpit/internal/html2text" "github.com/axllent/mailpit/internal/logger" + "github.com/axllent/mailpit/internal/tools" "github.com/jhillyerd/enmime/v2" ) @@ -88,7 +89,7 @@ func cleanString(str string) string { // LogMessagesDeleted logs the number of messages deleted func logMessagesDeleted(n int) { mu.Lock() - StatsDeleted = StatsDeleted + uint64(n) + StatsDeleted = StatsDeleted + tools.SafeUint64(n) mu.Unlock() } diff --git a/internal/tools/utils.go b/internal/tools/utils.go index f7a1262..b971a87 100644 --- a/internal/tools/utils.go +++ b/internal/tools/utils.go @@ -36,3 +36,22 @@ func Normalize(s string) string { return strings.TrimSpace(s) } + +// SafeUint64 converts an int or int64 to uint64, ensuring it does not exceed the maximum value for uint64. +func SafeUint64(i any) uint64 { + switch v := i.(type) { + case int: + if v < 0 { + return 0 + } + return uint64(v) + case int64: + if v < 0 { + return 0 + } + return uint64(v) + default: + // only accepts int or int64 + return 0 + } +} diff --git a/server/apiv1/messages.go b/server/apiv1/messages.go index 48009cc..dadb0fa 100644 --- a/server/apiv1/messages.go +++ b/server/apiv1/messages.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/axllent/mailpit/internal/storage" + "github.com/axllent/mailpit/internal/tools" ) // MessagesSummary is a summary of a list of messages @@ -241,9 +242,9 @@ func Search(w http.ResponseWriter, r *http.Request) { res.Start = start res.Messages = messages - res.Count = uint64(len(messages)) // legacy - now undocumented in API specs - res.Total = stats.Total // total messages in mailbox - res.MessagesCount = uint64(results) + res.Count = tools.SafeUint64(len(messages)) // legacy - now undocumented in API specs + res.Total = stats.Total // total messages in mailbox + res.MessagesCount = tools.SafeUint64(results) res.Unread = stats.Unread res.Tags = stats.Tags @@ -253,7 +254,7 @@ func Search(w http.ResponseWriter, r *http.Request) { return } - res.MessagesUnreadCount = uint64(unread) + res.MessagesUnreadCount = tools.SafeUint64(unread) w.Header().Add("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(res); err != nil {