1
0
mirror of https://github.com/rclone/rclone.git synced 2025-11-23 21:44:49 +02:00
Files
rclone/backend/compress/szstd_helper.go
Alex 8f74e7d331 backend/compress: add zstd compression
Added support for reading and writing zstd-compressed archives in seekable format
using "github.com/klauspost/compress/zstd" and
"github.com/SaveTheRbtz/zstd-seekable-format-go/pkg".

Bumped Go version from 1.24.0 to 1.24.4 due to requirements of
"github.com/SaveTheRbtz/zstd-seekable-format-go/pkg".
2025-11-04 14:50:56 +00:00

328 lines
7.5 KiB
Go

package compress
import (
"context"
"errors"
"io"
"runtime"
"sync"
szstd "github.com/a1ex3/zstd-seekable-format-go/pkg"
"github.com/klauspost/compress/zstd"
)
const szstdChunkSize int = 1 << 20 // 1 MiB chunk size
// SzstdMetadata holds metadata for szstd compressed files.
type SzstdMetadata struct {
BlockSize int // BlockSize is the size of the blocks in the zstd file
Size int64 // Size is the uncompressed size of the file
BlockData []uint32 // BlockData is the block data for the zstd file, used for seeking
}
// SzstdWriter is a writer that compresses data in szstd format.
type SzstdWriter struct {
enc *zstd.Encoder
w szstd.ConcurrentWriter
metadata SzstdMetadata
mu sync.Mutex
}
// NewWriterSzstd creates a new szstd writer with the specified options.
// It initializes the szstd writer with a zstd encoder and returns a pointer to the SzstdWriter.
// The writer can be used to write data in chunks, and it will automatically handle block sizes and metadata.
func NewWriterSzstd(w io.Writer, opts ...zstd.EOption) (*SzstdWriter, error) {
encoder, err := zstd.NewWriter(nil, opts...)
if err != nil {
return nil, err
}
sw, err := szstd.NewWriter(w, encoder)
if err != nil {
if err := encoder.Close(); err != nil {
return nil, err
}
return nil, err
}
return &SzstdWriter{
enc: encoder,
w: sw,
metadata: SzstdMetadata{
BlockSize: szstdChunkSize,
Size: 0,
},
}, nil
}
// Write writes data to the szstd writer in chunks of szstdChunkSize.
// It handles the block size and metadata updates automatically.
func (w *SzstdWriter) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if w.metadata.BlockData == nil {
numBlocks := (len(p) + w.metadata.BlockSize - 1) / w.metadata.BlockSize
w.metadata.BlockData = make([]uint32, 1, numBlocks+1)
w.metadata.BlockData[0] = 0
}
start := 0
total := len(p)
var writerFunc szstd.FrameSource = func() ([]byte, error) {
if start >= total {
return nil, nil
}
end := min(start+w.metadata.BlockSize, total)
chunk := p[start:end]
size := end - start
w.mu.Lock()
w.metadata.Size += int64(size)
w.mu.Unlock()
start = end
return chunk, nil
}
// write sizes of compressed blocks in the callback
err := w.w.WriteMany(context.Background(), writerFunc,
szstd.WithWriteCallback(func(size uint32) {
w.mu.Lock()
lastOffset := w.metadata.BlockData[len(w.metadata.BlockData)-1]
w.metadata.BlockData = append(w.metadata.BlockData, lastOffset+size)
w.mu.Unlock()
}),
)
if err != nil {
return 0, err
}
return total, nil
}
// Close closes the SzstdWriter and its underlying encoder.
func (w *SzstdWriter) Close() error {
if err := w.w.Close(); err != nil {
return err
}
if err := w.enc.Close(); err != nil {
return err
}
return nil
}
// GetMetadata returns the metadata of the szstd writer.
func (w *SzstdWriter) GetMetadata() SzstdMetadata {
return w.metadata
}
// SzstdReaderAt is a reader that allows random access in szstd compressed data.
type SzstdReaderAt struct {
r szstd.Reader
decoder *zstd.Decoder
metadata *SzstdMetadata
pos int64
mu sync.Mutex
}
// NewReaderAtSzstd creates a new SzstdReaderAt at the specified io.ReadSeeker.
func NewReaderAtSzstd(rs io.ReadSeeker, meta *SzstdMetadata, offset int64, opts ...zstd.DOption) (*SzstdReaderAt, error) {
decoder, err := zstd.NewReader(nil, opts...)
if err != nil {
return nil, err
}
r, err := szstd.NewReader(rs, decoder)
if err != nil {
decoder.Close()
return nil, err
}
sr := &SzstdReaderAt{
r: r,
decoder: decoder,
metadata: meta,
pos: 0,
}
// Set initial position to the provided offset
if _, err := sr.Seek(offset, io.SeekStart); err != nil {
if err := sr.Close(); err != nil {
return nil, err
}
return nil, err
}
return sr, nil
}
// Seek sets the offset for the next Read.
func (s *SzstdReaderAt) Seek(offset int64, whence int) (int64, error) {
s.mu.Lock()
defer s.mu.Unlock()
pos, err := s.r.Seek(offset, whence)
if err == nil {
s.pos = pos
}
return pos, err
}
func (s *SzstdReaderAt) Read(p []byte) (int, error) {
s.mu.Lock()
defer s.mu.Unlock()
n, err := s.r.Read(p)
if err == nil {
s.pos += int64(n)
}
return n, err
}
// ReadAt reads data at the specified offset.
func (s *SzstdReaderAt) ReadAt(p []byte, off int64) (int, error) {
if off < 0 {
return 0, errors.New("invalid offset")
}
if off >= s.metadata.Size {
return 0, io.EOF
}
endOff := min(off+int64(len(p)), s.metadata.Size)
// Find all blocks covered by the range
type blockInfo struct {
index int // Block index
offsetInBlock int64 // Offset within the block for starting reading
bytesToRead int64 // How many bytes to read from this block
}
var blocks []blockInfo
uncompressedOffset := int64(0)
currentOff := off
for i := 0; i < len(s.metadata.BlockData)-1; i++ {
blockUncompressedEnd := min(uncompressedOffset+int64(s.metadata.BlockSize), s.metadata.Size)
if currentOff < blockUncompressedEnd && endOff > uncompressedOffset {
offsetInBlock := max(0, currentOff-uncompressedOffset)
bytesToRead := min(blockUncompressedEnd-uncompressedOffset-offsetInBlock, endOff-currentOff)
blocks = append(blocks, blockInfo{
index: i,
offsetInBlock: offsetInBlock,
bytesToRead: bytesToRead,
})
currentOff += bytesToRead
if currentOff >= endOff {
break
}
}
uncompressedOffset = blockUncompressedEnd
}
if len(blocks) == 0 {
return 0, io.EOF
}
// Parallel block decoding
type decodeResult struct {
index int
data []byte
err error
}
resultCh := make(chan decodeResult, len(blocks))
var wg sync.WaitGroup
sem := make(chan struct{}, runtime.NumCPU())
for _, block := range blocks {
wg.Add(1)
go func(block blockInfo) {
defer wg.Done()
sem <- struct{}{}
defer func() { <-sem }()
startOffset := int64(s.metadata.BlockData[block.index])
endOffset := int64(s.metadata.BlockData[block.index+1])
compressedSize := endOffset - startOffset
compressed := make([]byte, compressedSize)
_, err := s.r.ReadAt(compressed, startOffset)
if err != nil && err != io.EOF {
resultCh <- decodeResult{index: block.index, err: err}
return
}
decoded, err := s.decoder.DecodeAll(compressed, nil)
if err != nil {
resultCh <- decodeResult{index: block.index, err: err}
return
}
resultCh <- decodeResult{index: block.index, data: decoded, err: nil}
}(block)
}
go func() {
wg.Wait()
close(resultCh)
}()
// Collect results in block index order
totalRead := 0
results := make(map[int]decodeResult)
expected := len(blocks)
minIndex := blocks[0].index
for res := range resultCh {
results[res.index] = res
for {
if result, ok := results[minIndex]; ok {
if result.err != nil {
return 0, result.err
}
// find the corresponding blockInfo
var blk blockInfo
for _, b := range blocks {
if b.index == result.index {
blk = b
break
}
}
start := blk.offsetInBlock
end := start + blk.bytesToRead
copy(p[totalRead:totalRead+int(blk.bytesToRead)], result.data[start:end])
totalRead += int(blk.bytesToRead)
minIndex++
if minIndex-blocks[0].index >= len(blocks) {
break
}
} else {
break
}
}
if len(results) == expected && minIndex-blocks[0].index >= len(blocks) {
break
}
}
return totalRead, nil
}
// Close closes the SzstdReaderAt and underlying decoder.
func (s *SzstdReaderAt) Close() error {
if err := s.r.Close(); err != nil {
return err
}
s.decoder.Close()
return nil
}