mirror of
https://github.com/imgproxy/imgproxy.git
synced 2024-11-24 08:12:38 +02:00
Add client-side decryption support to S3
This commit is contained in:
parent
8b5255da7d
commit
b384e2bb7f
@ -3,6 +3,7 @@
|
||||
## [Unreleased]
|
||||
### Add
|
||||
- Add `status_codes_total` counter to Prometheus metrics.
|
||||
- Add client-side decryprion support for S3 integration.
|
||||
- (pro) Add the `IMGPROXY_VIDEO_THUMBNAIL_KEYFRAMES` config and the [video_thumbnail_keyframes](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-keyframes) processing option.
|
||||
- (pro) Add the [video_thumbnail_tile](https://docs.imgproxy.net/latest/generating_the_url?id=video-thumbnail-tile) processing option.
|
||||
|
||||
|
@ -99,11 +99,12 @@ var (
|
||||
|
||||
LocalFileSystemRoot string
|
||||
|
||||
S3Enabled bool
|
||||
S3Region string
|
||||
S3Endpoint string
|
||||
S3AssumeRoleArn string
|
||||
S3MultiRegion bool
|
||||
S3Enabled bool
|
||||
S3Region string
|
||||
S3Endpoint string
|
||||
S3AssumeRoleArn string
|
||||
S3MultiRegion bool
|
||||
S3DecryptionClientEnabled bool
|
||||
|
||||
GCSEnabled bool
|
||||
GCSKey string
|
||||
@ -300,6 +301,7 @@ func Reset() {
|
||||
S3Endpoint = ""
|
||||
S3AssumeRoleArn = ""
|
||||
S3MultiRegion = false
|
||||
S3DecryptionClientEnabled = false
|
||||
GCSEnabled = false
|
||||
GCSKey = ""
|
||||
ABSEnabled = false
|
||||
@ -501,6 +503,7 @@ func Configure() error {
|
||||
configurators.String(&S3Endpoint, "IMGPROXY_S3_ENDPOINT")
|
||||
configurators.String(&S3AssumeRoleArn, "IMGPROXY_S3_ASSUME_ROLE_ARN")
|
||||
configurators.Bool(&S3MultiRegion, "IMGPROXY_S3_MULTI_REGION")
|
||||
configurators.Bool(&S3DecryptionClientEnabled, "IMGPROXY_S3_USE_DECRYPTION_CLIENT")
|
||||
|
||||
configurators.Bool(&GCSEnabled, "IMGPROXY_USE_GCS")
|
||||
configurators.String(&GCSKey, "IMGPROXY_GCS_KEY")
|
||||
|
@ -4,7 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
http "net/http"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -14,20 +15,27 @@ import (
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3crypto"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
|
||||
"github.com/imgproxy/imgproxy/v3/config"
|
||||
defaultTransport "github.com/imgproxy/imgproxy/v3/transport"
|
||||
)
|
||||
|
||||
type s3Client interface {
|
||||
GetObjectRequest(input *s3.GetObjectInput) (req *request.Request, output *s3.GetObjectOutput)
|
||||
}
|
||||
|
||||
// transport implements RoundTripper for the 's3' protocol.
|
||||
type transport struct {
|
||||
session *session.Session
|
||||
defaultClient *s3.S3
|
||||
defaultClient s3Client
|
||||
defaultConfig *aws.Config
|
||||
|
||||
clientsByRegion map[string]*s3.S3
|
||||
clientsByBucket map[string]*s3.S3
|
||||
clientsByRegion map[string]s3Client
|
||||
clientsByBucket map[string]s3Client
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@ -49,7 +57,7 @@ func New() (http.RoundTripper, error) {
|
||||
|
||||
sess, err := session.NewSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Can't create S3 session: %s", err)
|
||||
return nil, fmt.Errorf("can't create S3 session: %s", err)
|
||||
}
|
||||
|
||||
if len(config.S3Region) != 0 {
|
||||
@ -64,19 +72,19 @@ func New() (http.RoundTripper, error) {
|
||||
conf.Credentials = stscreds.NewCredentials(sess, config.S3AssumeRoleArn)
|
||||
}
|
||||
|
||||
client := s3.New(sess, conf)
|
||||
|
||||
clientRegion := "us-west-1"
|
||||
if client.Config.Region != nil {
|
||||
clientRegion = *client.Config.Region
|
||||
client, err := createClient(sess, conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't create S3 client: %s", err)
|
||||
}
|
||||
|
||||
return &transport{
|
||||
session: sess,
|
||||
defaultClient: client,
|
||||
clientRegion := *sess.Config.Region
|
||||
|
||||
clientsByRegion: map[string]*s3.S3{clientRegion: client},
|
||||
clientsByBucket: make(map[string]*s3.S3),
|
||||
return &transport{
|
||||
session: sess,
|
||||
defaultClient: client,
|
||||
defaultConfig: conf,
|
||||
clientsByRegion: map[string]s3Client{clientRegion: client},
|
||||
clientsByBucket: make(map[string]s3Client),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -113,7 +121,7 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return handleError(req, err)
|
||||
}
|
||||
|
||||
s3req, _ := client.GetObjectRequest(input)
|
||||
s3req, objectOutput := client.GetObjectRequest(input)
|
||||
s3req.SetContext(req.Context())
|
||||
|
||||
if err := s3req.Send(); err != nil {
|
||||
@ -124,15 +132,27 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return handleError(req, err)
|
||||
}
|
||||
|
||||
if config.S3DecryptionClientEnabled {
|
||||
s3req.HTTPResponse.Body = objectOutput.Body
|
||||
|
||||
if unencryptedContentLength := s3req.HTTPResponse.Header.Get("X-Amz-Meta-X-Amz-Unencrypted-Content-Length"); len(unencryptedContentLength) != 0 {
|
||||
contentLength, err := strconv.ParseInt(unencryptedContentLength, 10, 64)
|
||||
if err != nil {
|
||||
handleError(req, err)
|
||||
}
|
||||
s3req.HTTPResponse.ContentLength = contentLength
|
||||
}
|
||||
}
|
||||
|
||||
return s3req.HTTPResponse, nil
|
||||
}
|
||||
|
||||
func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error) {
|
||||
func (t *transport) getClient(ctx context.Context, bucket string) (s3Client, error) {
|
||||
if !config.S3MultiRegion {
|
||||
return t.defaultClient, nil
|
||||
}
|
||||
|
||||
var client *s3.S3
|
||||
var client s3Client
|
||||
|
||||
func() {
|
||||
t.mu.RLock()
|
||||
@ -152,7 +172,7 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
|
||||
return client, nil
|
||||
}
|
||||
|
||||
region, err := s3manager.GetBucketRegionWithClient(ctx, t.defaultClient, bucket)
|
||||
region, err := s3manager.GetBucketRegion(ctx, t.session, bucket, *t.session.Config.Region)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -162,10 +182,13 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
|
||||
return client, nil
|
||||
}
|
||||
|
||||
conf := t.defaultClient.Config.Copy()
|
||||
conf := t.defaultConfig.Copy()
|
||||
conf.Region = aws.String(region)
|
||||
|
||||
client = s3.New(t.session, conf)
|
||||
client, err = createClient(t.session, conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't create regional S3 client: %s", err)
|
||||
}
|
||||
|
||||
t.clientsByRegion[region] = client
|
||||
t.clientsByBucket[bucket] = client
|
||||
@ -173,6 +196,40 @@ func (t *transport) getClient(ctx context.Context, bucket string) (*s3.S3, error
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func createClient(sess *session.Session, conf *aws.Config) (s3Client, error) {
|
||||
if config.S3DecryptionClientEnabled {
|
||||
// `s3crypto.NewDecryptionClientV2` doesn't accept additional configs, so we
|
||||
// need to copy the session with an additional config
|
||||
sess = sess.Copy(conf)
|
||||
|
||||
cryptoRegistry, err := createCryptoRegistry(sess)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s3crypto.NewDecryptionClientV2(sess, cryptoRegistry)
|
||||
} else {
|
||||
return s3.New(sess, conf), nil
|
||||
}
|
||||
}
|
||||
|
||||
func createCryptoRegistry(sess *session.Session) (*s3crypto.CryptoRegistry, error) {
|
||||
kmsClient := kms.New(sess)
|
||||
|
||||
cr := s3crypto.NewCryptoRegistry()
|
||||
if err := s3crypto.RegisterKMSWrapWithAnyCMK(cr, kmsClient); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, kmsClient); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cr, nil
|
||||
}
|
||||
|
||||
func handleError(req *http.Request, err error) (*http.Response, error) {
|
||||
if s3err, ok := err.(awserr.Error); ok && s3err.Code() == request.CanceledErrorCode {
|
||||
if e := s3err.OrigErr(); e != nil {
|
||||
|
@ -50,15 +50,18 @@ func (s *S3TestSuite) SetupSuite() {
|
||||
svc, err := s.transport.(*transport).getClient(context.Background(), "test")
|
||||
require.Nil(s.T(), err)
|
||||
require.NotNil(s.T(), svc)
|
||||
require.IsType(s.T(), &s3.S3{}, svc)
|
||||
|
||||
_, err = svc.PutObject(&s3.PutObjectInput{
|
||||
client := svc.(*s3.S3)
|
||||
|
||||
_, err = client.PutObject(&s3.PutObjectInput{
|
||||
Body: bytes.NewReader(make([]byte, 32)),
|
||||
Bucket: aws.String("test"),
|
||||
Key: aws.String("foo/test.png"),
|
||||
})
|
||||
require.Nil(s.T(), err)
|
||||
|
||||
obj, err := svc.GetObject(&s3.GetObjectInput{
|
||||
obj, err := client.GetObject(&s3.GetObjectInput{
|
||||
Bucket: aws.String("test"),
|
||||
Key: aws.String("foo/test.png"),
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user