1
0
mirror of https://github.com/imgproxy/imgproxy.git synced 2024-11-24 08:12:38 +02:00

Add client-side decryption support to S3

This commit is contained in:
Garen J. Torikian 2023-10-17 16:03:34 -05:00 committed by DarthSim
parent 8b5255da7d
commit b384e2bb7f
4 changed files with 92 additions and 28 deletions

View File

@ -3,6 +3,7 @@
## [Unreleased]
### Add
- Add `status_codes_total` counter to Prometheus metrics.
- Add client-side decryprion support for S3 integration.
- (pro) Add the `IMGPROXY_VIDEO_THUMBNAIL_KEYFRAMES` config and the [video_thumbnail_keyframes](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-keyframes) processing option.
- (pro) Add the [video_thumbnail_tile](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-tile) processing option.

View File

@ -104,6 +104,7 @@ var (
S3Endpoint string
S3AssumeRoleArn string
S3MultiRegion bool
S3DecryptionClientEnabled bool
GCSEnabled bool
GCSKey string
@ -300,6 +301,7 @@ func Reset() {
S3Endpoint = ""
S3AssumeRoleArn = ""
S3MultiRegion = false
S3DecryptionClientEnabled = false
GCSEnabled = false
GCSKey = ""
ABSEnabled = false
@ -501,6 +503,7 @@ func Configure() error {
configurators.String(&S3Endpoint, "IMGPROXY_S3_ENDPOINT")
configurators.String(&S3AssumeRoleArn, "IMGPROXY_S3_ASSUME_ROLE_ARN")
configurators.Bool(&S3MultiRegion, "IMGPROXY_S3_MULTI_REGION")
configurators.Bool(&S3DecryptionClientEnabled, "IMGPROXY_S3_USE_DECRYPTION_CLIENT")
configurators.Bool(&GCSEnabled, "IMGPROXY_USE_GCS")
configurators.String(&GCSKey, "IMGPROXY_GCS_KEY")

View File

@ -4,7 +4,8 @@ import (
"context"
"fmt"
"io"
http "net/http"
"net/http"
"strconv"
"strings"
"sync"
"time"
@ -14,20 +15,27 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/imgproxy/imgproxy/v3/config"
defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
)
type s3Client interface {
GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput)
}
// transport implements RoundTripper for the 's3' protocol.
type transport struct {
session *session.Session
defaultClient *s3.S3
defaultClient s3Client
defaultConfig *aws.Config
clientsByRegion map[string]*s3.S3
clientsByBucket map[string]*s3.S3
clientsByRegion map[string]s3Client
clientsByBucket map[string]s3Client
mu sync.RWMutex
}
@ -49,7 +57,7 @@ func New() (http.RoundTripper, error) {
sess, err := session.NewSession()
if err != nil {
return nil, fmt.Errorf("Can't create S3 session: %s", err)
return nil, fmt.Errorf("can't create S3 session: %s", err)
}
if len(config.S3Region) != 0 {
@ -64,19 +72,19 @@ func New() (http.RoundTripper, error) {
conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
}
client := s3.New(sess, conf)
clientRegion := "us-west-1"
if client.Config.Region != nil {
clientRegion = *client.Config.Region
client, err := createClient(sess, conf)
if err != nil {
return nil, fmt.Errorf("can't create S3 client: %s", err)
}
clientRegion := *sess.Config.Region
return &transport{
session: sess,
defaultClient: client,
clientsByRegion: map[string]*s3.S3{clientRegion: client},
clientsByBucket: make(map[string]*s3.S3),
defaultConfig: conf,
clientsByRegion: map[string]s3Client{clientRegion: client},
clientsByBucket: make(map[string]s3Client),
}, nil
}
@ -113,7 +121,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
return handleError(req, err)
}
s3req, _ := client.GetObjectRequest(input)
s3req, objectOutput := client.GetObjectRequest(input)
s3req.SetContext(req.Context())
if err := s3req.Send(); err != nil {
@ -124,15 +132,27 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
return handleError(req, err)
}
if config.S3DecryptionClientEnabled {
s3req.HTTPResponse.Body = objectOutput.Body
if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 {
contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
if err != nil {
handleError(req, err)
}
s3req.HTTPResponse.ContentLength = contentLength
}
}
return s3req.HTTPResponse, nil
}
func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) {
func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
if !config.S3MultiRegion {
return t.defaultClient, nil
}
var client *s3.S3
var client s3Client
func() {
t.mu.RLock()
@ -152,7 +172,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
return client, nil
}
region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket)
region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region)
if err != nil {
return nil, err
}
@ -162,10 +182,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
return client, nil
}
conf := t.defaultClient.Config.Copy()
conf := t.defaultConfig.Copy()
conf.Region = aws.String(region)
client = s3.New(t.session, conf)
client, err = createClient(t.session, conf)
if err != nil {
return nil, fmt.Errorf("can't create regional S3 client: %s", err)
}
t.clientsByRegion[region] = client
t.clientsByBucket[bucket] = client
@ -173,6 +196,40 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
return client, nil
}
func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) {
if config.S3DecryptionClientEnabled {
// `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we
// need to copy the session with an additional config
sess = sess.Copy(conf)
cryptoRegistry, err := createCryptoRegistry(sess)
if err != nil {
return nil, err
}
return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry)
} else {
return s3.New(sess, conf), nil
}
}
func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) {
kmsClient := kms.New(sess)
cr := s3crypto.NewCryptoRegistry()
if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil {
return nil, err
}
if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil {
return nil, err
}
if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
return nil, err
}
return cr, nil
}
func handleError(req *http.Request, err error) (*http.Response, error) {
if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
if e := s3err.OrigErr(); e != nil {

View File

@ -50,15 +50,18 @@ func (s *S3TestSuite) SetupSuite() {
svc, err := s.transport.(*transport).getClient(context.Background(), "test")
require.Nil(s.T(), err)
require.NotNil(s.T(), svc)
require.IsType(s.T(), &s3.S3{}, svc)
_, err = svc.PutObject(&s3.PutObjectInput{
client := svc.(*s3.S3)
_, err = client.PutObject(&s3.PutObjectInput{
Body: bytes.NewReader(make([]byte, 32)),
Bucket: aws.String("test"),
Key: aws.String("foo/test.png"),
})
require.Nil(s.T(), err)
obj, err := svc.GetObject(&s3.GetObjectInput{
obj, err := client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("test"),
Key: aws.String("foo/test.png"),
})