From ecd52aa809775e68c6745bcbf8987d93fdafb7d8 Mon Sep 17 00:00:00 2001 From: Sudipto Baral Date: Mon, 18 Aug 2025 11:29:18 -0400 Subject: [PATCH] smb: improve multithreaded upload performance using multiple connections In the current design, OpenWriterAt provides the interface for random-access writes, and openChunkWriterFromOpenWriterAt wraps this interface to enable parallel chunk uploads using multiple goroutines. A global connection pool is already in place to manage SMB connections across files. However, currently only one connection is used per file, which makes multiple goroutines compete for the connection during multithreaded writes. This changes create separate connections for each goroutine, which allows true parallelism by giving each goroutine its own SMB connection Signed-off-by: sudipto baral --- backend/smb/filepool.go | 99 +++++++++++++++ backend/smb/filepool_test.go | 228 +++++++++++++++++++++++++++++++++++ backend/smb/smb.go | 96 +++++++++++++-- 3 files changed, 413 insertions(+), 10 deletions(-) create mode 100644 backend/smb/filepool.go create mode 100644 backend/smb/filepool_test.go diff --git a/backend/smb/filepool.go b/backend/smb/filepool.go new file mode 100644 index 000000000..0061ec1ff --- /dev/null +++ b/backend/smb/filepool.go @@ -0,0 +1,99 @@ +package smb + +import ( + "context" + "fmt" + "os" + "sync" + + "github.com/cloudsoda/go-smb2" + "golang.org/x/sync/errgroup" +) + +// FsInterface defines the methods that filePool needs from Fs +type FsInterface interface { + getConnection(ctx context.Context, share string) (*conn, error) + putConnection(pc **conn, err error) + removeSession() +} + +type file struct { + *smb2.File + c *conn +} + +type filePool struct { + ctx context.Context + fs FsInterface + share string + path string + + mu sync.Mutex + pool []*file +} + +func newFilePool(ctx context.Context, fs FsInterface, share, path string) *filePool { + return &filePool{ + ctx: ctx, + fs: fs, + share: share, + path: path, + } +} + +func (p *filePool) get() (*file, error) { + p.mu.Lock() + if len(p.pool) > 0 { + f := p.pool[len(p.pool)-1] + p.pool = p.pool[:len(p.pool)-1] + p.mu.Unlock() + return f, nil + } + p.mu.Unlock() + + c, err := p.fs.getConnection(p.ctx, p.share) + if err != nil { + return nil, err + } + + fl, err := c.smbShare.OpenFile(p.path, os.O_WRONLY, 0o644) + if err != nil { + p.fs.putConnection(&c, err) + return nil, fmt.Errorf("failed to open: %w", err) + } + + return &file{File: fl, c: c}, nil +} + +func (p *filePool) put(f *file, err error) { + if f == nil { + return + } + + if err != nil { + _ = f.Close() + p.fs.putConnection(&f.c, err) + return + } + + p.mu.Lock() + p.pool = append(p.pool, f) + p.mu.Unlock() +} + +func (p *filePool) drain() error { + p.mu.Lock() + files := p.pool + p.pool = nil + p.mu.Unlock() + + g, _ := errgroup.WithContext(p.ctx) + for _, f := range files { + g.Go(func() error { + err := f.Close() + p.fs.putConnection(&f.c, err) + return err + }) + } + return g.Wait() +} diff --git a/backend/smb/filepool_test.go b/backend/smb/filepool_test.go new file mode 100644 index 000000000..7cf90988b --- /dev/null +++ b/backend/smb/filepool_test.go @@ -0,0 +1,228 @@ +package smb + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/cloudsoda/go-smb2" + "github.com/stretchr/testify/assert" +) + +// Mock Fs that implements FsInterface +type mockFs struct { + mu sync.Mutex + putConnectionCalled bool + putConnectionErr error + getConnectionCalled bool + getConnectionErr error + getConnectionResult *conn + removeSessionCalled bool +} + +func (m *mockFs) putConnection(pc **conn, err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.putConnectionCalled = true + m.putConnectionErr = err +} + +func (m *mockFs) getConnection(ctx context.Context, share string) (*conn, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.getConnectionCalled = true + if m.getConnectionErr != nil { + return nil, m.getConnectionErr + } + if m.getConnectionResult != nil { + return m.getConnectionResult, nil + } + return &conn{}, nil +} + +func (m *mockFs) removeSession() { + m.mu.Lock() + defer m.mu.Unlock() + m.removeSessionCalled = true +} + +func (m *mockFs) isPutConnectionCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.putConnectionCalled +} + +func (m *mockFs) getPutConnectionErr() error { + m.mu.Lock() + defer m.mu.Unlock() + return m.putConnectionErr +} + +func (m *mockFs) isGetConnectionCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.getConnectionCalled +} + +func newMockFs() *mockFs { + return &mockFs{} +} + +// Helper function to create a mock file +func newMockFile() *file { + return &file{ + File: &smb2.File{}, + c: &conn{}, + } +} + +// Test filePool creation +func TestNewFilePool(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + share := "testshare" + path := "/test/path" + + pool := newFilePool(ctx, fs, share, path) + + assert.NotNil(t, pool) + assert.Equal(t, ctx, pool.ctx) + assert.Equal(t, fs, pool.fs) + assert.Equal(t, share, pool.share) + assert.Equal(t, path, pool.path) + assert.Empty(t, pool.pool) +} + +// Test getting file from pool when pool has files +func TestFilePool_Get_FromPool(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + // Add a mock file to the pool + mockFile := newMockFile() + pool.pool = append(pool.pool, mockFile) + + // Get file from pool + f, err := pool.get() + + assert.NoError(t, err) + assert.NotNil(t, f) + assert.Equal(t, mockFile, f) + assert.Empty(t, pool.pool) +} + +// Test getting file when pool is empty +func TestFilePool_Get_EmptyPool(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + + // Set up the mock to return an error from getConnection + // This tests that the pool calls getConnection when empty + fs.getConnectionErr = errors.New("connection failed") + + pool := newFilePool(ctx, fs, "testshare", "test/path") + + // This should call getConnection and return the error + f, err := pool.get() + assert.Error(t, err) + assert.Nil(t, f) + assert.True(t, fs.isGetConnectionCalled()) + assert.Equal(t, "connection failed", err.Error()) +} + +// Test putting file successfully +func TestFilePool_Put_Success(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + mockFile := newMockFile() + + pool.put(mockFile, nil) + + assert.Len(t, pool.pool, 1) + assert.Equal(t, mockFile, pool.pool[0]) +} + +// Test putting file with error +func TestFilePool_Put_WithError(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + mockFile := newMockFile() + + pool.put(mockFile, errors.New("write error")) + + // Should call putConnection with error + assert.True(t, fs.isPutConnectionCalled()) + assert.Equal(t, errors.New("write error"), fs.getPutConnectionErr()) + assert.Empty(t, pool.pool) +} + +// Test putting nil file +func TestFilePool_Put_NilFile(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + // Should not panic + pool.put(nil, nil) + pool.put(nil, errors.New("some error")) + + assert.Empty(t, pool.pool) +} + +// Test draining pool with files +func TestFilePool_Drain_WithFiles(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + // Add mock files to pool + mockFile1 := newMockFile() + mockFile2 := newMockFile() + pool.pool = append(pool.pool, mockFile1, mockFile2) + + // Before draining + assert.Len(t, pool.pool, 2) + + _ = pool.drain() + assert.Empty(t, pool.pool) +} + +// Test concurrent access to pool +func TestFilePool_ConcurrentAccess(t *testing.T) { + ctx := context.Background() + fs := newMockFs() + pool := newFilePool(ctx, fs, "testshare", "/test/path") + + const numGoroutines = 10 + for i := 0; i < numGoroutines; i++ { + mockFile := newMockFile() + pool.pool = append(pool.pool, mockFile) + } + + // Test concurrent get operations + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer func() { done <- true }() + + f, err := pool.get() + if err == nil { + pool.put(f, nil) + } + }() + } + + for i := 0; i < numGoroutines; i++ { + <-done + } + + // Pool should be in a consistent after the concurrence access + assert.Len(t, pool.pool, numGoroutines) +} diff --git a/backend/smb/smb.go b/backend/smb/smb.go index d102233f5..993b2ec7d 100644 --- a/backend/smb/smb.go +++ b/backend/smb/smb.go @@ -3,6 +3,7 @@ package smb import ( "context" + "errors" "fmt" "io" "os" @@ -503,13 +504,73 @@ func (f *Fs) About(ctx context.Context) (_ *fs.Usage, err error) { return usage, nil } +type smbWriterAt struct { + pool *filePool + closed bool + closeMu sync.Mutex + wg sync.WaitGroup +} + +func (w *smbWriterAt) WriteAt(p []byte, off int64) (int, error) { + w.closeMu.Lock() + if w.closed { + w.closeMu.Unlock() + return 0, errors.New("writer already closed") + } + w.wg.Add(1) + w.closeMu.Unlock() + defer w.wg.Done() + + f, err := w.pool.get() + if err != nil { + return 0, fmt.Errorf("failed to get file from pool: %w", err) + } + + n, writeErr := f.WriteAt(p, off) + w.pool.put(f, writeErr) + + if writeErr != nil { + return n, fmt.Errorf("failed to write at offset %d: %w", off, writeErr) + } + + return n, writeErr +} + +func (w *smbWriterAt) Close() error { + w.closeMu.Lock() + defer w.closeMu.Unlock() + + if w.closed { + return nil + } + w.closed = true + + // Wait for all pending writes to finish + w.wg.Wait() + + var errs []error + + // Drain the pool + if err := w.pool.drain(); err != nil { + errs = append(errs, fmt.Errorf("failed to drain file pool: %w", err)) + } + + // Remove session + w.pool.fs.removeSession() + + if len(errs) > 0 { + return errors.Join(errs...) + } + + return nil +} + // OpenWriterAt opens with a handle for random access writes // // Pass in the remote desired and the size if known. // // It truncates any existing object func (f *Fs) OpenWriterAt(ctx context.Context, remote string, size int64) (fs.WriterAtCloser, error) { - var err error o := &Object{ fs: f, remote: remote, @@ -519,27 +580,42 @@ func (f *Fs) OpenWriterAt(ctx context.Context, remote string, size int64) (fs.Wr return nil, fs.ErrorIsDir } - err = o.fs.ensureDirectory(ctx, share, filename) + err := o.fs.ensureDirectory(ctx, share, filename) if err != nil { return nil, fmt.Errorf("failed to make parent directories: %w", err) } - filename = o.fs.toSambaPath(filename) - - o.fs.addSession() // Show session in use - defer o.fs.removeSession() + smbPath := o.fs.toSambaPath(filename) + // One-time truncate cn, err := o.fs.getConnection(ctx, share) if err != nil { return nil, err } - - fl, err := cn.smbShare.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + file, err := cn.smbShare.OpenFile(smbPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) if err != nil { - return nil, fmt.Errorf("failed to open: %w", err) + o.fs.putConnection(&cn, err) + return nil, err } + if size > 0 { + if truncateErr := file.Truncate(size); truncateErr != nil { + _ = file.Close() + o.fs.putConnection(&cn, truncateErr) + return nil, fmt.Errorf("failed to truncate file: %w", truncateErr) + } + } + if closeErr := file.Close(); closeErr != nil { + o.fs.putConnection(&cn, closeErr) + return nil, fmt.Errorf("failed to close file after truncate: %w", closeErr) + } + o.fs.putConnection(&cn, nil) - return fl, nil + // Add a new session + o.fs.addSession() + + return &smbWriterAt{ + pool: newFilePool(ctx, o.fs, share, smbPath), + }, nil } // Shutdown the backend, closing any background tasks and any