1
0
mirror of https://github.com/rclone/rclone.git synced 2025-01-24 12:56:36 +02:00
rclone/fs/operations/reopen_test.go
Nick Craig-Wood 93955b755f operations: fix retries downloading too much data with certain backends
Before this fix if more than one retry happened on a file that rclone
had opened for read with a backend that uses fs.FixRangeOption then
rclone would read too much data and the transfer would fail.

Backends affected:

- azureblob, azurefiles, b2, box, dropbox, fichier, filefabric
- googlecloudstorage, hidrive, imagekit, jottacloud, koofr, netstorage
- onedrive, opendrive, oracleobjectstorage, pikpak, premiumizeme
- protondrive, qingstor, quatrix, s3, sharefile, sugarsync, swift
- uptobox, webdav, zoho

This was because rclone was emitting Range requests for the wrong data
range on the second and subsequent retries.

This was caused by fs.FixRangeOption modifying the options and the
reopen code relying on them not being modified.

This fix makes a copy of the fs.FixRangeOption in the reopen code to
fix the problem.

In future it might be best to change fs.FixRangeOption so it returns a
new options slice.

Fixes #7759
2024-04-13 19:25:15 +01:00

398 lines
10 KiB
Go

package operations
import (
"context"
"errors"
"io"
"testing"
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/hash"
"github.com/rclone/rclone/fstest/mockobject"
"github.com/rclone/rclone/lib/pool"
"github.com/rclone/rclone/lib/readers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// check interfaces
var (
_ io.ReadSeekCloser = (*ReOpen)(nil)
_ pool.DelayAccountinger = (*ReOpen)(nil)
)
var errorTestError = errors.New("test error")
// this is a wrapper for a mockobject with a custom Open function
//
// breaks indicate the number of bytes to read before returning an
// error
type reOpenTestObject struct {
fs.Object
t *testing.T
wantStart int64
breaks []int64
unknownSize bool
}
// Open opens the file for read. Call Close() on the returned io.ReadCloser
//
// This will break after reading the number of bytes in breaks
func (o *reOpenTestObject) Open(ctx context.Context, options ...fs.OpenOption) (io.ReadCloser, error) {
// Lots of backends do this - make sure it works as it modifies options
fs.FixRangeOption(options, o.Size())
gotHash := false
gotRange := false
startPos := int64(0)
for _, option := range options {
switch x := option.(type) {
case *fs.HashesOption:
gotHash = true
case *fs.RangeOption:
gotRange = true
startPos = x.Start
if o.unknownSize {
assert.Equal(o.t, int64(-1), x.End)
}
case *fs.SeekOption:
startPos = x.Offset
}
}
assert.Equal(o.t, o.wantStart, startPos)
// Check if ranging, mustn't have hash if offset != 0
if gotHash && gotRange {
assert.Equal(o.t, int64(0), startPos)
}
rc, err := o.Object.Open(ctx, options...)
if err != nil {
return nil, err
}
if len(o.breaks) > 0 {
// Pop a breakpoint off
N := o.breaks[0]
o.breaks = o.breaks[1:]
o.wantStart += N
// If 0 then return an error immediately
if N == 0 {
return nil, errorTestError
}
// Read N bytes then an error
r := io.MultiReader(&io.LimitedReader{R: rc, N: N}, readers.ErrorReader{Err: errorTestError})
// Wrap with Close in a new readCloser
rc = readCloser{Reader: r, Closer: rc}
}
return rc, nil
}
func TestReOpen(t *testing.T) {
for _, testName := range []string{"Normal", "WithRangeOption", "WithSeekOption", "UnknownSize"} {
t.Run(testName, func(t *testing.T) {
// Contents for the mock object
var (
reOpenTestcontents = []byte("0123456789")
expectedRead = reOpenTestcontents
rangeOption *fs.RangeOption
seekOption *fs.SeekOption
unknownSize = false
)
switch testName {
case "Normal":
case "WithRangeOption":
rangeOption = &fs.RangeOption{Start: 1, End: 7} // range is inclusive
expectedRead = reOpenTestcontents[1:8]
case "WithSeekOption":
seekOption = &fs.SeekOption{Offset: 2}
expectedRead = reOpenTestcontents[2:]
case "UnknownSize":
rangeOption = &fs.RangeOption{Start: 1, End: -1}
expectedRead = reOpenTestcontents[1:]
unknownSize = true
default:
panic("bad test name")
}
// Start the test with the given breaks
testReOpen := func(breaks []int64, maxRetries int) (*ReOpen, *reOpenTestObject, error) {
srcOrig := mockobject.New("potato").WithContent(reOpenTestcontents, mockobject.SeekModeNone)
srcOrig.SetUnknownSize(unknownSize)
src := &reOpenTestObject{
Object: srcOrig,
t: t,
breaks: breaks,
unknownSize: unknownSize,
}
opts := []fs.OpenOption{}
if rangeOption == nil && seekOption == nil {
opts = append(opts, &fs.HashesOption{Hashes: hash.NewHashSet(hash.MD5)})
}
if rangeOption != nil {
opts = append(opts, rangeOption)
src.wantStart = rangeOption.Start
}
if seekOption != nil {
opts = append(opts, seekOption)
src.wantStart = seekOption.Offset
}
rc, err := NewReOpen(context.Background(), src, maxRetries, opts...)
return rc, src, err
}
t.Run("Basics", func(t *testing.T) {
// open
h, _, err := testReOpen(nil, 10)
assert.NoError(t, err)
// Check contents read correctly
got, err := io.ReadAll(h)
assert.NoError(t, err)
assert.Equal(t, expectedRead, got)
// Check read after end
var buf = make([]byte, 1)
n, err := h.Read(buf)
assert.Equal(t, 0, n)
assert.Equal(t, io.EOF, err)
// Rewind the stream
_, err = h.Seek(0, io.SeekStart)
require.NoError(t, err)
// Check contents read correctly
got, err = io.ReadAll(h)
assert.NoError(t, err)
assert.Equal(t, expectedRead, got)
// Check close
assert.NoError(t, h.Close())
// Check double close
assert.Equal(t, errFileClosed, h.Close())
// Check read after close
n, err = h.Read(buf)
assert.Equal(t, 0, n)
assert.Equal(t, errFileClosed, err)
})
t.Run("ErrorAtStart", func(t *testing.T) {
// open with immediate breaking
h, _, err := testReOpen([]int64{0}, 10)
assert.Equal(t, errorTestError, err)
assert.Nil(t, h)
})
t.Run("WithErrors", func(t *testing.T) {
// open with a few break points but less than the max
h, _, err := testReOpen([]int64{2, 1, 3}, 10)
assert.NoError(t, err)
// check contents
got, err := io.ReadAll(h)
assert.NoError(t, err)
assert.Equal(t, expectedRead, got)
// check close
assert.NoError(t, h.Close())
})
t.Run("TooManyErrors", func(t *testing.T) {
// open with a few break points but >= the max
h, _, err := testReOpen([]int64{2, 1, 3}, 3)
assert.NoError(t, err)
// check contents
got, err := io.ReadAll(h)
assert.Equal(t, errorTestError, err)
assert.Equal(t, expectedRead[:6], got)
// check old error is returned
var buf = make([]byte, 1)
n, err := h.Read(buf)
assert.Equal(t, 0, n)
assert.Equal(t, errTooManyTries, err)
// Check close
assert.Equal(t, errFileClosed, h.Close())
})
t.Run("Seek", func(t *testing.T) {
// open
h, src, err := testReOpen([]int64{2, 1, 3}, 10)
assert.NoError(t, err)
// Seek to end
pos, err := h.Seek(int64(len(expectedRead)), io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(len(expectedRead)), pos)
// Seek to start
pos, err = h.Seek(0, io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(0), pos)
// Should not allow seek past end
pos, err = h.Seek(int64(len(expectedRead))+1, io.SeekCurrent)
if !unknownSize {
assert.Equal(t, errSeekPastEnd, err)
assert.Equal(t, len(expectedRead), int(pos))
} else {
assert.Equal(t, nil, err)
assert.Equal(t, len(expectedRead)+1, int(pos))
// Seek back to start to get tests in sync
pos, err = h.Seek(0, io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(0), pos)
}
// Should not allow seek to negative position start
pos, err = h.Seek(-1, io.SeekCurrent)
assert.Equal(t, errNegativeSeek, err)
assert.Equal(t, 0, int(pos))
// Should not allow seek with invalid whence
pos, err = h.Seek(0, 3)
assert.Equal(t, errInvalidWhence, err)
assert.Equal(t, 0, int(pos))
// check read
dst := make([]byte, 5)
n, err := h.Read(dst)
assert.Nil(t, err)
assert.Equal(t, 5, n)
assert.Equal(t, expectedRead[:5], dst)
// Test io.SeekCurrent
pos, err = h.Seek(-3, io.SeekCurrent)
assert.Nil(t, err)
assert.Equal(t, 2, int(pos))
// Reset the start after a seek, taking into account the offset
setWantStart := func(x int64) {
src.wantStart = x
if rangeOption != nil {
src.wantStart += rangeOption.Start
} else if seekOption != nil {
src.wantStart += seekOption.Offset
}
}
// check read
setWantStart(2)
n, err = h.Read(dst)
assert.Nil(t, err)
assert.Equal(t, 5, n)
assert.Equal(t, expectedRead[2:7], dst)
pos, err = h.Seek(-2, io.SeekCurrent)
assert.Nil(t, err)
assert.Equal(t, 5, int(pos))
// Test io.SeekEnd
pos, err = h.Seek(-3, io.SeekEnd)
if !unknownSize {
assert.Nil(t, err)
assert.Equal(t, len(expectedRead)-3, int(pos))
} else {
assert.Equal(t, errBadEndSeek, err)
assert.Equal(t, 0, int(pos))
// sync
pos, err = h.Seek(1, io.SeekCurrent)
assert.Nil(t, err)
assert.Equal(t, 6, int(pos))
}
// check read
dst = make([]byte, 3)
setWantStart(int64(len(expectedRead) - 3))
n, err = h.Read(dst)
assert.Nil(t, err)
assert.Equal(t, 3, n)
assert.Equal(t, expectedRead[len(expectedRead)-3:], dst)
// check close
assert.NoError(t, h.Close())
_, err = h.Seek(0, io.SeekCurrent)
assert.Equal(t, errFileClosed, err)
})
t.Run("AccountRead", func(t *testing.T) {
h, _, err := testReOpen(nil, 10)
assert.NoError(t, err)
var total int
h.SetAccounting(func(n int) error {
total += n
return nil
})
dst := make([]byte, 3)
n, err := h.Read(dst)
assert.Equal(t, 3, n)
assert.NoError(t, err)
assert.Equal(t, 3, total)
})
t.Run("AccountReadDelay", func(t *testing.T) {
h, _, err := testReOpen(nil, 10)
assert.NoError(t, err)
var total int
h.SetAccounting(func(n int) error {
total += n
return nil
})
rewind := func() {
_, err := h.Seek(0, io.SeekStart)
require.NoError(t, err)
}
h.DelayAccounting(3)
dst := make([]byte, 16)
n, err := h.Read(dst)
assert.Equal(t, len(expectedRead), n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, total)
rewind()
n, err = h.Read(dst)
assert.Equal(t, len(expectedRead), n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, total)
rewind()
n, err = h.Read(dst)
assert.Equal(t, len(expectedRead), n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, len(expectedRead), total)
rewind()
n, err = h.Read(dst)
assert.Equal(t, len(expectedRead), n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 2*len(expectedRead), total)
rewind()
})
t.Run("AccountReadError", func(t *testing.T) {
// Test accounting errors
h, _, err := testReOpen(nil, 10)
assert.NoError(t, err)
h.SetAccounting(func(n int) error {
return errorTestError
})
dst := make([]byte, 3)
n, err := h.Read(dst)
assert.Equal(t, 3, n)
assert.Equal(t, errorTestError, err)
})
})
}
}