1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-01-10 00:43:36 +02:00
pocketbase/daos/record_expand.go

275 lines
8.1 KiB
Go
Raw Normal View History

2022-07-06 23:19:05 +02:00
package daos
import (
"errors"
"fmt"
2022-10-30 10:28:14 +02:00
"regexp"
2022-07-06 23:19:05 +02:00
"strings"
2022-10-30 10:28:14 +02:00
"github.com/pocketbase/dbx"
2022-07-06 23:19:05 +02:00
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
2022-10-30 10:28:14 +02:00
"github.com/pocketbase/pocketbase/tools/inflector"
2022-07-06 23:19:05 +02:00
"github.com/pocketbase/pocketbase/tools/list"
2022-10-30 10:28:14 +02:00
"github.com/pocketbase/pocketbase/tools/security"
"github.com/pocketbase/pocketbase/tools/types"
2022-07-06 23:19:05 +02:00
)
// MaxExpandDepth specifies the max allowed nested expand depth path.
const MaxExpandDepth = 6
// ExpandFetchFunc defines the function that is used to fetch the expanded relation records.
type ExpandFetchFunc func(relCollection *models.Collection, relIds []string) ([]*models.Record, error)
// ExpandRecord expands the relations of a single Record model.
//
// Returns a map with the failed expand parameters and their errors.
func (dao *Dao) ExpandRecord(record *models.Record, expands []string, fetchFunc ExpandFetchFunc) map[string]error {
2022-07-06 23:19:05 +02:00
return dao.ExpandRecords([]*models.Record{record}, expands, fetchFunc)
}
// ExpandRecords expands the relations of the provided Record models list.
//
// Returns a map with the failed expand parameters and their errors.
func (dao *Dao) ExpandRecords(records []*models.Record, expands []string, fetchFunc ExpandFetchFunc) map[string]error {
2022-07-06 23:19:05 +02:00
normalized := normalizeExpands(expands)
failed := map[string]error{}
2022-07-06 23:19:05 +02:00
for _, expand := range normalized {
if err := dao.expandRecords(records, expand, fetchFunc, 1); err != nil {
failed[expand] = err
2022-07-06 23:19:05 +02:00
}
}
return failed
2022-07-06 23:19:05 +02:00
}
2022-10-30 10:28:14 +02:00
var indirectExpandRegex = regexp.MustCompile(`^(\w+)\((\w+)\)$`)
2022-07-06 23:19:05 +02:00
// notes:
// - fetchFunc must be non-nil func
// - all records are expected to be from the same collection
// - if MaxExpandDepth is reached, the function returns nil ignoring the remaining expand path
2022-10-30 10:28:14 +02:00
// - indirect expands are supported only with single relation fields
2022-07-06 23:19:05 +02:00
func (dao *Dao) expandRecords(records []*models.Record, expandPath string, fetchFunc ExpandFetchFunc, recursionLevel int) error {
if fetchFunc == nil {
return errors.New("Relation records fetchFunc is not set.")
}
if expandPath == "" || recursionLevel > MaxExpandDepth || len(records) == 0 {
return nil
}
2022-10-30 10:28:14 +02:00
mainCollection := records[0].Collection()
var relField *schema.SchemaField
var relFieldOptions *schema.RelationOptions
var relCollection *models.Collection
2022-07-06 23:19:05 +02:00
parts := strings.SplitN(expandPath, ".", 2)
2022-10-30 10:28:14 +02:00
matches := indirectExpandRegex.FindStringSubmatch(parts[0])
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
if len(matches) == 3 {
indirectRel, _ := dao.FindCollectionByNameOrId(matches[1])
if indirectRel == nil {
return fmt.Errorf("Couldn't find indirect related collection %q.", matches[1])
}
indirectRelField := indirectRel.Schema.GetFieldByName(matches[2])
if indirectRelField == nil || indirectRelField.Type != schema.FieldTypeRelation {
return fmt.Errorf("Couldn't find indirect relation field %q in collection %q.", matches[2], mainCollection.Name)
}
indirectRelField.InitOptions()
indirectRelFieldOptions, _ := indirectRelField.Options.(*schema.RelationOptions)
if indirectRelFieldOptions == nil || indirectRelFieldOptions.CollectionId != mainCollection.Id {
return fmt.Errorf("Invalid indirect relation field path %q.", parts[0])
}
if indirectRelFieldOptions.MaxSelect != nil && *indirectRelFieldOptions.MaxSelect != 1 {
// for now don't allow multi-relation indirect fields expand
// due to eventual poor query performance with large data sets.
return fmt.Errorf("Multi-relation fields cannot be indirectly expanded in %q.", parts[0])
}
recordIds := make([]any, len(records))
for _, record := range records {
recordIds = append(recordIds, record.Id)
}
indirectRecords, err := dao.FindRecordsByExpr(
indirectRel.Id,
dbx.In(inflector.Columnify(matches[2]), recordIds...),
)
if err != nil {
return err
}
mappedIndirectRecordIds := make(map[string][]string, len(indirectRecords))
for _, indirectRecord := range indirectRecords {
recId := indirectRecord.GetString(matches[2])
if recId != "" {
mappedIndirectRecordIds[recId] = append(mappedIndirectRecordIds[recId], indirectRecord.Id)
}
}
// add the indirect relation ids as a new relation field value
for _, record := range records {
relIds, ok := mappedIndirectRecordIds[record.Id]
if ok && len(relIds) > 0 {
record.Set(parts[0], relIds)
}
}
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
relFieldOptions = &schema.RelationOptions{
MaxSelect: nil,
CollectionId: indirectRel.Id,
}
if indirectRelField.Unique {
relFieldOptions.MaxSelect = types.Pointer(1)
}
// indirect relation
relField = &schema.SchemaField{
2022-11-06 15:26:34 +02:00
Id: "indirect_" + security.PseudoRandomString(5),
2022-10-30 10:28:14 +02:00
Type: schema.FieldTypeRelation,
Name: parts[0],
Options: relFieldOptions,
}
relCollection = indirectRel
} else {
// direct relation
relField = mainCollection.Schema.GetFieldByName(parts[0])
if relField == nil || relField.Type != schema.FieldTypeRelation {
return fmt.Errorf("Couldn't find relation field %q in collection %q.", parts[0], mainCollection.Name)
}
relField.InitOptions()
relFieldOptions, _ = relField.Options.(*schema.RelationOptions)
if relFieldOptions == nil {
return fmt.Errorf("Couldn't initialize the options of relation field %q.", parts[0])
}
relCollection, _ = dao.FindCollectionByNameOrId(relFieldOptions.CollectionId)
if relCollection == nil {
return fmt.Errorf("Couldn't find related collection %q.", relFieldOptions.CollectionId)
}
2022-07-06 23:19:05 +02:00
}
2022-10-30 10:28:14 +02:00
// ---------------------------------------------------------------
2022-07-06 23:19:05 +02:00
// extract the id of the relations to expand
relIds := make([]string, 0, len(records))
2022-07-06 23:19:05 +02:00
for _, record := range records {
2022-10-30 10:28:14 +02:00
relIds = append(relIds, record.GetStringSlice(relField.Name)...)
2022-07-06 23:19:05 +02:00
}
// fetch rels
rels, relsErr := fetchFunc(relCollection, relIds)
if relsErr != nil {
return relsErr
}
// expand nested fields
if len(parts) > 1 {
err := dao.expandRecords(rels, parts[1], fetchFunc, recursionLevel+1)
if err != nil {
return err
}
}
// reindex with the rel id
indexedRels := map[string]*models.Record{}
for _, rel := range rels {
indexedRels[rel.GetId()] = rel
}
for _, model := range records {
2022-10-30 10:28:14 +02:00
relIds := model.GetStringSlice(relField.Name)
2022-07-06 23:19:05 +02:00
validRels := make([]*models.Record, 0, len(relIds))
2022-07-06 23:19:05 +02:00
for _, id := range relIds {
if rel, ok := indexedRels[id]; ok {
validRels = append(validRels, rel)
}
}
if len(validRels) == 0 {
continue // no valid relations
}
2022-10-30 10:28:14 +02:00
expandData := model.Expand()
2022-07-06 23:19:05 +02:00
// normalize access to the previously expanded rel records (if any)
var oldExpandedRels []*models.Record
switch v := expandData[relField.Name].(type) {
case nil:
// no old expands
case *models.Record:
oldExpandedRels = []*models.Record{v}
case []*models.Record:
oldExpandedRels = v
}
// merge expands
for _, oldExpandedRel := range oldExpandedRels {
// find a matching rel record
for _, rel := range validRels {
if rel.Id != oldExpandedRel.Id {
continue
}
2022-10-30 10:28:14 +02:00
oldRelExpand := oldExpandedRel.Expand()
newRelExpand := rel.Expand()
for k, v := range oldRelExpand {
newRelExpand[k] = v
}
rel.SetExpand(newRelExpand)
}
}
// update the expanded data
2022-10-30 10:28:14 +02:00
if relFieldOptions.MaxSelect != nil && *relFieldOptions.MaxSelect <= 1 {
2022-07-06 23:19:05 +02:00
expandData[relField.Name] = validRels[0]
} else {
expandData[relField.Name] = validRels
}
2022-07-06 23:19:05 +02:00
model.SetExpand(expandData)
}
return nil
}
// normalizeExpands normalizes expand strings and merges self containing paths
// (eg. ["a.b.c", "a.b", " test ", " ", "test"] -> ["a.b.c", "test"]).
func normalizeExpands(paths []string) []string {
// normalize paths
normalized := make([]string, 0, len(paths))
2022-07-06 23:19:05 +02:00
for _, p := range paths {
p = strings.ReplaceAll(p, " ", "") // replace spaces
p = strings.Trim(p, ".") // trim incomplete paths
if p != "" {
normalized = append(normalized, p)
2022-07-06 23:19:05 +02:00
}
}
// merge containing paths
result := make([]string, 0, len(normalized))
2022-07-06 23:19:05 +02:00
for i, p1 := range normalized {
var skip bool
for j, p2 := range normalized {
if i == j {
continue
}
if strings.HasPrefix(p2, p1+".") {
// skip because there is more detailed expand path
skip = true
break
}
}
if !skip {
result = append(result, p1)
}
}
return list.ToUniqueStringSlice(result)
}