diff --git a/CHANGELOG.md b/CHANGELOG.md index 285cd7d8..fb424d85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## [Unreleased] ### Add - Add `status_codes_total` counter to Prometheus metrics. +- Add client-side decryprion support for S3 integration. - (pro) Add the `IMGPROXY_VIDEO_THUMBNAIL_KEYFRAMES` config and the [video_thumbnail_keyframes](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-keyframes) processing option. - (pro) Add the [video_thumbnail_tile](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-tile) processing option. diff --git a/config/config.go b/config/config.go index 036808ab..90e7b4e0 100644 --- a/config/config.go +++ b/config/config.go @@ -99,11 +99,12 @@ var ( LocalFileSystemRoot string - S3Enabled bool - S3Region string - S3Endpoint string - S3AssumeRoleArn string - S3MultiRegion bool + S3Enabled bool + S3Region string + S3Endpoint string + S3AssumeRoleArn string + S3MultiRegion bool + S3DecryptionClientEnabled bool GCSEnabled bool GCSKey string @@ -300,6 +301,7 @@ func Reset() { S3Endpoint = "" S3AssumeRoleArn = "" S3MultiRegion = false + S3DecryptionClientEnabled = false GCSEnabled = false GCSKey = "" ABSEnabled = false @@ -501,6 +503,7 @@ func Configure() error { configurators.String(&S3Endpoint, "IMGPROXY_S3_ENDPOINT") configurators.String(&S3AssumeRoleArn, "IMGPROXY_S3_ASSUME_ROLE_ARN") configurators.Bool(&S3MultiRegion, "IMGPROXY_S3_MULTI_REGION") + configurators.Bool(&S3DecryptionClientEnabled, "IMGPROXY_S3_USE_DECRYPTION_CLIENT") configurators.Bool(&GCSEnabled, "IMGPROXY_USE_GCS") configurators.String(&GCSKey, "IMGPROXY_GCS_KEY") diff --git a/transport/s3/s3.go b/transport/s3/s3.go index 60036985..bd375782 100644 --- a/transport/s3/s3.go +++ b/transport/s3/s3.go @@ -4,7 +4,8 @@ import ( "context" "fmt" "io" - http "net/http" + "net/http" + "strconv" "strings" "sync" "time" @@ -14,20 +15,27 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3crypto" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/imgproxy/imgproxy/v3/config" defaultTransport "github.com/imgproxy/imgproxy/v3/transport" ) +type s3Client interface { + GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput) +} + // transport implements RoundTripper for the 's3' protocol. type transport struct { session *session.Session - defaultClient *s3.S3 + defaultClient s3Client + defaultConfig *aws.Config - clientsByRegion map[string]*s3.S3 - clientsByBucket map[string]*s3.S3 + clientsByRegion map[string]s3Client + clientsByBucket map[string]s3Client mu sync.RWMutex } @@ -49,7 +57,7 @@ func New() (http.RoundTripper, error) { sess, err := session.NewSession() if err != nil { - return nil, fmt.Errorf("Can't create S3 session: %s", err) + return nil, fmt.Errorf("can't create S3 session: %s", err) } if len(config.S3Region) != 0 { @@ -64,19 +72,19 @@ func New() (http.RoundTripper, error) { conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn) } - client := s3.New(sess, conf) - - clientRegion := "us-west-1" - if client.Config.Region != nil { - clientRegion = *client.Config.Region + client, err := createClient(sess, conf) + if err != nil { + return nil, fmt.Errorf("can't create S3 client: %s", err) } - return &transport{ - session: sess, - defaultClient: client, + clientRegion := *sess.Config.Region - clientsByRegion: map[string]*s3.S3{clientRegion: client}, - clientsByBucket: make(map[string]*s3.S3), + return &transport{ + session: sess, + defaultClient: client, + defaultConfig: conf, + clientsByRegion: map[string]s3Client{clientRegion: client}, + clientsByBucket: make(map[string]s3Client), }, nil } @@ -113,7 +121,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { return handleError(req, err) } - s3req, _ := client.GetObjectRequest(input) + s3req, objectOutput := client.GetObjectRequest(input) s3req.SetContext(req.Context()) if err := s3req.Send(); err != nil { @@ -124,15 +132,27 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { return handleError(req, err) } + if config.S3DecryptionClientEnabled { + s3req.HTTPResponse.Body = objectOutput.Body + + if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 { + contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64) + if err != nil { + handleError(req, err) + } + s3req.HTTPResponse.ContentLength = contentLength + } + } + return s3req.HTTPResponse, nil } -func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) { +func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) { if !config.S3MultiRegion { return t.defaultClient, nil } - var client *s3.S3 + var client s3Client func() { t.mu.RLock() @@ -152,7 +172,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error return client, nil } - region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket) + region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region) if err != nil { return nil, err } @@ -162,10 +182,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error return client, nil } - conf := t.defaultClient.Config.Copy() + conf := t.defaultConfig.Copy() conf.Region = aws.String(region) - client = s3.New(t.session, conf) + client, err = createClient(t.session, conf) + if err != nil { + return nil, fmt.Errorf("can't create regional S3 client: %s", err) + } t.clientsByRegion[region] = client t.clientsByBucket[bucket] = client @@ -173,6 +196,40 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error return client, nil } +func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) { + if config.S3DecryptionClientEnabled { + // `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we + // need to copy the session with an additional config + sess = sess.Copy(conf) + + cryptoRegistry, err := createCryptoRegistry(sess) + if err != nil { + return nil, err + } + + return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry) + } else { + return s3.New(sess, conf), nil + } +} + +func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) { + kmsClient := kms.New(sess) + + cr := s3crypto.NewCryptoRegistry() + if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil { + return nil, err + } + if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil { + return nil, err + } + if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil { + return nil, err + } + + return cr, nil +} + func handleError(req *http.Request, err error) (*http.Response, error) { if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode { if e := s3err.OrigErr(); e != nil { diff --git a/transport/s3/s3_test.go b/transport/s3/s3_test.go index 03259e59..56d5d8ea 100644 --- a/transport/s3/s3_test.go +++ b/transport/s3/s3_test.go @@ -50,15 +50,18 @@ func (s *S3TestSuite) SetupSuite() { svc, err := s.transport.(*transport).getClient(context.Background(), "test") require.Nil(s.T(), err) require.NotNil(s.T(), svc) + require.IsType(s.T(), &s3.S3{}, svc) - _, err = svc.PutObject(&s3.PutObjectInput{ + client := svc.(*s3.S3) + + _, err = client.PutObject(&s3.PutObjectInput{ Body: bytes.NewReader(make([]byte, 32)), Bucket: aws.String("test"), Key: aws.String("foo/test.png"), }) require.Nil(s.T(), err) - obj, err := svc.GetObject(&s3.GetObjectInput{ + obj, err := client.GetObject(&s3.GetObjectInput{ Bucket: aws.String("test"), Key: aws.String("foo/test.png"), })