1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-02-16 01:19:46 +02:00
pocketbase/apis/batch.go

549 lines
14 KiB
Go

package apis
import (
"bytes"
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"regexp"
"slices"
"strconv"
"strings"
"time"
validation "github.com/go-ozzo/ozzo-validation/v4"
"github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/tools/filesystem"
"github.com/pocketbase/pocketbase/tools/router"
"github.com/pocketbase/pocketbase/tools/types"
"github.com/spf13/cast"
)
func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) {
sub := rg.Group("/batch")
sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined
}
type HandleFunc func(e *core.RequestEvent) error
type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func(data any) error) HandleFunc
// ValidBatchActions defines a map with the supported batch InternalRequest actions.
//
// Note: when adding new routes make sure that their middlewares are inlined!
var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{
// "upsert" handler
regexp.MustCompile(`^PUT /api/collections/(?P<collection>[^\/\?]+)/records(?P<query>\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
var id string
if len(ir.Body) > 0 && ir.Body["id"] != "" {
id = cast.ToString(ir.Body["id"])
}
if id != "" {
_, err := app.FindRecordById(params["collection"], id)
if err == nil {
// update
// ---
params["id"] = id // required for the path value
ir.Method = "PATCH"
ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"]
return recordUpdate(next)
}
}
// create
// ---
ir.Method = "POST"
ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"]
return recordCreate(next)
},
regexp.MustCompile(`^POST /api/collections/(?P<collection>[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordCreate(next)
},
regexp.MustCompile(`^PATCH /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordUpdate(next)
},
regexp.MustCompile(`^DELETE /api/collections/(?P<collection>[^\/\?]+)/records/(?P<id>[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc {
return recordDelete(next)
},
}
type BatchRequestResult struct {
Body any `json:"body"`
Status int `json:"status"`
}
type batchRequestsForm struct {
Requests []*core.InternalRequest `form:"requests" json:"requests"`
max int
}
func (brs batchRequestsForm) validate() error {
return validation.ValidateStruct(&brs,
validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)),
)
}
// NB! When the request is submitted as multipart/form-data,
// the regular fields data is expected to be submitted as serailized
// json under the @jsonPayload field and file keys need to follow the
// pattern "requests.N.fileField" or requests[N].fileField.
func batchTransaction(e *core.RequestEvent) error {
maxRequests := e.App.Settings().Batch.MaxRequests
if !e.App.Settings().Batch.Enabled || maxRequests <= 0 {
return e.ForbiddenError("Batch requests are not allowed.", nil)
}
txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second
if txTimeout <= 0 {
txTimeout = 3 * time.Second // for now always limit
}
maxBodySize := e.App.Settings().Batch.MaxBodySize
if maxBodySize <= 0 {
maxBodySize = 128 << 20
}
err := applyBodyLimit(e, maxBodySize)
if err != nil {
return err
}
form := &batchRequestsForm{max: maxRequests}
// load base requests data
err = e.BindBody(form)
if err != nil {
return e.BadRequestError("Failed to read the submitted batch data.", err)
}
// load uploaded files into each request item
// note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format
// (the other regular fields must be put under `@jsonPayload` as serialized json)
if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") {
for i, ir := range form.Requests {
iStr := strconv.Itoa(i)
files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].")
if err != nil {
return e.BadRequestError("Failed to read the submitted batch files data.", err)
}
for key, files := range files {
if ir.Body == nil {
ir.Body = map[string]any{}
}
ir.Body[key] = files
}
}
}
// validate batch request form
err = form.validate()
if err != nil {
return e.BadRequestError("Invalid batch request data.", err)
}
event := new(core.BatchRequestEvent)
event.RequestEvent = e
event.Batch = form.Requests
return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error {
bp := batchProcessor{
app: e.App,
baseEvent: e.RequestEvent,
infoContext: core.RequestInfoContextBatch,
}
if err := bp.Process(e.Batch, txTimeout); err != nil {
return firstApiError(err, e.BadRequestError("Batch transaction failed.", err))
}
return e.JSON(http.StatusOK, bp.results)
})
}
type batchProcessor struct {
app core.App
baseEvent *core.RequestEvent
infoContext string
results []*BatchRequestResult
failedIndex int
errCh chan error
stopCh chan struct{}
}
func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error {
p.results = make([]*BatchRequestResult, 0, len(batch))
if p.stopCh != nil {
close(p.stopCh)
}
p.stopCh = make(chan struct{}, 1)
if p.errCh != nil {
close(p.errCh)
}
p.errCh = make(chan error, 1)
return p.app.RunInTransaction(func(txApp core.App) error {
// used to interupts the recursive processing calls in case of a timeout or connection close
defer func() {
p.stopCh <- struct{}{}
}()
go func() {
err := p.process(txApp, batch, 0)
if err != nil {
err = validation.Errors{
"requests": validation.Errors{
strconv.Itoa(p.failedIndex): &BatchResponseError{
code: "batch_request_failed",
message: "Batch request failed.",
err: router.ToApiError(err),
},
},
}
}
// note: to avoid copying and due to the process recursion the final results order is reversed
if err == nil {
slices.Reverse(p.results)
}
p.errCh <- err
}()
select {
case responseErr := <-p.errCh:
return responseErr
case <-time.After(timeout):
// note: we don't return 408 Reques Timeout error because
// some browsers perform automatic retry behind the scenes
// which are hard to debug and unnecessary
return errors.New("batch transaction timeout")
case <-p.baseEvent.Request.Context().Done():
return errors.New("batch request interrupted")
}
})
}
func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error {
select {
case <-p.stopCh:
return nil
default:
if len(batch) == 0 {
return nil
}
result, err := processInternalRequest(
activeApp,
p.baseEvent,
batch[0],
p.infoContext,
func(_ any) error {
if len(batch) == 1 {
return nil
}
err := p.process(activeApp, batch[1:], i+1)
// update the failed batch index (if not already)
if err != nil && p.failedIndex == 0 {
p.failedIndex = i + 1
}
return err
},
)
if err != nil {
return err
}
p.results = append(p.results, result)
return nil
}
}
func processInternalRequest(
activeApp core.App,
baseEvent *core.RequestEvent,
ir *core.InternalRequest,
infoContext string,
optNext func(data any) error,
) (*BatchRequestResult, error) {
handle, params, ok := prepareInternalAction(activeApp, ir, optNext)
if !ok {
return nil, errors.New("unknown batch request action")
}
// construct a new http.Request
// ---------------------------------------------------------------
buf, mw, err := multipartDataFromInternalRequest(ir)
if err != nil {
return nil, err
}
r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf)
if err != nil {
return nil, err
}
// cleanup multipart temp files
defer func() {
if r.MultipartForm != nil {
if err := r.MultipartForm.RemoveAll(); err != nil {
activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err)
}
}
}()
// load batch request path params
// ---
for k, v := range params {
r.SetPathValue(k, v)
}
// clone original request
// ---
r.RequestURI = r.URL.RequestURI()
r.Proto = baseEvent.Request.Proto
r.ProtoMajor = baseEvent.Request.ProtoMajor
r.ProtoMinor = baseEvent.Request.ProtoMinor
r.Host = baseEvent.Request.Host
r.RemoteAddr = baseEvent.Request.RemoteAddr
r.TLS = baseEvent.Request.TLS
if s := baseEvent.Request.TransferEncoding; s != nil {
s2 := make([]string, len(s))
copy(s2, s)
r.TransferEncoding = s2
}
if baseEvent.Request.Trailer != nil {
r.Trailer = baseEvent.Request.Trailer.Clone()
}
if baseEvent.Request.Header != nil {
r.Header = baseEvent.Request.Header.Clone()
}
// apply batch request specific headers
// ---
for k, v := range ir.Headers {
// individual Authorization header keys don't have affect
// because the auth state is populated from the base event
if strings.EqualFold(k, "authorization") {
continue
}
r.Header.Set(k, v)
}
r.Header.Set("Content-Type", mw.FormDataContentType())
// construct a new RequestEvent
// ---------------------------------------------------------------
event := &core.RequestEvent{}
event.App = activeApp
event.Auth = baseEvent.Auth
event.SetAll(baseEvent.GetAll())
// load RequestInfo context
if infoContext == "" {
infoContext = core.RequestInfoContextDefault
}
event.Set(core.RequestEventKeyInfoContext, infoContext)
// assign request
event.Request = r
event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads
// assign response
rec := httptest.NewRecorder()
event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking
// execute
// ---------------------------------------------------------------
if err := handle(event); err != nil {
return nil, err
}
result := rec.Result()
defer result.Body.Close()
body, _ := types.ParseJSONRaw(rec.Body.Bytes())
return &BatchRequestResult{
Status: result.StatusCode,
Body: body,
}, nil
}
func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) {
buf := &bytes.Buffer{}
mw := multipart.NewWriter(buf)
regularFields := map[string]any{}
fileFields := map[string][]*filesystem.File{}
// separate regular fields from files
// ---
for k, rawV := range ir.Body {
switch v := rawV.(type) {
case *filesystem.File:
fileFields[k] = append(fileFields[k], v)
case []*filesystem.File:
fileFields[k] = append(fileFields[k], v...)
default:
regularFields[k] = v
}
}
// submit regularFields as @jsonPayload
// ---
rawBody, err := json.Marshal(regularFields)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
jsonPayload, err := mw.CreateFormField("@jsonPayload")
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = jsonPayload.Write(rawBody)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
// submit fileFields as multipart files
// ---
for key, files := range fileFields {
for _, file := range files {
part, err := mw.CreateFormFile(key, file.Name)
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
fr, err := file.Reader.Open()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
_, err = io.Copy(part, fr)
if err != nil {
return nil, nil, errors.Join(err, fr.Close(), mw.Close())
}
err = fr.Close()
if err != nil {
return nil, nil, errors.Join(err, mw.Close())
}
}
}
return buf, mw, mw.Close()
}
func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) {
if request.MultipartForm == nil {
if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil {
return nil, err
}
}
result := make(map[string][]*filesystem.File)
for k, fhs := range request.MultipartForm.File {
for _, p := range prefixes {
if strings.HasPrefix(k, p) {
resultKey := strings.TrimPrefix(k, p)
for _, fh := range fhs {
file, err := filesystem.NewFileFromMultipart(fh)
if err != nil {
return nil, err
}
result[resultKey] = append(result[resultKey], file)
}
}
}
}
return result, nil
}
func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func(data any) error) (HandleFunc, map[string]string, bool) {
full := strings.ToUpper(ir.Method) + " " + ir.URL
for re, actionFactory := range ValidBatchActions {
params, ok := findNamedMatches(re, full)
if ok {
return actionFactory(activeApp, ir, params, optNext), params, true
}
}
return nil, nil, false
}
func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) {
match := re.FindStringSubmatch(str)
if match == nil {
return nil, false
}
result := map[string]string{}
names := re.SubexpNames()
for i, m := range match {
if names[i] != "" {
result[names[i]] = m
}
}
return result, true
}
// -------------------------------------------------------------------
var (
_ router.SafeErrorItem = (*BatchResponseError)(nil)
_ router.SafeErrorResolver = (*BatchResponseError)(nil)
)
type BatchResponseError struct {
err *router.ApiError
code string
message string
}
func (e *BatchResponseError) Error() string {
return e.message
}
func (e *BatchResponseError) Code() string {
return e.code
}
func (e *BatchResponseError) Resolve(errData map[string]any) any {
errData["response"] = e.err
return errData
}
func (e BatchResponseError) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"message": e.message,
"code": e.code,
"response": e.err,
})
}