package apis import ( "bytes" "encoding/json" "errors" "io" "mime/multipart" "net/http" "net/http/httptest" "regexp" "slices" "strconv" "strings" "time" validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tools/filesystem" "github.com/pocketbase/pocketbase/tools/router" "github.com/pocketbase/pocketbase/tools/types" "github.com/spf13/cast" ) func bindBatchApi(app core.App, rg *router.RouterGroup[*core.RequestEvent]) { sub := rg.Group("/batch") sub.POST("", batchTransaction).Unbind(DefaultBodyLimitMiddlewareId) // the body limit is inlined } type HandleFunc func(e *core.RequestEvent) error type BatchActionHandlerFunc func(app core.App, ir *core.InternalRequest, params map[string]string, next func(data any) error) HandleFunc // ValidBatchActions defines a map with the supported batch InternalRequest actions. // // Note: when adding new routes make sure that their middlewares are inlined! var ValidBatchActions = map[*regexp.Regexp]BatchActionHandlerFunc{ // "upsert" handler regexp.MustCompile(`^PUT /api/collections/(?P[^\/\?]+)/records(?P\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { var id string if len(ir.Body) > 0 && ir.Body["id"] != "" { id = cast.ToString(ir.Body["id"]) } if id != "" { _, err := app.FindRecordById(params["collection"], id) if err == nil { // update // --- params["id"] = id // required for the path value ir.Method = "PATCH" ir.URL = "/api/collections/" + params["collection"] + "/records/" + id + params["query"] return recordUpdate(next) } } // create // --- ir.Method = "POST" ir.URL = "/api/collections/" + params["collection"] + "/records" + params["query"] return recordCreate(next) }, regexp.MustCompile(`^POST /api/collections/(?P[^\/\?]+)/records(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { return recordCreate(next) }, regexp.MustCompile(`^PATCH /api/collections/(?P[^\/\?]+)/records/(?P[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { return recordUpdate(next) }, regexp.MustCompile(`^DELETE /api/collections/(?P[^\/\?]+)/records/(?P[^\/\?]+)(\?.*)?$`): func(app core.App, ir *core.InternalRequest, params map[string]string, next func(any) error) HandleFunc { return recordDelete(next) }, } type BatchRequestResult struct { Body any `json:"body"` Status int `json:"status"` } type batchRequestsForm struct { Requests []*core.InternalRequest `form:"requests" json:"requests"` max int } func (brs batchRequestsForm) validate() error { return validation.ValidateStruct(&brs, validation.Field(&brs.Requests, validation.Required, validation.Length(0, brs.max)), ) } // NB! When the request is submitted as multipart/form-data, // the regular fields data is expected to be submitted as serailized // json under the @jsonPayload field and file keys need to follow the // pattern "requests.N.fileField" or requests[N].fileField. func batchTransaction(e *core.RequestEvent) error { maxRequests := e.App.Settings().Batch.MaxRequests if !e.App.Settings().Batch.Enabled || maxRequests <= 0 { return e.ForbiddenError("Batch requests are not allowed.", nil) } txTimeout := time.Duration(e.App.Settings().Batch.Timeout) * time.Second if txTimeout <= 0 { txTimeout = 3 * time.Second // for now always limit } maxBodySize := e.App.Settings().Batch.MaxBodySize if maxBodySize <= 0 { maxBodySize = 128 << 20 } err := applyBodyLimit(e, maxBodySize) if err != nil { return err } form := &batchRequestsForm{max: maxRequests} // load base requests data err = e.BindBody(form) if err != nil { return e.BadRequestError("Failed to read the submitted batch data.", err) } // load uploaded files into each request item // note: expects the files to be under "requests.N.fileField" or "requests[N].fileField" format // (the other regular fields must be put under `@jsonPayload` as serialized json) if strings.HasPrefix(e.Request.Header.Get("Content-Type"), "multipart/form-data") { for i, ir := range form.Requests { iStr := strconv.Itoa(i) files, err := extractPrefixedFiles(e.Request, "requests."+iStr+".", "requests["+iStr+"].") if err != nil { return e.BadRequestError("Failed to read the submitted batch files data.", err) } for key, files := range files { if ir.Body == nil { ir.Body = map[string]any{} } ir.Body[key] = files } } } // validate batch request form err = form.validate() if err != nil { return e.BadRequestError("Invalid batch request data.", err) } event := new(core.BatchRequestEvent) event.RequestEvent = e event.Batch = form.Requests return e.App.OnBatchRequest().Trigger(event, func(e *core.BatchRequestEvent) error { bp := batchProcessor{ app: e.App, baseEvent: e.RequestEvent, infoContext: core.RequestInfoContextBatch, } if err := bp.Process(e.Batch, txTimeout); err != nil { return firstApiError(err, e.BadRequestError("Batch transaction failed.", err)) } return e.JSON(http.StatusOK, bp.results) }) } type batchProcessor struct { app core.App baseEvent *core.RequestEvent infoContext string results []*BatchRequestResult failedIndex int errCh chan error stopCh chan struct{} } func (p *batchProcessor) Process(batch []*core.InternalRequest, timeout time.Duration) error { p.results = make([]*BatchRequestResult, 0, len(batch)) if p.stopCh != nil { close(p.stopCh) } p.stopCh = make(chan struct{}, 1) if p.errCh != nil { close(p.errCh) } p.errCh = make(chan error, 1) return p.app.RunInTransaction(func(txApp core.App) error { // used to interupts the recursive processing calls in case of a timeout or connection close defer func() { p.stopCh <- struct{}{} }() go func() { err := p.process(txApp, batch, 0) if err != nil { err = validation.Errors{ "requests": validation.Errors{ strconv.Itoa(p.failedIndex): &BatchResponseError{ code: "batch_request_failed", message: "Batch request failed.", err: router.ToApiError(err), }, }, } } // note: to avoid copying and due to the process recursion the final results order is reversed if err == nil { slices.Reverse(p.results) } p.errCh <- err }() select { case responseErr := <-p.errCh: return responseErr case <-time.After(timeout): // note: we don't return 408 Reques Timeout error because // some browsers perform automatic retry behind the scenes // which are hard to debug and unnecessary return errors.New("batch transaction timeout") case <-p.baseEvent.Request.Context().Done(): return errors.New("batch request interrupted") } }) } func (p *batchProcessor) process(activeApp core.App, batch []*core.InternalRequest, i int) error { select { case <-p.stopCh: return nil default: if len(batch) == 0 { return nil } result, err := processInternalRequest( activeApp, p.baseEvent, batch[0], p.infoContext, func(_ any) error { if len(batch) == 1 { return nil } err := p.process(activeApp, batch[1:], i+1) // update the failed batch index (if not already) if err != nil && p.failedIndex == 0 { p.failedIndex = i + 1 } return err }, ) if err != nil { return err } p.results = append(p.results, result) return nil } } func processInternalRequest( activeApp core.App, baseEvent *core.RequestEvent, ir *core.InternalRequest, infoContext string, optNext func(data any) error, ) (*BatchRequestResult, error) { handle, params, ok := prepareInternalAction(activeApp, ir, optNext) if !ok { return nil, errors.New("unknown batch request action") } // construct a new http.Request // --------------------------------------------------------------- buf, mw, err := multipartDataFromInternalRequest(ir) if err != nil { return nil, err } r, err := http.NewRequest(strings.ToUpper(ir.Method), ir.URL, buf) if err != nil { return nil, err } // cleanup multipart temp files defer func() { if r.MultipartForm != nil { if err := r.MultipartForm.RemoveAll(); err != nil { activeApp.Logger().Warn("failed to cleanup temp batch files", "error", err) } } }() // load batch request path params // --- for k, v := range params { r.SetPathValue(k, v) } // clone original request // --- r.RequestURI = r.URL.RequestURI() r.Proto = baseEvent.Request.Proto r.ProtoMajor = baseEvent.Request.ProtoMajor r.ProtoMinor = baseEvent.Request.ProtoMinor r.Host = baseEvent.Request.Host r.RemoteAddr = baseEvent.Request.RemoteAddr r.TLS = baseEvent.Request.TLS if s := baseEvent.Request.TransferEncoding; s != nil { s2 := make([]string, len(s)) copy(s2, s) r.TransferEncoding = s2 } if baseEvent.Request.Trailer != nil { r.Trailer = baseEvent.Request.Trailer.Clone() } if baseEvent.Request.Header != nil { r.Header = baseEvent.Request.Header.Clone() } // apply batch request specific headers // --- for k, v := range ir.Headers { r.Header.Set(k, v) } r.Header.Set("Content-Type", mw.FormDataContentType()) // construct a new RequestEvent // --------------------------------------------------------------- event := &core.RequestEvent{} event.App = activeApp event.Auth = baseEvent.Auth event.SetAll(baseEvent.GetAll()) // load RequestInfo context if infoContext == "" { infoContext = core.RequestInfoContextDefault } event.Set(core.RequestEventKeyInfoContext, infoContext) // assign request event.Request = r event.Request.Body = &router.RereadableReadCloser{ReadCloser: r.Body} // enables multiple reads // assign response rec := httptest.NewRecorder() event.Response = &router.ResponseWriter{ResponseWriter: rec} // enables status and write tracking // execute // --------------------------------------------------------------- if err := handle(event); err != nil { return nil, err } result := rec.Result() defer result.Body.Close() body, _ := types.ParseJSONRaw(rec.Body.Bytes()) return &BatchRequestResult{ Status: result.StatusCode, Body: body, }, nil } func multipartDataFromInternalRequest(ir *core.InternalRequest) (*bytes.Buffer, *multipart.Writer, error) { buf := &bytes.Buffer{} mw := multipart.NewWriter(buf) regularFields := map[string]any{} fileFields := map[string][]*filesystem.File{} // separate regular fields from files // --- for k, rawV := range ir.Body { switch v := rawV.(type) { case *filesystem.File: fileFields[k] = append(fileFields[k], v) case []*filesystem.File: fileFields[k] = append(fileFields[k], v...) default: regularFields[k] = v } } // submit regularFields as @jsonPayload // --- rawBody, err := json.Marshal(regularFields) if err != nil { return nil, nil, errors.Join(err, mw.Close()) } jsonPayload, err := mw.CreateFormField("@jsonPayload") if err != nil { return nil, nil, errors.Join(err, mw.Close()) } _, err = jsonPayload.Write(rawBody) if err != nil { return nil, nil, errors.Join(err, mw.Close()) } // submit fileFields as multipart files // --- for key, files := range fileFields { for _, file := range files { part, err := mw.CreateFormFile(key, file.Name) if err != nil { return nil, nil, errors.Join(err, mw.Close()) } fr, err := file.Reader.Open() if err != nil { return nil, nil, errors.Join(err, mw.Close()) } _, err = io.Copy(part, fr) if err != nil { return nil, nil, errors.Join(err, fr.Close(), mw.Close()) } err = fr.Close() if err != nil { return nil, nil, errors.Join(err, mw.Close()) } } } return buf, mw, mw.Close() } func extractPrefixedFiles(request *http.Request, prefixes ...string) (map[string][]*filesystem.File, error) { if request.MultipartForm == nil { if err := request.ParseMultipartForm(router.DefaultMaxMemory); err != nil { return nil, err } } result := make(map[string][]*filesystem.File) for k, fhs := range request.MultipartForm.File { for _, p := range prefixes { if strings.HasPrefix(k, p) { resultKey := strings.TrimPrefix(k, p) for _, fh := range fhs { file, err := filesystem.NewFileFromMultipart(fh) if err != nil { return nil, err } result[resultKey] = append(result[resultKey], file) } } } } return result, nil } func prepareInternalAction(activeApp core.App, ir *core.InternalRequest, optNext func(data any) error) (HandleFunc, map[string]string, bool) { full := strings.ToUpper(ir.Method) + " " + ir.URL for re, actionFactory := range ValidBatchActions { params, ok := findNamedMatches(re, full) if ok { return actionFactory(activeApp, ir, params, optNext), params, true } } return nil, nil, false } func findNamedMatches(re *regexp.Regexp, str string) (map[string]string, bool) { match := re.FindStringSubmatch(str) if match == nil { return nil, false } result := map[string]string{} names := re.SubexpNames() for i, m := range match { if names[i] != "" { result[names[i]] = m } } return result, true } // ------------------------------------------------------------------- var ( _ router.SafeErrorItem = (*BatchResponseError)(nil) _ router.SafeErrorResolver = (*BatchResponseError)(nil) ) type BatchResponseError struct { err *router.ApiError code string message string } func (e *BatchResponseError) Error() string { return e.message } func (e *BatchResponseError) Code() string { return e.code } func (e *BatchResponseError) Resolve(errData map[string]any) any { errData["response"] = e.err return errData } func (e BatchResponseError) MarshalJSON() ([]byte, error) { return json.Marshal(map[string]any{ "message": e.message, "code": e.code, "response": e.err, }) }