1
0
mirror of https://github.com/axllent/mailpit.git synced 2025-09-16 09:26:37 +02:00

Security: Prevent integer overflow conversion to uint64

This commit is contained in:
Ralph Slooten
2025-07-25 20:32:31 +12:00
parent fbc1dc6118
commit 5a4d13b15a
6 changed files with 34 additions and 12 deletions

View File

@@ -215,7 +215,7 @@ func handleTransactionCommand(conn net.Conn, cmd string, args []string, messages
for _, m := range messages {
totalSize += m.Size
}
sendResponse(conn, fmt.Sprintf("+OK %d %d", len(messages), int64(totalSize)))
sendResponse(conn, fmt.Sprintf("+OK %d %d", len(messages), totalSize))
case "LIST":
totalSize := uint64(0)
for _, m := range messages {
@@ -229,12 +229,12 @@ func handleTransactionCommand(conn net.Conn, cmd string, args []string, messages
sendResponse(conn, "-ERR no such message")
return
}
sendResponse(conn, fmt.Sprintf("+OK %d %d", nr, int64(messages[nr-1].Size)))
sendResponse(conn, fmt.Sprintf("+OK %d %d", nr, messages[nr-1].Size))
} else {
sendResponse(conn, fmt.Sprintf("+OK %d messages (%d octets)", len(messages), int64(totalSize)))
sendResponse(conn, fmt.Sprintf("+OK %d messages (%d octets)", len(messages), totalSize))
for row, m := range messages {
sendResponse(conn, fmt.Sprintf("%d %d", row+1, int64(m.Size))) // Convert Size to int64 when printing
sendResponse(conn, fmt.Sprintf("%d %d", row+1, m.Size))
}
sendResponse(conn, ".")
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/axllent/mailpit/config"
"github.com/axllent/mailpit/internal/logger"
"github.com/axllent/mailpit/internal/storage"
"github.com/axllent/mailpit/internal/tools"
)
// Stores cached version along with its expiry time and error count.
@@ -146,7 +147,7 @@ func Track() {
func LogSMTPAccepted(size int) {
mu.Lock()
smtpAccepted = smtpAccepted + 1
smtpAcceptedSize = smtpAcceptedSize + uint64(size)
smtpAcceptedSize = smtpAcceptedSize + tools.SafeUint64(size)
mu.Unlock()
}

View File

@@ -108,7 +108,7 @@ func Store(body *[]byte, username *string) (string, error) {
if config.Compression > 0 {
// insert compressed raw message
compressed := dbEncoder.EncodeAll(*body, make([]byte, 0, int(size)))
compressed := dbEncoder.EncodeAll(*body, make([]byte, 0, size))
if sqlDriver == "rqlite" {
// rqlite does not support binary data in query, so we need to encode the compressed message into hexadecimal
@@ -202,7 +202,7 @@ func Store(body *[]byte, username *string) (string, error) {
BroadcastMailboxStats()
logger.Log().Debugf("[db] saved message %s (%d bytes)", id, int64(size))
logger.Log().Debugf("[db] saved message %s (%d bytes)", id, size)
return id, nil
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/axllent/mailpit/internal/html2text"
"github.com/axllent/mailpit/internal/logger"
"github.com/axllent/mailpit/internal/tools"
"github.com/jhillyerd/enmime/v2"
)
@@ -88,7 +89,7 @@ func cleanString(str string) string {
// LogMessagesDeleted logs the number of messages deleted
func logMessagesDeleted(n int) {
mu.Lock()
StatsDeleted = StatsDeleted + uint64(n)
StatsDeleted = StatsDeleted + tools.SafeUint64(n)
mu.Unlock()
}

View File

@@ -36,3 +36,22 @@ func Normalize(s string) string {
return strings.TrimSpace(s)
}
// SafeUint64 converts an int or int64 to uint64, ensuring it does not exceed the maximum value for uint64.
func SafeUint64(i any) uint64 {
switch v := i.(type) {
case int:
if v < 0 {
return 0
}
return uint64(v)
case int64:
if v < 0 {
return 0
}
return uint64(v)
default:
// only accepts int or int64
return 0
}
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/axllent/mailpit/internal/storage"
"github.com/axllent/mailpit/internal/tools"
)
// MessagesSummary is a summary of a list of messages
@@ -241,9 +242,9 @@ func Search(w http.ResponseWriter, r *http.Request) {
res.Start = start
res.Messages = messages
res.Count = uint64(len(messages)) // legacy - now undocumented in API specs
res.Total = stats.Total // total messages in mailbox
res.MessagesCount = uint64(results)
res.Count = tools.SafeUint64(len(messages)) // legacy - now undocumented in API specs
res.Total = stats.Total // total messages in mailbox
res.MessagesCount = tools.SafeUint64(results)
res.Unread = stats.Unread
res.Tags = stats.Tags
@@ -253,7 +254,7 @@ func Search(w http.ResponseWriter, r *http.Request) {
return
}
res.MessagesUnreadCount = uint64(unread)
res.MessagesUnreadCount = tools.SafeUint64(unread)
w.Header().Add("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(res); err != nil {