From 181fecaec3017e0e01ab417f72d8bc4559797d5a Mon Sep 17 00:00:00 2001
From: Vitor Gomes <mail@vitorgomes.com>
Date: Tue, 25 Jul 2023 17:19:37 +0200
Subject: [PATCH] multithread: refactor multithread operation to use
 OpenChunkWriter if available #7056

If the feature OpenChunkWriter is not available, multithread tries to create an adapter from OpenWriterAt to OpenChunkWriter.
---
 fs/operations/multithread.go      | 276 ++++++++++++++++++------------
 fs/operations/multithread_test.go |  27 ++-
 2 files changed, 180 insertions(+), 123 deletions(-)

diff --git a/fs/operations/multithread.go b/fs/operations/multithread.go
index 5ab41b9a2..deb7c88ba 100644
--- a/fs/operations/multithread.go
+++ b/fs/operations/multithread.go
@@ -9,13 +9,13 @@ import (
 
 	"github.com/rclone/rclone/fs"
 	"github.com/rclone/rclone/fs/accounting"
+	"github.com/rclone/rclone/lib/readers"
 	"golang.org/x/sync/errgroup"
+	"golang.org/x/sync/semaphore"
 )
 
 const (
-	multithreadChunkSize      = 64 << 10
-	multithreadChunkSizeMask  = multithreadChunkSize - 1
-	multithreadReadBufferSize = 32 * 1024
+	multithreadChunkSize = 64 << 10
 )
 
 // An offsetWriter maps writes at offset base to offset base+off in the underlying writer.
@@ -60,7 +60,7 @@ func doMultiThreadCopy(ctx context.Context, f fs.Fs, src fs.Object) bool {
 	}
 	// ...destination doesn't support it
 	dstFeatures := f.Features()
-	if dstFeatures.OpenWriterAt == nil {
+	if dstFeatures.OpenChunkWriter == nil && dstFeatures.OpenWriterAt == nil {
 		return false
 	}
 	// ...if --multi-thread-streams not in use and source and
@@ -73,21 +73,20 @@ func doMultiThreadCopy(ctx context.Context, f fs.Fs, src fs.Object) bool {
 
 // state for a multi-thread copy
 type multiThreadCopyState struct {
-	ctx      context.Context
-	partSize int64
-	size     int64
-	wc       fs.WriterAtCloser
-	src      fs.Object
-	acc      *accounting.Account
-	streams  int
+	ctx       context.Context
+	partSize  int64
+	size      int64
+	src       fs.Object
+	acc       *accounting.Account
+	streams   int
+	numChunks int
 }
 
 // Copy a single stream into place
-func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int) (err error) {
-	ci := fs.GetConfig(ctx)
+func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int, writer fs.ChunkWriter) (err error) {
 	defer func() {
 		if err != nil {
-			fs.Debugf(mc.src, "multi-thread copy: stream %d/%d failed: %v", stream+1, mc.streams, err)
+			fs.Debugf(mc.src, "multi-thread copy: stream %d/%d failed: %v", stream+1, mc.numChunks, err)
 		}
 	}()
 	start := int64(stream) * mc.partSize
@@ -99,7 +98,7 @@ func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int) (err
 		end = mc.size
 	}
 
-	fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v starting", stream+1, mc.streams, start, end, fs.SizeSuffix(end-start))
+	fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v starting", stream+1, mc.numChunks, start, end, fs.SizeSuffix(end-start))
 
 	rc, err := Open(ctx, mc.src, &fs.RangeOption{Start: start, End: end - 1})
 	if err != nil {
@@ -107,119 +106,99 @@ func (mc *multiThreadCopyState) copyStream(ctx context.Context, stream int) (err
 	}
 	defer fs.CheckClose(rc, &err)
 
-	var writer io.Writer = newOffsetWriter(mc.wc, start)
-	if ci.MultiThreadWriteBufferSize > 0 {
-		writer = bufio.NewWriterSize(writer, int(ci.MultiThreadWriteBufferSize))
-		fs.Debugf(mc.src, "multi-thread copy: write buffer set to %v", ci.MultiThreadWriteBufferSize)
+	bytesWritten, err := writer.WriteChunk(stream, readers.NewRepeatableReader(rc))
+	if err != nil {
+		return err
 	}
-	// Copy the data
-	buf := make([]byte, multithreadReadBufferSize)
-	offset := start
-	for {
-		// Check if context cancelled and exit if so
-		if mc.ctx.Err() != nil {
-			return mc.ctx.Err()
-		}
-		nr, er := rc.Read(buf)
-		if nr > 0 {
-			err = mc.acc.AccountRead(nr)
-			if err != nil {
-				return fmt.Errorf("multipart copy: accounting failed: %w", err)
-			}
-			nw, ew := writer.Write(buf[0:nr])
-			if nw > 0 {
-				offset += int64(nw)
-			}
-			if ew != nil {
-				return fmt.Errorf("multipart copy: write failed: %w", ew)
-			}
-			if nr != nw {
-				return fmt.Errorf("multipart copy: %w", io.ErrShortWrite)
-			}
-		}
-		if er != nil {
-			if er != io.EOF {
-				return fmt.Errorf("multipart copy: read failed: %w", er)
-			}
-
-			// if we were buffering, flush do disk
-			switch w := writer.(type) {
-			case *bufio.Writer:
-				er2 := w.Flush()
-				if er2 != nil {
-					return fmt.Errorf("multipart copy: flush failed: %w", er2)
-				}
-			}
-
-			break
-		}
+	// FIXME: Wrap ReadSeeker for Accounting
+	// However, to ensure reporting is correctly seeks have to be handled properly
+	errAccRead := mc.acc.AccountRead(int(bytesWritten))
+	if errAccRead != nil {
+		return errAccRead
 	}
 
-	if offset != end {
-		return fmt.Errorf("multipart copy: wrote %d bytes but expected to write %d", offset-start, end-start)
-	}
-
-	fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v finished", stream+1, mc.streams, start, end, fs.SizeSuffix(end-start))
+	fs.Debugf(mc.src, "multi-thread copy: stream %d/%d (%d-%d) size %v finished", stream+1, mc.numChunks, start, end, fs.SizeSuffix(bytesWritten))
 	return nil
 }
 
-// Calculate the chunk sizes and updated number of streams
-func (mc *multiThreadCopyState) calculateChunks() {
-	partSize := mc.size / int64(mc.streams)
-	// Round partition size up so partSize * streams >= size
-	if (mc.size % int64(mc.streams)) != 0 {
-		partSize++
-	}
-	// round partSize up to nearest multithreadChunkSize boundary
-	mc.partSize = (partSize + multithreadChunkSizeMask) &^ multithreadChunkSizeMask
-	// recalculate number of streams
-	mc.streams = int(mc.size / mc.partSize)
-	// round streams up so partSize * streams >= size
-	if (mc.size % mc.partSize) != 0 {
-		mc.streams++
+// Given a file size and a chunkSize
+// it returns the number of chunks, so that chunkSize * numChunks >= size
+func calculateNumChunks(size int64, chunkSize int64) int {
+	numChunks := size / chunkSize
+	if size%chunkSize != 0 {
+		numChunks++
 	}
+
+	return int(numChunks)
 }
 
-// Copy src to (f, remote) using streams download threads and the OpenWriterAt feature
+// Copy src to (f, remote) using streams download threads. It tries to use the OpenChunkWriter feature
+// and if that's not available it creates an adapter using OpenWriterAt
 func multiThreadCopy(ctx context.Context, f fs.Fs, remote string, src fs.Object, streams int, tr *accounting.Transfer) (newDst fs.Object, err error) {
-	openWriterAt := f.Features().OpenWriterAt
-	if openWriterAt == nil {
-		return nil, errors.New("multi-thread copy: OpenWriterAt not supported")
+	openChunkWriter := f.Features().OpenChunkWriter
+	ci := fs.GetConfig(ctx)
+	if openChunkWriter == nil {
+		openWriterAt := f.Features().OpenWriterAt
+		if openWriterAt == nil {
+			return nil, errors.New("multi-part copy: neither OpenChunkWriter nor OpenWriterAt supported")
+		}
+		openChunkWriter = openChunkWriterFromOpenWriterAt(openWriterAt, int64(ci.MultiThreadChunkSize), int64(ci.MultiThreadWriteBufferSize), f)
 	}
+
 	if src.Size() < 0 {
-		return nil, errors.New("multi-thread copy: can't copy unknown sized file")
+		return nil, fmt.Errorf("multi-thread copy: can't copy unknown sized file")
 	}
 	if src.Size() == 0 {
-		return nil, errors.New("multi-thread copy: can't copy zero sized file")
+		return nil, fmt.Errorf("multi-thread copy: can't copy zero sized file")
 	}
 
 	g, gCtx := errgroup.WithContext(ctx)
-	mc := &multiThreadCopyState{
-		ctx:     gCtx,
-		size:    src.Size(),
-		src:     src,
-		streams: streams,
+	chunkSize, chunkWriter, err := openChunkWriter(ctx, remote, src)
+
+	if chunkSize > src.Size() {
+		fs.Debugf(src, "multi-thread copy: chunk size %v was bigger than source file size %v", fs.SizeSuffix(chunkSize), fs.SizeSuffix(src.Size()))
+		chunkSize = src.Size()
+	}
+
+	numChunks := calculateNumChunks(src.Size(), chunkSize)
+	if streams > numChunks {
+		fs.Debugf(src, "multi-thread copy: number of streams '%d' was bigger than number of chunks '%d'", streams, numChunks)
+		streams = numChunks
+	}
+
+	mc := &multiThreadCopyState{
+		ctx:       gCtx,
+		size:      src.Size(),
+		src:       src,
+		partSize:  chunkSize,
+		streams:   streams,
+		numChunks: numChunks,
+	}
+
+	if err != nil {
+		return nil, fmt.Errorf("multipart copy: failed to open chunk writer: %w", err)
 	}
-	mc.calculateChunks()
 
 	// Make accounting
 	mc.acc = tr.Account(ctx, nil)
 
-	// create write file handle
-	mc.wc, err = openWriterAt(gCtx, remote, mc.size)
-	if err != nil {
-		return nil, fmt.Errorf("multipart copy: failed to open destination: %w", err)
-	}
-
-	fs.Debugf(src, "Starting multi-thread copy with %d parts of size %v", mc.streams, fs.SizeSuffix(mc.partSize))
-	for stream := 0; stream < mc.streams; stream++ {
-		stream := stream
+	fs.Debugf(src, "Starting multi-thread copy with %d parts of size %v with %v parallel streams", mc.numChunks, fs.SizeSuffix(mc.partSize), mc.streams)
+	sem := semaphore.NewWeighted(int64(mc.streams))
+	for chunk := 0; chunk < mc.numChunks; chunk++ {
+		fs.Debugf(src, "Acquiring semaphore...")
+		if err := sem.Acquire(ctx, 1); err != nil {
+			fs.Errorf(src, "Failed to acquire semaphore: %v", err)
+			break
+		}
+		currChunk := chunk
 		g.Go(func() (err error) {
-			return mc.copyStream(gCtx, stream)
+			defer sem.Release(1)
+			return mc.copyStream(gCtx, currChunk, chunkWriter)
 		})
 	}
+
 	err = g.Wait()
-	closeErr := mc.wc.Close()
+	closeErr := chunkWriter.Close()
 	if err != nil {
 		return nil, err
 	}
@@ -232,13 +211,94 @@ func multiThreadCopy(ctx context.Context, f fs.Fs, remote string, src fs.Object,
 		return nil, fmt.Errorf("multi-thread copy: failed to find object after copy: %w", err)
 	}
 
-	err = obj.SetModTime(ctx, src.ModTime(ctx))
-	switch err {
-	case nil, fs.ErrorCantSetModTime, fs.ErrorCantSetModTimeWithoutDelete:
-	default:
-		return nil, fmt.Errorf("multi-thread copy: failed to set modification time: %w", err)
+	if f.Features().PartialUploads {
+		err = obj.SetModTime(ctx, src.ModTime(ctx))
+		switch err {
+		case nil, fs.ErrorCantSetModTime, fs.ErrorCantSetModTimeWithoutDelete:
+		default:
+			return nil, fmt.Errorf("multi-thread copy: failed to set modification time: %w", err)
+		}
 	}
 
-	fs.Debugf(src, "Finished multi-thread copy with %d parts of size %v", mc.streams, fs.SizeSuffix(mc.partSize))
+	fs.Debugf(src, "Finished multi-thread copy with %d parts of size %v", mc.numChunks, fs.SizeSuffix(mc.partSize))
 	return obj, nil
 }
+
+type writerAtChunkWriter struct {
+	ctx             context.Context
+	remote          string
+	size            int64
+	writerAt        fs.WriterAtCloser
+	chunkSize       int64
+	chunks          int
+	writeBufferSize int64
+	f               fs.Fs
+}
+
+func (w writerAtChunkWriter) WriteChunk(chunkNumber int, reader io.ReadSeeker) (int64, error) {
+	fs.Debugf(w.remote, "writing chunk %v", chunkNumber)
+
+	bytesToWrite := w.chunkSize
+	if chunkNumber == (w.chunks-1) && w.size%w.chunkSize != 0 {
+		bytesToWrite = w.size % w.chunkSize
+	}
+
+	var writer io.Writer = newOffsetWriter(w.writerAt, int64(chunkNumber)*w.chunkSize)
+	if w.writeBufferSize > 0 {
+		writer = bufio.NewWriterSize(writer, int(w.writeBufferSize))
+	}
+	n, err := io.Copy(writer, reader)
+	if err != nil {
+		return -1, err
+	}
+	if n != bytesToWrite {
+		return -1, fmt.Errorf("expected to write %v bytes for chunk %v, but wrote %v bytes", bytesToWrite, chunkNumber, n)
+	}
+	// if we were buffering, flush do disk
+	switch w := writer.(type) {
+	case *bufio.Writer:
+		er2 := w.Flush()
+		if er2 != nil {
+			return -1, fmt.Errorf("multipart copy: flush failed: %w", err)
+		}
+	}
+	return n, nil
+}
+
+func (w writerAtChunkWriter) Close() error {
+	return w.writerAt.Close()
+}
+
+func (w writerAtChunkWriter) Abort() error {
+	obj, err := w.f.NewObject(w.ctx, w.remote)
+	if err != nil {
+		return fmt.Errorf("multi-thread copy: failed to find temp file when aborting chunk writer: %w", err)
+	}
+	return obj.Remove(w.ctx)
+}
+
+func openChunkWriterFromOpenWriterAt(openWriterAt func(ctx context.Context, remote string, size int64) (fs.WriterAtCloser, error), chunkSize int64, writeBufferSize int64, f fs.Fs) func(ctx context.Context, remote string, src fs.ObjectInfo, options ...fs.OpenOption) (chunkSizeResult int64, writer fs.ChunkWriter, err error) {
+	return func(ctx context.Context, remote string, src fs.ObjectInfo, options ...fs.OpenOption) (chunkSizeResult int64, writer fs.ChunkWriter, err error) {
+		writerAt, err := openWriterAt(ctx, remote, src.Size())
+		if err != nil {
+			return -1, nil, err
+		}
+
+		if writeBufferSize > 0 {
+			fs.Debugf(src.Remote(), "multi-thread copy: write buffer set to %v", writeBufferSize)
+		}
+
+		chunkWriter := &writerAtChunkWriter{
+			ctx:             ctx,
+			remote:          remote,
+			size:            src.Size(),
+			chunkSize:       chunkSize,
+			chunks:          calculateNumChunks(src.Size(), chunkSize),
+			writerAt:        writerAt,
+			writeBufferSize: writeBufferSize,
+			f:               f,
+		}
+
+		return chunkSize, chunkWriter, nil
+	}
+}
diff --git a/fs/operations/multithread_test.go b/fs/operations/multithread_test.go
index aef6f75af..cd5f430f8 100644
--- a/fs/operations/multithread_test.go
+++ b/fs/operations/multithread_test.go
@@ -86,27 +86,24 @@ func TestDoMultiThreadCopy(t *testing.T) {
 	assert.True(t, doMultiThreadCopy(ctx, f, src))
 }
 
-func TestMultithreadCalculateChunks(t *testing.T) {
+func TestMultithreadCalculateNumChunks(t *testing.T) {
 	for _, test := range []struct {
-		size         int64
-		streams      int
-		wantPartSize int64
-		wantStreams  int
+		size          int64
+		chunkSize     int64
+		wantNumChunks int
 	}{
-		{size: 1, streams: 10, wantPartSize: multithreadChunkSize, wantStreams: 1},
-		{size: 1 << 20, streams: 1, wantPartSize: 1 << 20, wantStreams: 1},
-		{size: 1 << 20, streams: 2, wantPartSize: 1 << 19, wantStreams: 2},
-		{size: (1 << 20) + 1, streams: 2, wantPartSize: (1 << 19) + multithreadChunkSize, wantStreams: 2},
-		{size: (1 << 20) - 1, streams: 2, wantPartSize: (1 << 19), wantStreams: 2},
+		{size: 1, chunkSize: multithreadChunkSize, wantNumChunks: 1},
+		{size: 1 << 20, chunkSize: 1, wantNumChunks: 1 << 20},
+		{size: 1 << 20, chunkSize: 2, wantNumChunks: 1 << 19},
+		{size: (1 << 20) + 1, chunkSize: 2, wantNumChunks: (1 << 19) + 1},
+		{size: (1 << 20) - 1, chunkSize: 2, wantNumChunks: 1 << 19},
 	} {
 		t.Run(fmt.Sprintf("%+v", test), func(t *testing.T) {
 			mc := &multiThreadCopyState{
-				size:    test.size,
-				streams: test.streams,
+				size: test.size,
 			}
-			mc.calculateChunks()
-			assert.Equal(t, test.wantPartSize, mc.partSize)
-			assert.Equal(t, test.wantStreams, mc.streams)
+			mc.numChunks = calculateNumChunks(test.size, test.chunkSize)
+			assert.Equal(t, test.wantNumChunks, mc.numChunks)
 		})
 	}
 }