1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-02-05 10:45:09 +02:00
pocketbase/core/record_query.go

608 lines
15 KiB
Go

package core
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/tools/inflector"
"github.com/pocketbase/pocketbase/tools/list"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/security"
)
var recordProxyType = reflect.TypeOf((*RecordProxy)(nil)).Elem()
// RecordQuery returns a new Record select query from a collection model, id or name.
//
// In case a collection id or name is provided and that collection doesn't
// actually exists, the generated query will be created with a cancelled context
// and will fail once an executor (Row(), One(), All(), etc.) is called.
func (app *BaseApp) RecordQuery(collectionModelOrIdentifier any) *dbx.SelectQuery {
var tableName string
collection, collectionErr := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
if collection != nil {
tableName = collection.Name
}
if tableName == "" {
// update with some fake table name for easier debugging
tableName = "@@__invalidCollectionModelOrIdentifier"
}
query := app.DB().Select(app.DB().QuoteSimpleColumnName(tableName) + ".*").From(tableName)
// in case of an error attach a new context and cancel it immediately with the error
if collectionErr != nil {
ctx, cancelFunc := context.WithCancelCause(context.Background())
query.WithContext(ctx)
cancelFunc(collectionErr)
}
return query.WithBuildHook(func(q *dbx.Query) {
q.WithExecHook(execLockRetry(app.config.QueryTimeout, defaultMaxLockRetries)).
WithOneHook(func(q *dbx.Query, a any, op func(b any) error) error {
if a == nil {
return op(a)
}
switch v := a.(type) {
case *Record:
record, err := resolveRecordOneHook(collection, op)
if err != nil {
return err
}
*v = *record
return nil
case RecordProxy:
record, err := resolveRecordOneHook(collection, op)
if err != nil {
return err
}
v.SetProxyRecord(record)
return nil
default:
return op(a)
}
}).
WithAllHook(func(q *dbx.Query, sliceA any, op func(sliceB any) error) error {
if sliceA == nil {
return op(sliceA)
}
switch v := sliceA.(type) {
case *[]*Record:
records, err := resolveRecordAllHook(collection, op)
if err != nil {
return err
}
*v = records
return nil
case *[]Record:
records, err := resolveRecordAllHook(collection, op)
if err != nil {
return err
}
nonPointers := make([]Record, len(records))
for i, r := range records {
nonPointers[i] = *r
}
*v = nonPointers
return nil
default: // expects []RecordProxy slice
records, err := resolveRecordAllHook(collection, op)
if err != nil {
return err
}
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("must be a pointer")
}
rv = dereference(rv)
if rv.Kind() != reflect.Slice {
return errors.New("must be a slice of RecordSetters")
}
// create an empty slice
if rv.IsNil() {
rv.Set(reflect.MakeSlice(rv.Type(), 0, len(records)))
}
et := rv.Type().Elem()
var isSliceOfPointers bool
if et.Kind() == reflect.Ptr {
isSliceOfPointers = true
et = et.Elem()
}
if !reflect.PointerTo(et).Implements(recordProxyType) {
return op(sliceA)
}
for _, record := range records {
ev := reflect.New(et)
if !ev.CanInterface() {
continue
}
ps, ok := ev.Interface().(RecordProxy)
if !ok {
continue
}
ps.SetProxyRecord(record)
ev = ev.Elem()
if isSliceOfPointers {
ev = ev.Addr()
}
rv.Set(reflect.Append(rv, ev))
}
return nil
}
})
})
}
func resolveRecordOneHook(collection *Collection, op func(dst any) error) (*Record, error) {
data := dbx.NullStringMap{}
if err := op(&data); err != nil {
return nil, err
}
return newRecordFromNullStringMap(collection, data)
}
func resolveRecordAllHook(collection *Collection, op func(dst any) error) ([]*Record, error) {
data := []dbx.NullStringMap{}
if err := op(&data); err != nil {
return nil, err
}
return newRecordsFromNullStringMaps(collection, data)
}
// dereference returns the underlying value v points to.
func dereference(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
// initialize with a new value and continue searching
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
func getCollectionByModelOrIdentifier(app App, collectionModelOrIdentifier any) (*Collection, error) {
switch c := collectionModelOrIdentifier.(type) {
case *Collection:
return c, nil
case Collection:
return &c, nil
case string:
return app.FindCachedCollectionByNameOrId(c)
default:
return nil, errors.New("unknown collection identifier - must be collection model, id or name")
}
}
// -------------------------------------------------------------------
// FindRecordById finds the Record model by its id.
func (app *BaseApp) FindRecordById(
collectionModelOrIdentifier any,
recordId string,
optFilters ...func(q *dbx.SelectQuery) error,
) (*Record, error) {
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
if err != nil {
return nil, err
}
record := &Record{}
query := app.RecordQuery(collection).
AndWhere(dbx.HashExp{collection.Name + ".id": recordId})
// apply filter funcs (if any)
for _, filter := range optFilters {
if filter == nil {
continue
}
if err = filter(query); err != nil {
return nil, err
}
}
err = query.Limit(1).One(record)
if err != nil {
return nil, err
}
return record, nil
}
// FindRecordsByIds finds all records by the specified ids.
// If no records are found, returns an empty slice.
func (app *BaseApp) FindRecordsByIds(
collectionModelOrIdentifier any,
recordIds []string,
optFilters ...func(q *dbx.SelectQuery) error,
) ([]*Record, error) {
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
if err != nil {
return nil, err
}
query := app.RecordQuery(collection).
AndWhere(dbx.In(
collection.Name+".id",
list.ToInterfaceSlice(recordIds)...,
))
for _, filter := range optFilters {
if filter == nil {
continue
}
if err = filter(query); err != nil {
return nil, err
}
}
records := make([]*Record, 0, len(recordIds))
err = query.All(&records)
if err != nil {
return nil, err
}
return records, nil
}
// FindAllRecords finds all records matching specified db expressions.
//
// Returns all collection records if no expression is provided.
//
// Returns an empty slice if no records are found.
//
// Example:
//
// // no extra expressions
// app.FindAllRecords("example")
//
// // with extra expressions
// expr1 := dbx.HashExp{"email": "test@example.com"}
// expr2 := dbx.NewExp("LOWER(username) = {:username}", dbx.Params{"username": "test"})
// app.FindAllRecords("example", expr1, expr2)
func (app *BaseApp) FindAllRecords(collectionModelOrIdentifier any, exprs ...dbx.Expression) ([]*Record, error) {
query := app.RecordQuery(collectionModelOrIdentifier)
for _, expr := range exprs {
if expr != nil { // add only the non-nil expressions
query.AndWhere(expr)
}
}
var records []*Record
if err := query.All(&records); err != nil {
return nil, err
}
return records, nil
}
// FindFirstRecordByData returns the first found record matching
// the provided key-value pair.
func (app *BaseApp) FindFirstRecordByData(collectionModelOrIdentifier any, key string, value any) (*Record, error) {
record := &Record{}
err := app.RecordQuery(collectionModelOrIdentifier).
AndWhere(dbx.HashExp{inflector.Columnify(key): value}).
Limit(1).
One(record)
if err != nil {
return nil, err
}
return record, nil
}
// FindRecordsByFilter returns limit number of records matching the
// provided string filter.
//
// NB! Use the last "params" argument to bind untrusted user variables!
//
// The filter argument is optional and can be empty string to target
// all available records.
//
// The sort argument is optional and can be empty string OR the same format
// used in the web APIs, ex. "-created,title".
//
// If the limit argument is <= 0, no limit is applied to the query and
// all matching records are returned.
//
// Returns an empty slice if no records are found.
//
// Example:
//
// app.FindRecordsByFilter(
// "posts",
// "title ~ {:title} && visible = {:visible}",
// "-created",
// 10,
// 0,
// dbx.Params{"title": "lorem ipsum", "visible": true}
// )
func (app *BaseApp) FindRecordsByFilter(
collectionModelOrIdentifier any,
filter string,
sort string,
limit int,
offset int,
params ...dbx.Params,
) ([]*Record, error) {
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
if err != nil {
return nil, err
}
q := app.RecordQuery(collection)
// build a fields resolver and attach the generated conditions to the query
// ---
resolver := NewRecordFieldResolver(
app,
collection, // the base collection
nil, // no request data
true, // allow searching hidden/protected fields like "email"
)
if filter != "" {
expr, err := search.FilterData(filter).BuildExpr(resolver, params...)
if err != nil {
return nil, fmt.Errorf("invalid filter expression: %w", err)
}
q.AndWhere(expr)
}
if sort != "" {
for _, sortField := range search.ParseSortFromString(sort) {
expr, err := sortField.BuildExpr(resolver)
if err != nil {
return nil, err
}
if expr != "" {
q.AndOrderBy(expr)
}
}
}
resolver.UpdateQuery(q) // attaches any adhoc joins and aliases
// ---
if offset > 0 {
q.Offset(int64(offset))
}
if limit > 0 {
q.Limit(int64(limit))
}
records := []*Record{}
if err := q.All(&records); err != nil {
return nil, err
}
return records, nil
}
// FindFirstRecordByFilter returns the first available record matching the provided filter (if any).
//
// NB! Use the last params argument to bind untrusted user variables!
//
// Returns sql.ErrNoRows if no record is found.
//
// Example:
//
// app.FindFirstRecordByFilter("posts", "")
// app.FindFirstRecordByFilter("posts", "slug={:slug} && status='public'", dbx.Params{"slug": "test"})
func (app *BaseApp) FindFirstRecordByFilter(
collectionModelOrIdentifier any,
filter string,
params ...dbx.Params,
) (*Record, error) {
result, err := app.FindRecordsByFilter(collectionModelOrIdentifier, filter, "", 1, 0, params...)
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, sql.ErrNoRows
}
return result[0], nil
}
// CountRecords returns the total number of records in a collection.
func (app *BaseApp) CountRecords(collectionModelOrIdentifier any, exprs ...dbx.Expression) (int64, error) {
var total int64
q := app.RecordQuery(collectionModelOrIdentifier).Select("count(*)")
for _, expr := range exprs {
if expr != nil { // add only the non-nil expressions
q.AndWhere(expr)
}
}
err := q.Row(&total)
return total, err
}
// FindAuthRecordByToken finds the auth record associated with the provided JWT
// (auth, file, verifyEmail, changeEmail, passwordReset types).
//
// Optionally specify a list of validTypes to check tokens only from those types.
//
// Returns an error if the JWT is invalid, expired or not associated to an auth collection record.
func (app *BaseApp) FindAuthRecordByToken(token string, validTypes ...string) (*Record, error) {
if token == "" {
return nil, errors.New("missing token")
}
unverifiedClaims, err := security.ParseUnverifiedJWT(token)
if err != nil {
return nil, err
}
// check required claims
id, _ := unverifiedClaims[TokenClaimId].(string)
collectionId, _ := unverifiedClaims[TokenClaimCollectionId].(string)
tokenType, _ := unverifiedClaims[TokenClaimType].(string)
if id == "" || collectionId == "" || tokenType == "" {
return nil, errors.New("missing or invalid token claims")
}
// check types (if explicitly set)
if len(validTypes) > 0 && !list.ExistInSlice(tokenType, validTypes) {
return nil, fmt.Errorf("invalid token type %q, expects %q", tokenType, strings.Join(validTypes, ","))
}
record, err := app.FindRecordById(collectionId, id)
if err != nil {
return nil, err
}
if !record.Collection().IsAuth() {
return nil, errors.New("the token is not associated to an auth collection record")
}
var baseTokenKey string
switch tokenType {
case TokenTypeAuth:
baseTokenKey = record.Collection().AuthToken.Secret
case TokenTypeFile:
baseTokenKey = record.Collection().FileToken.Secret
case TokenTypeVerification:
baseTokenKey = record.Collection().VerificationToken.Secret
case TokenTypePasswordReset:
baseTokenKey = record.Collection().PasswordResetToken.Secret
case TokenTypeEmailChange:
baseTokenKey = record.Collection().EmailChangeToken.Secret
default:
return nil, errors.New("unknown token type " + tokenType)
}
secret := record.TokenKey() + baseTokenKey
// verify token signature
_, err = security.ParseJWT(token, secret)
if err != nil {
return nil, err
}
return record, nil
}
// FindAuthRecordByEmail finds the auth record associated with the provided email.
//
// Returns an error if it is not an auth collection or the record is not found.
func (app *BaseApp) FindAuthRecordByEmail(collectionModelOrIdentifier any, email string) (*Record, error) {
collection, err := getCollectionByModelOrIdentifier(app, collectionModelOrIdentifier)
if err != nil {
return nil, fmt.Errorf("failed to fetch auth collection: %w", err)
}
if !collection.IsAuth() {
return nil, fmt.Errorf("%q is not an auth collection", collection.Name)
}
record := &Record{}
err = app.RecordQuery(collection).
AndWhere(dbx.HashExp{FieldNameEmail: email}).
Limit(1).
One(record)
if err != nil {
return nil, err
}
return record, nil
}
// CanAccessRecord checks if a record is allowed to be accessed by the
// specified requestInfo and accessRule.
//
// Rule and db checks are ignored in case requestInfo.Auth is a superuser.
//
// The returned error indicate that something unexpected happened during
// the check (eg. invalid rule or db query error).
//
// The method always return false on invalid rule or db query error.
//
// Example:
//
// requestInfo, _ := e.RequestInfo()
// record, _ := app.FindRecordById("example", "RECORD_ID")
// rule := types.Pointer("@request.auth.id != '' || status = 'public'")
// // ... or use one of the record collection's rule, eg. record.Collection().ViewRule
//
// if ok, _ := app.CanAccessRecord(record, requestInfo, rule); ok { ... }
func (app *BaseApp) CanAccessRecord(record *Record, requestInfo *RequestInfo, accessRule *string) (bool, error) {
// superusers can access everything
if requestInfo.HasSuperuserAuth() {
return true, nil
}
// only superusers can access this record
if accessRule == nil {
return false, nil
}
// empty public rule, aka. everyone can access
if *accessRule == "" {
return true, nil
}
var exists bool
query := app.RecordQuery(record.Collection()).
Select("(1)").
AndWhere(dbx.HashExp{record.Collection().Name + ".id": record.Id})
// parse and apply the access rule filter
resolver := NewRecordFieldResolver(app, record.Collection(), requestInfo, true)
expr, err := search.FilterData(*accessRule).BuildExpr(resolver)
if err != nil {
return false, err
}
resolver.UpdateQuery(query)
err = query.AndWhere(expr).Limit(1).Row(&exists)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, err
}
return exists, nil
}