1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2024-11-21 13:35:49 +02:00

added option to call Dao.RecordQuery() with the collection id or name

This commit is contained in:
Gani Georgiev 2023-07-13 22:38:55 +03:00
parent a38bd5bedc
commit fdccdcebad
4 changed files with 127 additions and 68 deletions

View File

@ -77,6 +77,9 @@
- Minor Admin UI fixes (typos, grammar fixes, removed unnecessary 404 error check, etc.).
- (@todo docs) For consistency and convenience it is now possible to call `Dao.RecordQuery(collectionModelOrIdentifier)` with just the collection id or name.
In case an invalid collection id/name string is passed the query will be resolved with cancelled context error.
## v0.16.8

View File

@ -1,6 +1,7 @@
package daos
import (
"context"
"database/sql"
"errors"
"fmt"
@ -18,79 +19,114 @@ import (
"github.com/spf13/cast"
)
// RecordQuery returns a new Record select query.
func (dao *Dao) RecordQuery(collection *models.Collection) *dbx.SelectQuery {
tableName := collection.Name
// RecordQuery returns a new Record select query from a collection model, id or name.
//
// In case a collection id or name is provided and that collection doesn't
// actually exists, the generated query will be created with a cancelled context
// and will fail once an executor (Row(), One(), All(), etc.) is called.
func (dao *Dao) RecordQuery(collectionModelOrIdentifier any) *dbx.SelectQuery {
var tableName string
var collection *models.Collection
var collectionErr error
switch c := collectionModelOrIdentifier.(type) {
case *models.Collection:
collection = c
tableName = collection.Name
case models.Collection:
collection = &c
tableName = collection.Name
case string:
collection, collectionErr = dao.FindCollectionByNameOrId(c)
if collection != nil {
tableName = collection.Name
} else {
// update with some fake table name for easier debugging
tableName = "@@__missing_" + c
}
default:
// update with some fake table name for easier debugging
tableName = "@@__invalidCollectionModelOrIdentifier"
collectionErr = errors.New("unsupported collection identifier, must be collection model, id or name")
}
selectCols := fmt.Sprintf("%s.*", dao.DB().QuoteSimpleColumnName(tableName))
return dao.DB().
Select(selectCols).
From(tableName).
WithBuildHook(func(query *dbx.Query) {
query.WithExecHook(execLockRetry(dao.ModelQueryTimeout, dao.MaxLockRetries)).
WithOneHook(func(q *dbx.Query, a any, op func(b any) error) error {
switch v := a.(type) {
case *models.Record:
if v == nil {
return op(a)
}
query := dao.DB().Select(selectCols).From(tableName)
row := dbx.NullStringMap{}
if err := op(&row); err != nil {
return err
}
// in case of an error attach a new context and cancel it immediately with the error
if collectionErr != nil {
// @todo consider changing to WithCancelCause when upgrading
// the min Go requirement to 1.20, so that we can pass the error
ctx, cancelFunc := context.WithCancel(context.Background())
query.WithContext(ctx)
cancelFunc()
}
record := models.NewRecordFromNullStringMap(collection, row)
*v = *record
return nil
default:
return query.WithBuildHook(func(q *dbx.Query) {
q.WithExecHook(execLockRetry(dao.ModelQueryTimeout, dao.MaxLockRetries)).
WithOneHook(func(q *dbx.Query, a any, op func(b any) error) error {
switch v := a.(type) {
case *models.Record:
if v == nil {
return op(a)
}
}).
WithAllHook(func(q *dbx.Query, sliceA any, op func(sliceB any) error) error {
switch v := sliceA.(type) {
case *[]*models.Record:
if v == nil {
return op(sliceA)
}
rows := []dbx.NullStringMap{}
if err := op(&rows); err != nil {
return err
}
row := dbx.NullStringMap{}
if err := op(&row); err != nil {
return err
}
records := models.NewRecordsFromNullStringMaps(collection, rows)
record := models.NewRecordFromNullStringMap(collection, row)
*v = records
*v = *record
return nil
case *[]models.Record:
if v == nil {
return op(sliceA)
}
rows := []dbx.NullStringMap{}
if err := op(&rows); err != nil {
return err
}
records := models.NewRecordsFromNullStringMaps(collection, rows)
nonPointers := make([]models.Record, len(records))
for i, r := range records {
nonPointers[i] = *r
}
*v = nonPointers
return nil
default:
return nil
default:
return op(a)
}
}).
WithAllHook(func(q *dbx.Query, sliceA any, op func(sliceB any) error) error {
switch v := sliceA.(type) {
case *[]*models.Record:
if v == nil {
return op(sliceA)
}
})
})
rows := []dbx.NullStringMap{}
if err := op(&rows); err != nil {
return err
}
records := models.NewRecordsFromNullStringMaps(collection, rows)
*v = records
return nil
case *[]models.Record:
if v == nil {
return op(sliceA)
}
rows := []dbx.NullStringMap{}
if err := op(&rows); err != nil {
return err
}
records := models.NewRecordsFromNullStringMaps(collection, rows)
nonPointers := make([]models.Record, len(records))
for i, r := range records {
nonPointers[i] = *r
}
*v = nonPointers
return nil
default:
return op(sliceA)
}
})
})
}
// FindRecordById finds the Record model by its id.

View File

@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"regexp"
"strings"
"testing"
@ -19,7 +18,7 @@ import (
"github.com/pocketbase/pocketbase/tools/types"
)
func TestRecordQuery(t *testing.T) {
func TestRecordQueryWithDifferentCollectionValues(t *testing.T) {
app, _ := tests.NewTestApp()
defer app.Cleanup()
@ -28,11 +27,33 @@ func TestRecordQuery(t *testing.T) {
t.Fatal(err)
}
expected := fmt.Sprintf("SELECT `%s`.* FROM `%s`", collection.Name, collection.Name)
scenarios := []struct {
name any
collection any
expectedTotal int
expectError bool
}{
{"with nil value", nil, 0, true},
{"with invalid or missing collection id/name", "missing", 0, true},
{"with pointer model", collection, 3, false},
{"with value model", *collection, 3, false},
{"with name", "demo1", 3, false},
{"with id", "wsmn24bux7wo113", 3, false},
}
sql := app.Dao().RecordQuery(collection).Build().SQL()
if sql != expected {
t.Errorf("Expected sql %s, got %s", expected, sql)
for _, s := range scenarios {
var records []*models.Record
err := app.Dao().RecordQuery(s.collection).All(&records)
hasErr := err != nil
if hasErr != s.expectError {
t.Errorf("[%s] Expected hasError %v, got %v", s.name, s.expectError, hasErr)
continue
}
if total := len(records); total != s.expectedTotal {
t.Errorf("[%s] Expected %d records, got %d", s.name, s.expectedTotal, total)
}
}
}

View File

@ -27,7 +27,6 @@ func SubtractSlice[T comparable](base []T, subtract []T) []T {
// ExistInSlice checks whether a comparable element exists in a slice of the same type.
func ExistInSlice[T comparable](item T, list []T) bool {
for _, v := range list {
if v == item {
return true