diff --git a/fs/buffer.go b/fs/buffer.go index 436b42383..588a36d00 100644 --- a/fs/buffer.go +++ b/fs/buffer.go @@ -16,6 +16,8 @@ var asyncBufferPool = sync.Pool{ New: func() interface{} { return newBuffer() }, } +var errorStreamAbandoned = errors.New("stream abandoned") + // asyncReader will do async read-ahead from the input reader // and make the data available as an io.Reader. // This should be fully transparent, except that once an error @@ -31,6 +33,7 @@ type asyncReader struct { exited chan struct{} // Channel is closed been the async reader shuts down size int // size of buffer to use closed bool // whether we have closed the underlying stream + mu sync.Mutex // lock for Read/WriteTo/Abandon/Close } // newAsyncReader returns a reader that will asynchronously read from @@ -39,7 +42,7 @@ type asyncReader struct { // function has returned. // The input can be read from the returned reader. // When done use Close to release the buffers and close the supplied input. -func newAsyncReader(rd io.ReadCloser, buffers int) (io.ReadCloser, error) { +func newAsyncReader(rd io.ReadCloser, buffers int) (*asyncReader, error) { if buffers <= 0 { return nil, errors.New("number of buffers too small") } @@ -113,6 +116,10 @@ func (a *asyncReader) fill() (err error) { } b, ok := <-a.ready if !ok { + // Return an error to show fill failed + if a.err == nil { + return errorStreamAbandoned + } return a.err } a.cur = b @@ -122,6 +129,9 @@ func (a *asyncReader) fill() (err error) { // Read will return the next available data. func (a *asyncReader) Read(p []byte) (n int, err error) { + a.mu.Lock() + defer a.mu.Unlock() + // Swap buffer and maybe return error err = a.fill() if err != nil { @@ -144,6 +154,9 @@ func (a *asyncReader) Read(p []byte) (n int, err error) { // The return value n is the number of bytes written. // Any error encountered during the write is also returned. func (a *asyncReader) WriteTo(w io.Writer) (n int64, err error) { + a.mu.Lock() + defer a.mu.Unlock() + n = 0 for { err = a.fill() @@ -175,6 +188,9 @@ func (a *asyncReader) Abandon() { // Close and wait for go routine close(a.exit) <-a.exited + // take the lock to wait for Read/WriteTo to complete + a.mu.Lock() + defer a.mu.Unlock() // Return any outstanding buffers to the Pool if a.cur != nil { a.putBuffer(a.cur) diff --git a/fs/buffer_test.go b/fs/buffer_test.go index 278f57678..83782b086 100644 --- a/fs/buffer_test.go +++ b/fs/buffer_test.go @@ -6,8 +6,10 @@ import ( "io" "io/ioutil" "strings" + "sync" "testing" "testing/iotest" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -196,8 +198,7 @@ func TestAsyncReaderWriteTo(t *testing.T) { buf := bufio.NewReaderSize(read, bufsize) ar, _ := newAsyncReader(ioutil.NopCloser(buf), l) dst := &bytes.Buffer{} - wt := ar.(io.WriterTo) - _, err := wt.WriteTo(dst) + _, err := ar.WriteTo(dst) if err != nil && err != io.EOF && err != iotest.ErrTimeout { t.Fatal("Copy:", err) } @@ -215,3 +216,65 @@ func TestAsyncReaderWriteTo(t *testing.T) { } } } + +// Read an infinite number of zeros +type zeroReader struct { + closed bool +} + +func (z *zeroReader) Read(p []byte) (n int, err error) { + if z.closed { + return 0, io.EOF + } + for i := range p { + p[i] = 0 + } + return len(p), nil +} + +func (z *zeroReader) Close() error { + if z.closed { + panic("double close on zeroReader") + } + z.closed = true + return nil +} + +// Test closing and abandoning +func testAsyncReaderClose(t *testing.T, writeto bool) { + zr := &zeroReader{} + a, err := newAsyncReader(zr, 16) + require.NoError(t, err) + var copyN int64 + var copyErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if true { + // exercise the WriteTo path + copyN, copyErr = a.WriteTo(ioutil.Discard) + } else { + // exercise the Read path + buf := make([]byte, 64*1024) + for { + var n int + n, copyErr = a.Read(buf) + copyN += int64(n) + if copyErr != nil { + break + } + } + } + }() + // Do some copying + time.Sleep(100 * time.Millisecond) + // Abandon the copy + a.Abandon() + wg.Wait() + assert.Equal(t, errorStreamAbandoned, copyErr) + // t.Logf("Copied %d bytes, err %v", copyN, copyErr) + assert.True(t, copyN > 0) +} +func TestAsyncReaderCloseRead(t *testing.T) { testAsyncReaderClose(t, false) } +func TestAsyncReaderCloseWriteTo(t *testing.T) { testAsyncReaderClose(t, true) }