1
0
mirror of https://github.com/axllent/mailpit.git synced 2025-01-04 00:15:54 +02:00
mailpit/internal/htmlcheck/css.go
2023-10-31 15:46:25 +13:00

220 lines
4.9 KiB
Go

package htmlcheck
import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/axllent/mailpit/config"
"github.com/axllent/mailpit/internal/logger"
"github.com/axllent/mailpit/internal/tools"
"github.com/vanng822/go-premailer/premailer"
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
)
// Go cannot calculate any rendered CSS attributes, so we merge all styles
// into the HTML and detect elements with styles containing the keywords.
func runCSSTests(html string) ([]Warning, int, error) {
results := []Warning{}
totalTests := 0
inlined, err := inlineRemoteCSS(html)
if err != nil {
// logger.Log().Warn(err)
inlined = html
}
// merge all CSS inline
merged, err := mergeInlineCSS(inlined)
if err != nil {
// logger.Log().Warn(err)
merged = inlined
}
reader := strings.NewReader(merged)
// Load the HTML document
doc, err := goquery.NewDocumentFromReader(reader)
if err != nil {
return results, totalTests, err
}
for key, test := range cssInlineTests {
totalTests++
found := len(doc.Find(test).Nodes)
if found > 0 {
result, err := cie.getTest(key)
if err != nil {
return results, totalTests, err
}
result.Score.Found = found
results = append(results, result)
}
}
// get a list of all generated styles from all nodes
allNodeStyles := []string{}
for _, n := range doc.Find("*[style]").Nodes {
style, err := tools.GetHTMLAttributeVal(n, "style")
if err == nil {
allNodeStyles = append(allNodeStyles, style)
}
}
for key, re := range cssRegexpUnitTests {
totalTests++
result, err := cie.getTest(key)
if err != nil {
return results, totalTests, err
}
found := 0
// loop through all styles to count total
for _, styles := range allNodeStyles {
found = found + len(re.FindAllString(styles, -1))
}
if found > 0 {
result.Score.Found = found
results = append(results, result)
}
}
// get all inline CSS block data
reader = strings.NewReader(inlined)
// Load the HTML document
doc, _ = goquery.NewDocumentFromReader(reader)
cssCode := ""
for _, n := range doc.Find("style").Nodes {
for c := n.FirstChild; c != nil; c = c.NextSibling {
cssCode = cssCode + c.Data
}
}
for key, re := range cssRegexpTests {
totalTests++
result, err := cie.getTest(key)
if err != nil {
return results, totalTests, err
}
found := len(re.FindAllString(cssCode, -1))
if found > 0 {
result.Score.Found = found
results = append(results, result)
}
}
return results, totalTests, nil
}
// MergeInlineCSS merges header CSS into element attributes
func mergeInlineCSS(html string) (string, error) {
options := premailer.NewOptions()
// options.RemoveClasses = true
// options.CssToAttributes = false
options.KeepBangImportant = true
pre, err := premailer.NewPremailerFromString(html, options)
if err != nil {
return "", err
}
return pre.Transform()
}
// InlineRemoteCSS searches the HTML for linked stylesheets, downloads the content, and
// inserts new <style> blocks into the head, unless BlockRemoteCSSAndFonts is set
func inlineRemoteCSS(h string) (string, error) {
reader := strings.NewReader(h)
// Load the HTML document
doc, err := goquery.NewDocumentFromReader(reader)
if err != nil {
return h, err
}
remoteCSS := doc.Find("link[rel=\"stylesheet\"]").Nodes
for _, link := range remoteCSS {
attributes := link.Attr
for _, a := range attributes {
if a.Key == "href" {
if !isURL(a.Val) {
// skip invalid URL
continue
}
if config.BlockRemoteCSSAndFonts {
logger.Log().Debugf("[html-check] skip testing remote CSS content: %s (--block-remote-css-and-fonts)", a.Val)
return h, nil
}
resp, err := downloadToBytes(a.Val)
if err != nil {
logger.Log().Warningf("html check failed to download %s", a.Val)
continue
}
// create new <style> block and insert downloaded CSS
styleBlock := &html.Node{
Type: html.ElementNode,
Data: "style",
DataAtom: atom.Style,
}
styleBlock.AppendChild(&html.Node{
Type: html.TextNode,
Data: string(resp),
})
link.Parent.AppendChild(styleBlock)
}
}
}
newDoc, err := doc.Html()
if err != nil {
logger.Log().Warning(err)
return h, err
}
return newDoc, nil
}
// DownloadToBytes returns a []byte slice from a URL
func downloadToBytes(url string) ([]byte, error) {
client := http.Client{
Timeout: 5 * time.Second,
}
// Get the link response data
resp, err := client.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
err := fmt.Errorf("Error downloading %s", url)
return nil, err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return body, nil
}
// Test if str is a URL
func isURL(str string) bool {
u, err := url.Parse(str)
return err == nil && (u.Scheme == "http" || u.Scheme == "https") && u.Host != ""
}