package xmlparser import ( "bufio" "bytes" "errors" "fmt" "io" "regexp" "sync" "unicode/utf8" "golang.org/x/net/html/charset" ) var bufreaderPool = sync.Pool{ New: func() any { return bufio.NewReader(nil) }, } var ( endTagStart = []byte("") commentStart = []byte("") cdataStart = []byte("") doctypeStart = []byte("' { d.setSyntaxErrorf("expected '>' at the end of self-closing tag, got %q", b) return nil, false } selfClosing = true break } if b == '>' { // End of start tag break } // Unread the byte for further processing d.unreadByte(b) // Read attribute name attrName, ok := d.readNSName() if !ok { if d.err == nil { d.setSyntaxErrorf("expected attribute name") } return nil, false } if !d.skipSpaces() { return nil, false } b, ok = d.mustReadByte() if !ok { return nil, false } var attrValue string // If the next byte is not '=', this is an attribute without value. if b != '=' { d.unreadByte(b) } else { if !d.skipSpaces() { return nil, false } // Read attribute value val, ok := d.readAttrValue() if !ok { if d.err == nil { d.setSyntaxErrorf("expected value for attribute %q", attrName) } return nil, false } attrValue = string(val) } attrs = append(attrs, &Attribute{ Name: attrName, Value: attrValue, }) } return &StartElement{ Name: name, Attrs: NewAttributes(attrs...), SelfClosing: selfClosing, }, true } // readAttrValue reads an attribute value. func (d *Decoder) readAttrValue() ([]byte, bool) { d.buf.Reset() b, ok := d.mustReadByte() if !ok { return nil, false } if b == '"' || b == '\'' { // Quoted attribute value // We can just read until the closing quote. if !d.mustReadUntil(b) { return nil, false } // Remove the trailing quote from the buffer d.buf.Remove(1) } else { // Unquoted attribute value. // Unread the byte for further processing d.unreadByte(b) // Read until we meet a byte that is not valid in an unquoted attribute value. if !d.mustReadWhileFn(isValueByte) { return nil, false } } return d.buf.Bytes(), true } // isValueByte checks if a byte is valid in an unquoted attribute value. // See: https://www.w3.org/TR/REC-html40/intro/sgmltut.html#h-3.2.2 func isValueByte(c byte) bool { return 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == ':' || c == '-' } // readEndTag reads an end tag. func (d *Decoder) readEndTag() (Name, bool) { // Discard '' if !d.skipSpaces() { return name, false } // Expect '>' b, ok := d.mustReadByte() if !ok { return name, false } if b != '>' { d.setSyntaxErrorf("expected '>' at the end of end element, got %q", b) return name, false } return name, true } // readProcInst reads a processing instruction (until `?>`). // // If the processing instruction specifies an encoding, it recreates // the reader with the specified encoding. func (d *Decoder) readProcInst() ([]byte, []byte, bool) { // Discard '' // We don't reset the buffer here, as we don't want target name to be overwritten. for { if !d.mustReadUntil('>') { return nil, nil, false } if d.buf.HasSuffix(procInstEnd) { break } } // Trim the trailing '?>' d.buf.Remove(len(procInstEnd)) // Separate the target and data data := d.buf.Bytes()[len(target):] if bytes.Equal(target, targetXML) { // Get the encoding from the processing instruction data data = d.handleProcInstEncoding(data) } return target, data, true } // handleProcInstEncoding replaces the encoding declaration in the processing instruction data // with "UTF-8" and returns the updated data. // It also recreates the reader with defined encoding. func (d *Decoder) handleProcInstEncoding(data []byte) []byte { matches := encodingRE.FindSubmatch(data) if matches == nil { // No encoding declaration found, return original data without changes return data } // Get the encoding from the processing instruction data encoding := bytes.Trim(matches[2], `"'`) if bytes.EqualFold(encoding, []byte("utf-8")) || bytes.EqualFold(encoding, []byte("utf8")) { // No need for special handling if encoding is already UTF-8 return data } // Recreate the reader with defined encoding. // If the encoding is UTF-16/32, we have already handled it in the BOM check. if len(encoding) < 3 || !bytes.EqualFold(encoding[:3], []byte("utf")) { if !d.setEncoding(string(encoding)) { return data } } // Build the updated data with "UTF-8" encoding. // We write it to the buffer that already contains the processing instruction data, // so we mark the position of the updated data start. start := d.buf.Len() d.buf.Write(matches[1]) // Up to encoding= d.buf.Write([]byte(`"UTF-8"`)) // New encoding d.buf.Write(matches[3]) // After encoding declaration updated := d.buf.Bytes()[start:] return updated } // readComment reads a comment (until `-->`). func (d *Decoder) readComment() ([]byte, bool) { if !d.checkAndDiscardPrefix(commentStart) { if d.err == nil { d.setSyntaxErrorf("invalid sequence ') { return nil, false } if d.buf.HasSuffix(commentEnd) { break } } // Trim the trailing '-->' d.buf.Remove(len(commentEnd)) return d.buf.Bytes(), true } // readCData reads a CDATA section (until `]]>`). func (d *Decoder) readCData() ([]byte, bool) { if !d.checkAndDiscardPrefix(cdataStart) { if d.err == nil { d.setSyntaxErrorf("invalid sequence ' for { if !d.mustReadUntil('>') { return nil, false } if d.buf.HasSuffix(cdataEnd) { break } } // Trim the trailing ']]>' d.buf.Remove(len(cdataEnd)) return d.buf.Bytes(), true } // readDoctype reads a directive (until `>`). func (d *Decoder) readDoctype() ([]byte, bool) { if !d.checkAndDiscardPrefix(doctypeStart) { if d.err == nil { d.setSyntaxErrorf("invalid sequence ' for { b, ok := d.mustReadByte() if !ok { return nil, false } d.buf.WriteByte(b) switch { case b == inQuote: // We met the closing quote, exit quote mode. inQuote = 0 case inQuote != 0: // Inside a quote, do nothing. case b == '"' || b == '\'': // We met an opening quote, enter quote mode. inQuote = b case b == ']': // We met a closing bracket. // If we are not inside brackets, this is an error. if !inBrackets { d.setSyntaxErrorf("unexpected ']' in directive") return nil, false } // Otherwise, exit brackets mode. inBrackets = false case b == '[': // We met an opening bracket. // If we are already inside brackets, this is an error. if inBrackets { d.setSyntaxErrorf("nested '[' in directive") return nil, false } // Otherwise, enter brackets mode. inBrackets = true case inBrackets: // Inside brackets, do nothing case b == '<': // Unexpected '<' outside quotes and brackets. d.setSyntaxErrorf("unexpected '<' in directive") return nil, false case b == '>': // End of directive. // Trim the trailing '>' from the buffer and return. d.buf.Remove(1) return d.buf.Bytes(), true } } } // readNSName reads a name with optional namespace prefix (e.g., "svg:svg"). func (d *Decoder) readNSName() (Name, bool) { if !d.readName() { return Name(""), false } return Name(d.buf.Bytes()), true } // readName reads a name (tag or attribute) to the buffer until a non-name byte is encountered. func (d *Decoder) readName() bool { d.buf.Reset() if !d.mustReadWhileFn(isNameByte) { return false } return d.buf.Len() > 0 } func isNameByte(c byte) bool { // We allow all non-ASCII bytes as names. return c >= utf8.RuneSelf || 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == ':' || c == '.' || c == '-' } // skipSpaces skips whitespace characters. func (d *Decoder) skipSpaces() bool { for { b, ok := d.mustPeekBuffered() if !ok { return false } found := false for i, c := range b { if !isSpace(c) { // Found a non-space byte. // Trim the bytes up to (but not including) this byte. b = b[:i] found = true break } } // Discard the spaces we've read. if !d.discard(len(b)) { return false } if found { // We've skipped all spaces, break the loop. return true } } } // isSpace checks if a byte is a whitespace character. func isSpace(b byte) bool { return b == ' ' || b == '\r' || b == '\n' || b == '\t' } func (d *Decoder) checkAndDiscardPrefix(prefix []byte) bool { prefixLen := len(prefix) b, ok := d.mustPeek(prefixLen) if !ok { return false } if !bytes.Equal(b, prefix) { return false } return d.discard(prefixLen) } // readByte reads a single byte from the reader. // If an error occurs, it sets d.err and returns false. func (d *Decoder) readByte() (byte, bool) { b, err := d.r.ReadByte() if err != nil { d.err = err return 0, false } if b == '\n' { d.line++ } return b, true } // mustReadByte reads a single byte from the reader. // If an error occurs, it sets d.err and returns false. // If io.EOF is encountered, it sets d.err to a more descriptive error. func (d *Decoder) mustReadByte() (byte, bool) { b, ok := d.readByte() if !ok { if errors.Is(d.err, io.EOF) { d.setSyntaxErrorf("unexpected EOF") } } return b, ok } // unreadByte unreads the last byte read from the reader. // If an error occurs, it sets d.err and returns false. // //nolint:unparam func (d *Decoder) unreadByte(b byte) bool { if err := d.r.UnreadByte(); err != nil { d.err = err return false } if b == '\n' { d.line-- } return true } // readUntil reads bytes to the buffer until the specified delimiter byte is encountered. // The delimiter byte is included in the buffer. // If an error occurs, it sets d.err and returns false. func (d *Decoder) readUntil(delim byte) bool { for { b, err := d.r.ReadSlice(delim) if err != nil && !errors.Is(err, bufio.ErrBufferFull) && !errors.Is(err, io.EOF) { d.err = err return false } d.buf.Write(b) d.countNewLines(b) if err == nil { // We've read up to the delimiter byte, break the loop. return true } if err == io.EOF { // Reached EOF without finding the delimiter. d.err = err return false } } } // mustReadUntil reads bytes to the buffer until the specified delimiter byte is encountered. // The delimiter byte is included in the buffer. // If an error occurs, it sets d.err and returns false. // If io.EOF is encountered, it sets d.err to a more descriptive error. func (d *Decoder) mustReadUntil(delim byte) bool { if !d.readUntil(delim) { if errors.Is(d.err, io.EOF) { d.setSyntaxErrorf("unexpected EOF") } return false } return true } // mustReadWhileFn reads bytes to the buffer while the provided function returns true. // The byte that causes the function to return false is not included in the buffer. func (d *Decoder) mustReadWhileFn(f func(byte) bool) bool { for { b, ok := d.mustPeekBuffered() if !ok { return false } found := false for i, c := range b { if !f(c) { // Found a byte that does not satisfy the condition. // Trim the bytes up to (but not including) this byte. b = b[:i] found = true break } } d.buf.Write(b) // Discard the bytes we've read. if !d.discard(len(b)) { return false } if found { // We've read up to the delimiter byte, break the loop. return true } } } // peek peeks at the next n bytes without advancing the reader. // If an error occurs, it sets d.err and returns false. func (d *Decoder) peek(n int) ([]byte, bool) { b, err := d.r.Peek(n) if err != nil { d.err = err return nil, false } return b, true } // mustPeek peeks at the next n bytes without advancing the reader. // If an error occurs, it sets d.err and returns false. // If io.EOF is encountered, it sets d.err to a more descriptive error. func (d *Decoder) mustPeek(n int) ([]byte, bool) { b, ok := d.peek(n) if !ok { if errors.Is(d.err, io.EOF) { d.setSyntaxErrorf("unexpected EOF") } } return b, ok } // mustPeekBuffered peeks at all currently buffered bytes without advancing the reader. // If no bytes are buffered, it peeks at least 1 byte. // If an error occurs, it sets d.err and returns false. // If io.EOF is encountered, it sets d.err to a more descriptive error. func (d *Decoder) mustPeekBuffered() ([]byte, bool) { toPeek := max(d.r.Buffered(), 1) return d.mustPeek(toPeek) } // discard discards the next n bytes from the reader. func (d *Decoder) discard(n int) bool { // Peek bytes we want to discard to count new lines. if b, err := d.r.Peek(n); err == nil { d.countNewLines(b) } _, err := d.r.Discard(n) if err != nil { d.err = err return false } return true } // countNewLines counts the number of new lines in the given byte slice // and increments the decoder's line counter accordingly. func (d *Decoder) countNewLines(b []byte) { // Somehow this is more efficient than bytes.Count... for { ind := bytes.IndexByte(b, '\n') if ind < 0 { break } d.line++ b = b[ind+1:] } } // setSyntaxErrorf sets a syntax error with the current line number. func (d *Decoder) setSyntaxErrorf(format string, a ...any) { msg := fmt.Sprintf(format, a...) d.err = newSyntaxError("%s (line %d)", msg, d.line) }