mirror of
https://github.com/rclone/rclone.git
synced 2025-11-23 21:44:49 +02:00
Added support for reading and writing zstd-compressed archives in seekable format using "github.com/klauspost/compress/zstd" and "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg". Bumped Go version from 1.24.0 to 1.24.4 due to requirements of "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg".
328 lines
7.5 KiB
Go
328 lines
7.5 KiB
Go
package compress
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"runtime"
|
|
"sync"
|
|
|
|
szstd "github.com/a1ex3/zstd-seekable-format-go/pkg"
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
const szstdChunkSize int = 1 << 20 // 1 MiB chunk size
|
|
|
|
// SzstdMetadata holds metadata for szstd compressed files.
|
|
type SzstdMetadata struct {
|
|
BlockSize int // BlockSize is the size of the blocks in the zstd file
|
|
Size int64 // Size is the uncompressed size of the file
|
|
BlockData []uint32 // BlockData is the block data for the zstd file, used for seeking
|
|
}
|
|
|
|
// SzstdWriter is a writer that compresses data in szstd format.
|
|
type SzstdWriter struct {
|
|
enc *zstd.Encoder
|
|
w szstd.ConcurrentWriter
|
|
metadata SzstdMetadata
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewWriterSzstd creates a new szstd writer with the specified options.
|
|
// It initializes the szstd writer with a zstd encoder and returns a pointer to the SzstdWriter.
|
|
// The writer can be used to write data in chunks, and it will automatically handle block sizes and metadata.
|
|
func NewWriterSzstd(w io.Writer, opts ...zstd.EOption) (*SzstdWriter, error) {
|
|
encoder, err := zstd.NewWriter(nil, opts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sw, err := szstd.NewWriter(w, encoder)
|
|
if err != nil {
|
|
if err := encoder.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &SzstdWriter{
|
|
enc: encoder,
|
|
w: sw,
|
|
metadata: SzstdMetadata{
|
|
BlockSize: szstdChunkSize,
|
|
Size: 0,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Write writes data to the szstd writer in chunks of szstdChunkSize.
|
|
// It handles the block size and metadata updates automatically.
|
|
func (w *SzstdWriter) Write(p []byte) (int, error) {
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
if w.metadata.BlockData == nil {
|
|
numBlocks := (len(p) + w.metadata.BlockSize - 1) / w.metadata.BlockSize
|
|
w.metadata.BlockData = make([]uint32, 1, numBlocks+1)
|
|
w.metadata.BlockData[0] = 0
|
|
}
|
|
|
|
start := 0
|
|
total := len(p)
|
|
|
|
var writerFunc szstd.FrameSource = func() ([]byte, error) {
|
|
if start >= total {
|
|
return nil, nil
|
|
}
|
|
|
|
end := min(start+w.metadata.BlockSize, total)
|
|
chunk := p[start:end]
|
|
size := end - start
|
|
|
|
w.mu.Lock()
|
|
w.metadata.Size += int64(size)
|
|
w.mu.Unlock()
|
|
|
|
start = end
|
|
return chunk, nil
|
|
}
|
|
|
|
// write sizes of compressed blocks in the callback
|
|
err := w.w.WriteMany(context.Background(), writerFunc,
|
|
szstd.WithWriteCallback(func(size uint32) {
|
|
w.mu.Lock()
|
|
lastOffset := w.metadata.BlockData[len(w.metadata.BlockData)-1]
|
|
w.metadata.BlockData = append(w.metadata.BlockData, lastOffset+size)
|
|
w.mu.Unlock()
|
|
}),
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return total, nil
|
|
}
|
|
|
|
// Close closes the SzstdWriter and its underlying encoder.
|
|
func (w *SzstdWriter) Close() error {
|
|
if err := w.w.Close(); err != nil {
|
|
return err
|
|
}
|
|
if err := w.enc.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetMetadata returns the metadata of the szstd writer.
|
|
func (w *SzstdWriter) GetMetadata() SzstdMetadata {
|
|
return w.metadata
|
|
}
|
|
|
|
// SzstdReaderAt is a reader that allows random access in szstd compressed data.
|
|
type SzstdReaderAt struct {
|
|
r szstd.Reader
|
|
decoder *zstd.Decoder
|
|
metadata *SzstdMetadata
|
|
pos int64
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// NewReaderAtSzstd creates a new SzstdReaderAt at the specified io.ReadSeeker.
|
|
func NewReaderAtSzstd(rs io.ReadSeeker, meta *SzstdMetadata, offset int64, opts ...zstd.DOption) (*SzstdReaderAt, error) {
|
|
decoder, err := zstd.NewReader(nil, opts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r, err := szstd.NewReader(rs, decoder)
|
|
if err != nil {
|
|
decoder.Close()
|
|
return nil, err
|
|
}
|
|
|
|
sr := &SzstdReaderAt{
|
|
r: r,
|
|
decoder: decoder,
|
|
metadata: meta,
|
|
pos: 0,
|
|
}
|
|
|
|
// Set initial position to the provided offset
|
|
if _, err := sr.Seek(offset, io.SeekStart); err != nil {
|
|
if err := sr.Close(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return sr, nil
|
|
}
|
|
|
|
// Seek sets the offset for the next Read.
|
|
func (s *SzstdReaderAt) Seek(offset int64, whence int) (int64, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
pos, err := s.r.Seek(offset, whence)
|
|
if err == nil {
|
|
s.pos = pos
|
|
}
|
|
return pos, err
|
|
}
|
|
|
|
func (s *SzstdReaderAt) Read(p []byte) (int, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
n, err := s.r.Read(p)
|
|
if err == nil {
|
|
s.pos += int64(n)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// ReadAt reads data at the specified offset.
|
|
func (s *SzstdReaderAt) ReadAt(p []byte, off int64) (int, error) {
|
|
if off < 0 {
|
|
return 0, errors.New("invalid offset")
|
|
}
|
|
if off >= s.metadata.Size {
|
|
return 0, io.EOF
|
|
}
|
|
|
|
endOff := min(off+int64(len(p)), s.metadata.Size)
|
|
|
|
// Find all blocks covered by the range
|
|
type blockInfo struct {
|
|
index int // Block index
|
|
offsetInBlock int64 // Offset within the block for starting reading
|
|
bytesToRead int64 // How many bytes to read from this block
|
|
}
|
|
|
|
var blocks []blockInfo
|
|
uncompressedOffset := int64(0)
|
|
currentOff := off
|
|
|
|
for i := 0; i < len(s.metadata.BlockData)-1; i++ {
|
|
blockUncompressedEnd := min(uncompressedOffset+int64(s.metadata.BlockSize), s.metadata.Size)
|
|
|
|
if currentOff < blockUncompressedEnd && endOff > uncompressedOffset {
|
|
offsetInBlock := max(0, currentOff-uncompressedOffset)
|
|
bytesToRead := min(blockUncompressedEnd-uncompressedOffset-offsetInBlock, endOff-currentOff)
|
|
|
|
blocks = append(blocks, blockInfo{
|
|
index: i,
|
|
offsetInBlock: offsetInBlock,
|
|
bytesToRead: bytesToRead,
|
|
})
|
|
|
|
currentOff += bytesToRead
|
|
if currentOff >= endOff {
|
|
break
|
|
}
|
|
}
|
|
uncompressedOffset = blockUncompressedEnd
|
|
}
|
|
|
|
if len(blocks) == 0 {
|
|
return 0, io.EOF
|
|
}
|
|
|
|
// Parallel block decoding
|
|
type decodeResult struct {
|
|
index int
|
|
data []byte
|
|
err error
|
|
}
|
|
|
|
resultCh := make(chan decodeResult, len(blocks))
|
|
var wg sync.WaitGroup
|
|
sem := make(chan struct{}, runtime.NumCPU())
|
|
|
|
for _, block := range blocks {
|
|
wg.Add(1)
|
|
go func(block blockInfo) {
|
|
defer wg.Done()
|
|
sem <- struct{}{}
|
|
defer func() { <-sem }()
|
|
|
|
startOffset := int64(s.metadata.BlockData[block.index])
|
|
endOffset := int64(s.metadata.BlockData[block.index+1])
|
|
compressedSize := endOffset - startOffset
|
|
|
|
compressed := make([]byte, compressedSize)
|
|
_, err := s.r.ReadAt(compressed, startOffset)
|
|
if err != nil && err != io.EOF {
|
|
resultCh <- decodeResult{index: block.index, err: err}
|
|
return
|
|
}
|
|
|
|
decoded, err := s.decoder.DecodeAll(compressed, nil)
|
|
if err != nil {
|
|
resultCh <- decodeResult{index: block.index, err: err}
|
|
return
|
|
}
|
|
|
|
resultCh <- decodeResult{index: block.index, data: decoded, err: nil}
|
|
}(block)
|
|
}
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
close(resultCh)
|
|
}()
|
|
|
|
// Collect results in block index order
|
|
totalRead := 0
|
|
results := make(map[int]decodeResult)
|
|
expected := len(blocks)
|
|
minIndex := blocks[0].index
|
|
|
|
for res := range resultCh {
|
|
results[res.index] = res
|
|
for {
|
|
if result, ok := results[minIndex]; ok {
|
|
if result.err != nil {
|
|
return 0, result.err
|
|
}
|
|
// find the corresponding blockInfo
|
|
var blk blockInfo
|
|
for _, b := range blocks {
|
|
if b.index == result.index {
|
|
blk = b
|
|
break
|
|
}
|
|
}
|
|
|
|
start := blk.offsetInBlock
|
|
end := start + blk.bytesToRead
|
|
copy(p[totalRead:totalRead+int(blk.bytesToRead)], result.data[start:end])
|
|
totalRead += int(blk.bytesToRead)
|
|
minIndex++
|
|
if minIndex-blocks[0].index >= len(blocks) {
|
|
break
|
|
}
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
if len(results) == expected && minIndex-blocks[0].index >= len(blocks) {
|
|
break
|
|
}
|
|
}
|
|
|
|
return totalRead, nil
|
|
}
|
|
|
|
// Close closes the SzstdReaderAt and underlying decoder.
|
|
func (s *SzstdReaderAt) Close() error {
|
|
if err := s.r.Close(); err != nil {
|
|
return err
|
|
}
|
|
s.decoder.Close()
|
|
return nil
|
|
}
|