mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-04-03 01:56:17 +02:00
415 lines
9.1 KiB
Go
415 lines
9.1 KiB
Go
package s3
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/xml"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
var ErrUsedUploader = errors.New("the Uploader has been already used")
|
|
|
|
const (
|
|
defaultMaxConcurrency int = 5
|
|
defaultMinPartSize int = 6 << 20
|
|
)
|
|
|
|
// Uploader handles S3 object upload.
|
|
//
|
|
// If the Payload size is less than the configured MinPartSize it sends
|
|
// a single (PutObject) request, otherwise performs chunked/multipart upload.
|
|
type Uploader struct {
|
|
// S3 is the S3 client instance performing the upload object request (required).
|
|
S3 *S3
|
|
|
|
// Payload is the object content to upload (required).
|
|
Payload io.Reader
|
|
|
|
// Key is the destination key of the uploaded object (required).
|
|
Key string
|
|
|
|
// Metadata specifies the optional metadata to write with the object upload.
|
|
Metadata map[string]string
|
|
|
|
// MaxConcurrency specifies the max number of workers to use when
|
|
// performing chunked/multipart upload.
|
|
//
|
|
// If zero or negative, defaults to 5.
|
|
//
|
|
// This option is used only when the Payload size is > MinPartSize.
|
|
MaxConcurrency int
|
|
|
|
// MinPartSize specifies the min Payload size required to perform
|
|
// chunked/multipart upload.
|
|
//
|
|
// If zero or negative, defaults to ~6MB.
|
|
MinPartSize int
|
|
|
|
uploadId string
|
|
uploadedParts []*mpPart
|
|
lastPartNumber int
|
|
mu sync.Mutex // guards lastPartNumber and the uploadedParts slice
|
|
used bool
|
|
}
|
|
|
|
// Upload processes the current Uploader instance.
|
|
//
|
|
// Users can specify an optional optReqFuncs that will be passed down to all Upload internal requests
|
|
// (single upload, multipart init, multipart parts upload, multipart complete, multipart abort).
|
|
//
|
|
// Note that after this call the Uploader should be discarded (aka. no longer can be used).
|
|
func (u *Uploader) Upload(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
|
if u.used {
|
|
return ErrUsedUploader
|
|
}
|
|
|
|
err := u.validateAndNormalize()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
initPart, _, err := u.nextPart()
|
|
if err != nil && !errors.Is(err, io.EOF) {
|
|
return err
|
|
}
|
|
|
|
if len(initPart) < u.MinPartSize {
|
|
return u.singleUpload(ctx, initPart, optReqFuncs...)
|
|
}
|
|
|
|
err = u.multipartInit(ctx, optReqFuncs...)
|
|
if err != nil {
|
|
return fmt.Errorf("multipart init error: %w", err)
|
|
}
|
|
|
|
err = u.multipartUpload(ctx, initPart, optReqFuncs...)
|
|
if err != nil {
|
|
return errors.Join(
|
|
u.multipartAbort(ctx, optReqFuncs...),
|
|
fmt.Errorf("multipart upload error: %w", err),
|
|
)
|
|
}
|
|
|
|
err = u.multipartComplete(ctx, optReqFuncs...)
|
|
if err != nil {
|
|
return errors.Join(
|
|
u.multipartAbort(ctx, optReqFuncs...),
|
|
fmt.Errorf("multipart complete error: %w", err),
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// -------------------------------------------------------------------
|
|
|
|
func (u *Uploader) validateAndNormalize() error {
|
|
if u.S3 == nil {
|
|
return errors.New("Uploader.S3 must be a non-empty and properly initialized S3 client instance")
|
|
}
|
|
|
|
if u.Key == "" {
|
|
return errors.New("Uploader.Key is required")
|
|
}
|
|
|
|
if u.Payload == nil {
|
|
return errors.New("Uploader.Payload must be non-nill")
|
|
}
|
|
|
|
if u.MaxConcurrency <= 0 {
|
|
u.MaxConcurrency = defaultMaxConcurrency
|
|
}
|
|
|
|
if u.MinPartSize <= 0 {
|
|
u.MinPartSize = defaultMinPartSize
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (u *Uploader) singleUpload(ctx context.Context, part []byte, optReqFuncs ...func(*http.Request)) error {
|
|
if u.used {
|
|
return ErrUsedUploader
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.S3.URL(u.Key), bytes.NewReader(part))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Content-Length", strconv.Itoa(len(part)))
|
|
|
|
for k, v := range u.Metadata {
|
|
req.Header.Set(metadataPrefix+k, v)
|
|
}
|
|
|
|
// apply optional request funcs
|
|
for _, fn := range optReqFuncs {
|
|
if fn != nil {
|
|
fn(req)
|
|
}
|
|
}
|
|
|
|
resp, err := u.S3.SignAndSend(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
// -------------------------------------------------------------------
|
|
|
|
type mpPart struct {
|
|
XMLName xml.Name `xml:"Part"`
|
|
ETag string `xml:"ETag"`
|
|
PartNumber int `xml:"PartNumber"`
|
|
}
|
|
|
|
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CreateMultipartUpload.html
|
|
func (u *Uploader) multipartInit(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
|
if u.used {
|
|
return ErrUsedUploader
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.S3.URL(u.Key+"?uploads"), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for k, v := range u.Metadata {
|
|
req.Header.Set(metadataPrefix+k, v)
|
|
}
|
|
|
|
// apply optional request funcs
|
|
for _, fn := range optReqFuncs {
|
|
if fn != nil {
|
|
fn(req)
|
|
}
|
|
}
|
|
|
|
resp, err := u.S3.SignAndSend(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body := &struct {
|
|
XMLName xml.Name `xml:"InitiateMultipartUploadResult"`
|
|
UploadId string `xml:"UploadId"`
|
|
}{}
|
|
|
|
err = xml.NewDecoder(resp.Body).Decode(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
u.uploadId = body.UploadId
|
|
|
|
return nil
|
|
}
|
|
|
|
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_AbortMultipartUpload.html
|
|
func (u *Uploader) multipartAbort(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
|
|
u.used = true
|
|
|
|
// ensure that the specified abort context is always valid to allow cleanup
|
|
var abortCtx = ctx
|
|
if abortCtx.Err() != nil {
|
|
abortCtx = context.Background()
|
|
}
|
|
|
|
query := url.Values{"uploadId": []string{u.uploadId}}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, u.S3.URL(u.Key+"?"+query.Encode()), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// apply optional request funcs
|
|
for _, fn := range optReqFuncs {
|
|
if fn != nil {
|
|
fn(req)
|
|
}
|
|
}
|
|
|
|
resp, err := u.S3.SignAndSend(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html
|
|
func (u *Uploader) multipartComplete(ctx context.Context, optReqFuncs ...func(*http.Request)) error {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
|
|
u.used = true
|
|
|
|
// the list of parts must be sorted in ascending order
|
|
slices.SortFunc(u.uploadedParts, func(a, b *mpPart) int {
|
|
if a.PartNumber < b.PartNumber {
|
|
return -1
|
|
}
|
|
if a.PartNumber > b.PartNumber {
|
|
return 1
|
|
}
|
|
return 0
|
|
})
|
|
|
|
// build a request payload with the uploaded parts
|
|
xmlParts := &struct {
|
|
XMLName xml.Name `xml:"CompleteMultipartUpload"`
|
|
Parts []*mpPart
|
|
}{
|
|
Parts: u.uploadedParts,
|
|
}
|
|
rawXMLParts, err := xml.Marshal(xmlParts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
reqPayload := strings.NewReader(xml.Header + string(rawXMLParts))
|
|
|
|
query := url.Values{"uploadId": []string{u.uploadId}}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.S3.URL(u.Key+"?"+query.Encode()), reqPayload)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// apply optional request funcs
|
|
for _, fn := range optReqFuncs {
|
|
if fn != nil {
|
|
fn(req)
|
|
}
|
|
}
|
|
|
|
resp, err := u.S3.SignAndSend(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (u *Uploader) nextPart() ([]byte, int, error) {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
|
|
part := make([]byte, u.MinPartSize)
|
|
n, err := io.ReadFull(u.Payload, part)
|
|
|
|
// normalize io.EOF errors and ensure that io.EOF is returned only when there were no read bytes
|
|
if err != nil && (errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
|
if n == 0 {
|
|
err = io.EOF
|
|
} else {
|
|
err = nil
|
|
}
|
|
}
|
|
|
|
u.lastPartNumber++
|
|
|
|
return part[0:n], u.lastPartNumber, err
|
|
}
|
|
|
|
func (u *Uploader) multipartUpload(ctx context.Context, initPart []byte, optReqFuncs ...func(*http.Request)) error {
|
|
var g errgroup.Group
|
|
g.SetLimit(u.MaxConcurrency)
|
|
|
|
totalParallel := u.MaxConcurrency
|
|
|
|
if len(initPart) != 0 {
|
|
totalParallel--
|
|
initPartNumber := u.lastPartNumber
|
|
g.Go(func() error {
|
|
mp, err := u.uploadPart(ctx, initPartNumber, initPart, optReqFuncs...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
u.mu.Lock()
|
|
u.uploadedParts = append(u.uploadedParts, mp)
|
|
u.mu.Unlock()
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
for i := 0; i < totalParallel; i++ {
|
|
g.Go(func() error {
|
|
for {
|
|
part, num, err := u.nextPart()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
|
|
mp, err := u.uploadPart(ctx, num, part, optReqFuncs...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
u.mu.Lock()
|
|
u.uploadedParts = append(u.uploadedParts, mp)
|
|
u.mu.Unlock()
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
return g.Wait()
|
|
}
|
|
|
|
// https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html
|
|
func (u *Uploader) uploadPart(ctx context.Context, partNumber int, partData []byte, optReqFuncs ...func(*http.Request)) (*mpPart, error) {
|
|
query := url.Values{}
|
|
query.Set("uploadId", u.uploadId)
|
|
query.Set("partNumber", strconv.Itoa(partNumber))
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, u.S3.URL(u.Key+"?"+query.Encode()), bytes.NewReader(partData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Length", strconv.Itoa(len(partData)))
|
|
|
|
// apply optional request funcs
|
|
for _, fn := range optReqFuncs {
|
|
if fn != nil {
|
|
fn(req)
|
|
}
|
|
}
|
|
|
|
resp, err := u.S3.SignAndSend(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
return &mpPart{
|
|
PartNumber: partNumber,
|
|
ETag: resp.Header.Get("ETag"),
|
|
}, nil
|
|
}
|