From 14038868402f03fd96dd54d78ced88e8597df76f Mon Sep 17 00:00:00 2001 From: DarthSim Date: Sat, 22 Oct 2022 17:55:06 +0600 Subject: [PATCH] Wrap custom transport bodies with a context reader --- ctxreader/ctxreader.go | 44 ++++++++++++++ ctxreader/ctxreader_test.go | 114 ++++++++++++++++++++++++++++++++++++ transport/fs/fs.go | 3 +- transport/gcs/gcs.go | 3 +- transport/swift/swift.go | 3 +- 5 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 ctxreader/ctxreader.go create mode 100644 ctxreader/ctxreader_test.go diff --git a/ctxreader/ctxreader.go b/ctxreader/ctxreader.go new file mode 100644 index 00000000..ed394776 --- /dev/null +++ b/ctxreader/ctxreader.go @@ -0,0 +1,44 @@ +package ctxreader + +import ( + "context" + "io" + "sync" + "sync/atomic" +) + +type ctxReader struct { + r io.ReadCloser + err atomic.Value + closeOnce sync.Once +} + +func (r *ctxReader) Read(p []byte) (int, error) { + if err := r.err.Load(); err != nil { + return 0, err.(error) + } + return r.r.Read(p) +} + +func (r *ctxReader) Close() (err error) { + r.closeOnce.Do(func() { err = r.r.Close() }) + return +} + +func New(ctx context.Context, r io.ReadCloser, closeOnDone bool) io.ReadCloser { + if ctx.Done() == nil { + return r + } + + ctxr := ctxReader{r: r} + + go func(ctx context.Context) { + <-ctx.Done() + ctxr.err.Store(ctx.Err()) + if closeOnDone { + ctxr.closeOnce.Do(func() { ctxr.r.Close() }) + } + }(ctx) + + return &ctxr +} diff --git a/ctxreader/ctxreader_test.go b/ctxreader/ctxreader_test.go new file mode 100644 index 00000000..60c3859c --- /dev/null +++ b/ctxreader/ctxreader_test.go @@ -0,0 +1,114 @@ +package ctxreader + +import ( + "context" + "crypto/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type testReader struct { + closed bool +} + +func (r *testReader) Read(p []byte) (int, error) { + return rand.Reader.Read(p) +} + +func (r *testReader) Close() error { + r.closed = true + return nil +} + +type CtxReaderTestSuite struct { + suite.Suite +} + +func (s *CtxReaderTestSuite) TestReadUntilCanceled() { + ctx, cancel := context.WithCancel(context.Background()) + + r := New(ctx, &testReader{}, false) + p := make([]byte, 1024) + + _, err := r.Read(p) + require.Nil(s.T(), err) + + cancel() + time.Sleep(time.Second) + + _, err = r.Read(p) + require.Equal(s.T(), err, context.Canceled) +} + +func (s *CtxReaderTestSuite) TestReturnOriginalOnBackgroundContext() { + rr := &testReader{} + r := New(context.Background(), rr, false) + + require.Equal(s.T(), rr, r) +} + +func (s *CtxReaderTestSuite) TestClose() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rr := &testReader{} + New(ctx, rr, true).Close() + + require.True(s.T(), rr.closed) +} + +func (s *CtxReaderTestSuite) TestCloseOnCancel() { + ctx, cancel := context.WithCancel(context.Background()) + + rr := &testReader{} + New(ctx, rr, true) + + cancel() + time.Sleep(time.Second) + + require.True(s.T(), rr.closed) +} + +func (s *CtxReaderTestSuite) TestDontCloseOnCancel() { + ctx, cancel := context.WithCancel(context.Background()) + + rr := &testReader{} + New(ctx, rr, false) + + cancel() + time.Sleep(time.Second) + + require.False(s.T(), rr.closed) +} + +func TestCtxReader(t *testing.T) { + suite.Run(t, new(CtxReaderTestSuite)) +} + +func BenchmarkRawReader(b *testing.B) { + r := testReader{} + + b.ResetTimer() + + p := make([]byte, 1024) + for i := 0; i < b.N; i++ { + r.Read(p) + } +} + +func BenchmarkCtxReader(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() + + r := New(ctx, &testReader{}, true) + + b.ResetTimer() + + p := make([]byte, 1024) + for i := 0; i < b.N; i++ { + r.Read(p) + } +} diff --git a/transport/fs/fs.go b/transport/fs/fs.go index e8a33107..1797ba7c 100644 --- a/transport/fs/fs.go +++ b/transport/fs/fs.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/ctxreader" "github.com/imgproxy/imgproxy/v3/httprange" ) @@ -103,7 +104,7 @@ func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error) ProtoMinor: 0, Header: header, ContentLength: size, - Body: body, + Body: ctxreader.New(req.Context(), body, true), Close: true, Request: req, }, nil diff --git a/transport/gcs/gcs.go b/transport/gcs/gcs.go index d3ecee74..e3f1f31d 100644 --- a/transport/gcs/gcs.go +++ b/transport/gcs/gcs.go @@ -12,6 +12,7 @@ import ( "google.golang.org/api/option" "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/ctxreader" "github.com/imgproxy/imgproxy/v3/httprange" ) @@ -141,7 +142,7 @@ func (t transport) RoundTrip(req *http.Request) (*http.Response, error) { ProtoMinor: 0, Header: header, ContentLength: reader.Attrs.Size, - Body: reader, + Body: ctxreader.New(req.Context(), reader, true), Close: true, Request: req, }, nil diff --git a/transport/swift/swift.go b/transport/swift/swift.go index 3fc289b4..79907b41 100644 --- a/transport/swift/swift.go +++ b/transport/swift/swift.go @@ -12,6 +12,7 @@ import ( "github.com/ncw/swift/v2" "github.com/imgproxy/imgproxy/v3/config" + "github.com/imgproxy/imgproxy/v3/ctxreader" ) type transport struct { @@ -105,7 +106,7 @@ func (t transport) RoundTrip(req *http.Request) (resp *http.Response, err error) ProtoMajor: 1, ProtoMinor: 0, Header: header, - Body: object, + Body: ctxreader.New(req.Context(), object, true), Close: true, Request: req, }, nil