1
0
mirror of https://github.com/imgproxy/imgproxy.git synced 2025-02-12 11:46:10 +02:00

Add client-side decryption support to S3

This commit is contained in:
Garen J. Torikian 2023-10-17 16:03:34 -05:00 committed by DarthSim
parent 8b5255da7d
commit b384e2bb7f
4 changed files with 92 additions and 28 deletions

View File

@ -3,6 +3,7 @@
## [Unreleased] ## [Unreleased]
### Add ### Add
- Add `status_codes_total` counter to Prometheus metrics. - Add `status_codes_total` counter to Prometheus metrics.
- Add client-side decryprion support for S3 integration.
- (pro) Add the `IMGPROXY_VIDEO_THUMBNAIL_KEYFRAMES` config and the [video_thumbnail_keyframes](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-keyframes) processing option. - (pro) Add the `IMGPROXY_VIDEO_THUMBNAIL_KEYFRAMES` config and the [video_thumbnail_keyframes](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-keyframes) processing option.
- (pro) Add the [video_thumbnail_tile](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-tile) processing option. - (pro) Add the [video_thumbnail_tile](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-tile) processing option.

View File

@ -99,11 +99,12 @@ var (
LocalFileSystemRoot string LocalFileSystemRoot string
S3Enabled bool S3Enabled bool
S3Region string S3Region string
S3Endpoint string S3Endpoint string
S3AssumeRoleArn string S3AssumeRoleArn string
S3MultiRegion bool S3MultiRegion bool
S3DecryptionClientEnabled bool
GCSEnabled bool GCSEnabled bool
GCSKey string GCSKey string
@ -300,6 +301,7 @@ func Reset() {
S3Endpoint = "" S3Endpoint = ""
S3AssumeRoleArn = "" S3AssumeRoleArn = ""
S3MultiRegion = false S3MultiRegion = false
S3DecryptionClientEnabled = false
GCSEnabled = false GCSEnabled = false
GCSKey = "" GCSKey = ""
ABSEnabled = false ABSEnabled = false
@ -501,6 +503,7 @@ func Configure() error {
configurators.String(&S3Endpoint, "IMGPROXY_S3_ENDPOINT") configurators.String(&S3Endpoint, "IMGPROXY_S3_ENDPOINT")
configurators.String(&S3AssumeRoleArn, "IMGPROXY_S3_ASSUME_ROLE_ARN") configurators.String(&S3AssumeRoleArn, "IMGPROXY_S3_ASSUME_ROLE_ARN")
configurators.Bool(&S3MultiRegion, "IMGPROXY_S3_MULTI_REGION") configurators.Bool(&S3MultiRegion, "IMGPROXY_S3_MULTI_REGION")
configurators.Bool(&S3DecryptionClientEnabled, "IMGPROXY_S3_USE_DECRYPTION_CLIENT")
configurators.Bool(&GCSEnabled, "IMGPROXY_USE_GCS") configurators.Bool(&GCSEnabled, "IMGPROXY_USE_GCS")
configurators.String(&GCSKey, "IMGPROXY_GCS_KEY") configurators.String(&GCSKey, "IMGPROXY_GCS_KEY")

View File

@ -4,7 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
http "net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -14,20 +15,27 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
"github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/imgproxy/imgproxy/v3/config" "github.com/imgproxy/imgproxy/v3/config"
defaultTransport "github.com/imgproxy/imgproxy/v3/transport" defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
) )
type s3Client interface {
GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput)
}
// transport implements RoundTripper for the 's3' protocol. // transport implements RoundTripper for the 's3' protocol.
type transport struct { type transport struct {
session *session.Session session *session.Session
defaultClient *s3.S3 defaultClient s3Client
defaultConfig *aws.Config
clientsByRegion map[string]*s3.S3 clientsByRegion map[string]s3Client
clientsByBucket map[string]*s3.S3 clientsByBucket map[string]s3Client
mu sync.RWMutex mu sync.RWMutex
} }
@ -49,7 +57,7 @@ func New() (http.RoundTripper, error) {
sess, err := session.NewSession() sess, err := session.NewSession()
if err != nil { if err != nil {
return nil, fmt.Errorf("Can't create S3 session: %s", err) return nil, fmt.Errorf("can't create S3 session: %s", err)
} }
if len(config.S3Region) != 0 { if len(config.S3Region) != 0 {
@ -64,19 +72,19 @@ func New() (http.RoundTripper, error) {
conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn) conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
} }
client := s3.New(sess, conf) client, err := createClient(sess, conf)
if err != nil {
clientRegion := "us-west-1" return nil, fmt.Errorf("can't create S3 client: %s", err)
if client.Config.Region != nil {
clientRegion = *client.Config.Region
} }
return &transport{ clientRegion := *sess.Config.Region
session: sess,
defaultClient: client,
clientsByRegion: map[string]*s3.S3{clientRegion: client}, return &transport{
clientsByBucket: make(map[string]*s3.S3), session: sess,
defaultClient: client,
defaultConfig: conf,
clientsByRegion: map[string]s3Client{clientRegion: client},
clientsByBucket: make(map[string]s3Client),
}, nil }, nil
} }
@ -113,7 +121,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
return handleError(req, err) return handleError(req, err)
} }
s3req, _ := client.GetObjectRequest(input) s3req, objectOutput := client.GetObjectRequest(input)
s3req.SetContext(req.Context()) s3req.SetContext(req.Context())
if err := s3req.Send(); err != nil { if err := s3req.Send(); err != nil {
@ -124,15 +132,27 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
return handleError(req, err) return handleError(req, err)
} }
if config.S3DecryptionClientEnabled {
s3req.HTTPResponse.Body = objectOutput.Body
if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 {
contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
if err != nil {
handleError(req, err)
}
s3req.HTTPResponse.ContentLength = contentLength
}
}
return s3req.HTTPResponse, nil return s3req.HTTPResponse, nil
} }
func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) { func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
if !config.S3MultiRegion { if !config.S3MultiRegion {
return t.defaultClient, nil return t.defaultClient, nil
} }
var client *s3.S3 var client s3Client
func() { func() {
t.mu.RLock() t.mu.RLock()
@ -152,7 +172,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
return client, nil return client, nil
} }
region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket) region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -162,10 +182,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
return client, nil return client, nil
} }
conf := t.defaultClient.Config.Copy() conf := t.defaultConfig.Copy()
conf.Region = aws.String(region) conf.Region = aws.String(region)
client = s3.New(t.session, conf) client, err = createClient(t.session, conf)
if err != nil {
return nil, fmt.Errorf("can't create regional S3 client: %s", err)
}
t.clientsByRegion[region] = client t.clientsByRegion[region] = client
t.clientsByBucket[bucket] = client t.clientsByBucket[bucket] = client
@ -173,6 +196,40 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
return client, nil return client, nil
} }
func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) {
if config.S3DecryptionClientEnabled {
// `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we
// need to copy the session with an additional config
sess = sess.Copy(conf)
cryptoRegistry, err := createCryptoRegistry(sess)
if err != nil {
return nil, err
}
return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry)
} else {
return s3.New(sess, conf), nil
}
}
func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) {
kmsClient := kms.New(sess)
cr := s3crypto.NewCryptoRegistry()
if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil {
return nil, err
}
if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil {
return nil, err
}
if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
return nil, err
}
return cr, nil
}
func handleError(req *http.Request, err error) (*http.Response, error) { func handleError(req *http.Request, err error) (*http.Response, error) {
if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode { if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
if e := s3err.OrigErr(); e != nil { if e := s3err.OrigErr(); e != nil {

View File

@ -50,15 +50,18 @@ func (s *S3TestSuite) SetupSuite() {
svc, err := s.transport.(*transport).getClient(context.Background(), "test") svc, err := s.transport.(*transport).getClient(context.Background(), "test")
require.Nil(s.T(), err) require.Nil(s.T(), err)
require.NotNil(s.T(), svc) require.NotNil(s.T(), svc)
require.IsType(s.T(), &s3.S3{}, svc)
_, err = svc.PutObject(&s3.PutObjectInput{ client := svc.(*s3.S3)
_, err = client.PutObject(&s3.PutObjectInput{
Body: bytes.NewReader(make([]byte, 32)), Body: bytes.NewReader(make([]byte, 32)),
Bucket: aws.String("test"), Bucket: aws.String("test"),
Key: aws.String("foo/test.png"), Key: aws.String("foo/test.png"),
}) })
require.Nil(s.T(), err) require.Nil(s.T(), err)
obj, err := svc.GetObject(&s3.GetObjectInput{ obj, err := client.GetObject(&s3.GetObjectInput{
Bucket: aws.String("test"), Bucket: aws.String("test"),
Key: aws.String("foo/test.png"), Key: aws.String("foo/test.png"),
}) })