2022-07-06 23:19:05 +02:00
package daos
import (
"errors"
"fmt"
2022-12-11 21:25:31 +02:00
"math"
2022-07-06 23:19:05 +02:00
"strings"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/models/schema"
2022-10-30 10:28:14 +02:00
"github.com/pocketbase/pocketbase/tools/inflector"
2022-07-06 23:19:05 +02:00
"github.com/pocketbase/pocketbase/tools/list"
2022-07-18 15:26:37 +02:00
"github.com/pocketbase/pocketbase/tools/security"
2022-07-06 23:19:05 +02:00
"github.com/pocketbase/pocketbase/tools/types"
2022-10-30 10:28:14 +02:00
"github.com/spf13/cast"
2022-07-06 23:19:05 +02:00
)
// RecordQuery returns a new Record select query.
func ( dao * Dao ) RecordQuery ( collection * models . Collection ) * dbx . SelectQuery {
tableName := collection . Name
selectCols := fmt . Sprintf ( "%s.*" , dao . DB ( ) . QuoteSimpleColumnName ( tableName ) )
return dao . DB ( ) . Select ( selectCols ) . From ( tableName )
}
// FindRecordById finds the Record model by its id.
func ( dao * Dao ) FindRecordById (
2022-10-30 10:28:14 +02:00
collectionNameOrId string ,
2022-07-06 23:19:05 +02:00
recordId string ,
2022-10-30 10:28:14 +02:00
optFilters ... func ( q * dbx . SelectQuery ) error ,
2022-07-06 23:19:05 +02:00
) ( * models . Record , error ) {
2022-10-30 10:28:14 +02:00
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
if err != nil {
return nil , err
}
2022-07-06 23:19:05 +02:00
tableName := collection . Name
query := dao . RecordQuery ( collection ) .
AndWhere ( dbx . HashExp { tableName + ".id" : recordId } )
2022-10-30 10:28:14 +02:00
for _ , filter := range optFilters {
if filter == nil {
continue
}
2022-07-06 23:19:05 +02:00
if err := filter ( query ) ; err != nil {
return nil , err
}
}
row := dbx . NullStringMap { }
if err := query . Limit ( 1 ) . One ( row ) ; err != nil {
return nil , err
}
return models . NewRecordFromNullStringMap ( collection , row ) , nil
}
// FindRecordsByIds finds all Record models by the provided ids.
// If no records are found, returns an empty slice.
func ( dao * Dao ) FindRecordsByIds (
2022-10-30 10:28:14 +02:00
collectionNameOrId string ,
2022-07-06 23:19:05 +02:00
recordIds [ ] string ,
2022-10-30 10:28:14 +02:00
optFilters ... func ( q * dbx . SelectQuery ) error ,
2022-07-06 23:19:05 +02:00
) ( [ ] * models . Record , error ) {
2022-10-30 10:28:14 +02:00
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
if err != nil {
return nil , err
}
2022-07-06 23:19:05 +02:00
query := dao . RecordQuery ( collection ) .
2022-10-30 10:28:14 +02:00
AndWhere ( dbx . In (
collection . Name + ".id" ,
list . ToInterfaceSlice ( recordIds ) ... ,
) )
for _ , filter := range optFilters {
if filter == nil {
continue
}
2022-07-06 23:19:05 +02:00
if err := filter ( query ) ; err != nil {
return nil , err
}
}
2023-01-07 22:25:56 +02:00
rows := make ( [ ] dbx . NullStringMap , 0 , len ( recordIds ) )
2022-07-06 23:19:05 +02:00
if err := query . All ( & rows ) ; err != nil {
return nil , err
}
return models . NewRecordsFromNullStringMaps ( collection , rows ) , nil
}
2022-10-30 10:28:14 +02:00
// FindRecordsByExpr finds all records by the specified db expression.
//
// Returns all collection records if no expressions are provided.
//
// Returns an empty slice if no records are found.
2022-07-06 23:19:05 +02:00
//
// Example:
2022-10-30 10:28:14 +02:00
// expr1 := dbx.HashExp{"email": "test@example.com"}
2022-11-13 00:38:18 +02:00
// expr2 := dbx.NewExp("LOWER(username) = {:username}", dbx.Params{"username": "test"})
2022-10-30 10:28:14 +02:00
// dao.FindRecordsByExpr("example", expr1, expr2)
func ( dao * Dao ) FindRecordsByExpr ( collectionNameOrId string , exprs ... dbx . Expression ) ( [ ] * models . Record , error ) {
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
if err != nil {
return nil , err
2022-07-06 23:19:05 +02:00
}
2022-10-30 10:28:14 +02:00
query := dao . RecordQuery ( collection )
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
// add only the non-nil expressions
for _ , expr := range exprs {
if expr != nil {
query . AndWhere ( expr )
}
}
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
rows := [ ] dbx . NullStringMap { }
if err := query . All ( & rows ) ; err != nil {
2022-07-06 23:19:05 +02:00
return nil , err
}
return models . NewRecordsFromNullStringMaps ( collection , rows ) , nil
}
// FindFirstRecordByData returns the first found record matching
// the provided key-value pair.
2022-10-30 10:28:14 +02:00
func ( dao * Dao ) FindFirstRecordByData ( collectionNameOrId string , key string , value any ) ( * models . Record , error ) {
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
if err != nil {
return nil , err
}
2022-07-06 23:19:05 +02:00
row := dbx . NullStringMap { }
2022-10-30 10:28:14 +02:00
err = dao . RecordQuery ( collection ) .
AndWhere ( dbx . HashExp { inflector . Columnify ( key ) : value } ) .
2022-07-06 23:19:05 +02:00
Limit ( 1 ) .
One ( row )
if err != nil {
return nil , err
}
return models . NewRecordFromNullStringMap ( collection , row ) , nil
}
// IsRecordValueUnique checks if the provided key-value pair is a unique Record value.
//
2022-10-30 10:28:14 +02:00
// For correctness, if the collection is "auth" and the key is "username",
// the unique check will be case insensitive.
//
2022-07-06 23:19:05 +02:00
// NB! Array values (eg. from multiple select fields) are matched
// as a serialized json strings (eg. `["a","b"]`), so the value uniqueness
// depends on the elements order. Or in other words the following values
// are considered different: `[]string{"a","b"}` and `[]string{"b","a"}`
func ( dao * Dao ) IsRecordValueUnique (
2022-10-30 10:28:14 +02:00
collectionNameOrId string ,
2022-07-06 23:19:05 +02:00
key string ,
value any ,
2022-10-30 10:28:14 +02:00
excludeIds ... string ,
2022-07-06 23:19:05 +02:00
) bool {
2022-10-30 10:28:14 +02:00
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
if err != nil {
return false
}
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
var expr dbx . Expression
if collection . IsAuth ( ) && key == schema . FieldNameUsername {
expr = dbx . NewExp ( "LOWER([[" + schema . FieldNameUsername + "]])={:username}" , dbx . Params {
"username" : strings . ToLower ( cast . ToString ( value ) ) ,
} )
} else {
var normalizedVal any
switch val := value . ( type ) {
case [ ] string :
normalizedVal = append ( types . JsonArray { } , list . ToInterfaceSlice ( val ) ... )
case [ ] any :
normalizedVal = append ( types . JsonArray { } , val ... )
default :
normalizedVal = val
}
expr = dbx . HashExp { inflector . Columnify ( key ) : normalizedVal }
2022-07-06 23:19:05 +02:00
}
2022-10-30 10:28:14 +02:00
query := dao . RecordQuery ( collection ) .
2022-07-06 23:19:05 +02:00
Select ( "count(*)" ) .
2022-10-30 10:28:14 +02:00
AndWhere ( expr ) .
Limit ( 1 )
2023-01-07 22:25:56 +02:00
if uniqueExcludeIds := list . NonzeroUniques ( excludeIds ) ; len ( uniqueExcludeIds ) > 0 {
2022-10-30 10:28:14 +02:00
query . AndWhere ( dbx . NotIn ( collection . Name + ".id" , list . ToInterfaceSlice ( uniqueExcludeIds ) ... ) )
}
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
var exists bool
return query . Row ( & exists ) == nil && ! exists
2022-07-06 23:19:05 +02:00
}
2022-10-30 10:28:14 +02:00
// FindAuthRecordByToken finds the auth record associated with the provided JWT token.
//
// Returns an error if the JWT token is invalid, expired or not associated to an auth collection record.
func ( dao * Dao ) FindAuthRecordByToken ( token string , baseTokenKey string ) ( * models . Record , error ) {
unverifiedClaims , err := security . ParseUnverifiedJWT ( token )
if err != nil {
return nil , err
}
// check required claims
id , _ := unverifiedClaims [ "id" ] . ( string )
collectionId , _ := unverifiedClaims [ "collectionId" ] . ( string )
if id == "" || collectionId == "" {
2022-12-18 14:06:48 +02:00
return nil , errors . New ( "missing or invalid token claims" )
2022-07-19 12:09:54 +02:00
}
2022-10-30 10:28:14 +02:00
record , err := dao . FindRecordById ( collectionId , id )
2022-07-06 23:19:05 +02:00
if err != nil {
return nil , err
}
2022-10-30 10:28:14 +02:00
if ! record . Collection ( ) . IsAuth ( ) {
return nil , errors . New ( "The token is not associated to an auth collection record." )
}
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
verificationKey := record . TokenKey ( ) + baseTokenKey
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
// verify token signature
if _ , err := security . ParseJWT ( token , verificationKey ) ; err != nil {
return nil , err
}
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
return record , nil
}
// FindAuthRecordByEmail finds the auth record associated with the provided email.
//
// Returns an error if it is not an auth collection or the record is not found.
func ( dao * Dao ) FindAuthRecordByEmail ( collectionNameOrId string , email string ) ( * models . Record , error ) {
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
2022-12-18 14:06:48 +02:00
if err != nil {
return nil , fmt . Errorf ( "failed to fetch auth collection %q (%w)" , collectionNameOrId , err )
}
if ! collection . IsAuth ( ) {
return nil , fmt . Errorf ( "%q is not an auth collection" , collectionNameOrId )
2022-10-30 10:28:14 +02:00
}
row := dbx . NullStringMap { }
err = dao . RecordQuery ( collection ) .
AndWhere ( dbx . HashExp { schema . FieldNameEmail : email } ) .
Limit ( 1 ) .
One ( row )
if err != nil {
return nil , err
}
return models . NewRecordFromNullStringMap ( collection , row ) , nil
}
// FindAuthRecordByUsername finds the auth record associated with the provided username (case insensitive).
//
// Returns an error if it is not an auth collection or the record is not found.
func ( dao * Dao ) FindAuthRecordByUsername ( collectionNameOrId string , username string ) ( * models . Record , error ) {
collection , err := dao . FindCollectionByNameOrId ( collectionNameOrId )
2022-12-18 14:06:48 +02:00
if err != nil {
return nil , fmt . Errorf ( "failed to fetch auth collection %q (%w)" , collectionNameOrId , err )
}
if ! collection . IsAuth ( ) {
return nil , fmt . Errorf ( "%q is not an auth collection" , collectionNameOrId )
2022-10-30 10:28:14 +02:00
}
row := dbx . NullStringMap { }
2022-07-06 23:19:05 +02:00
2022-10-30 10:28:14 +02:00
err = dao . RecordQuery ( collection ) .
AndWhere ( dbx . NewExp ( "LOWER([[" + schema . FieldNameUsername + "]])={:username}" , dbx . Params {
"username" : strings . ToLower ( username ) ,
} ) ) .
Limit ( 1 ) .
One ( row )
if err != nil {
return nil , err
2022-07-06 23:19:05 +02:00
}
2022-10-30 10:28:14 +02:00
return models . NewRecordFromNullStringMap ( collection , row ) , nil
}
// SuggestUniqueAuthRecordUsername checks if the provided username is unique
// and return a new "unique" username with appended random numeric part
// (eg. "existingName" -> "existingName583").
//
// The same username will be returned if the provided string is already unique.
func ( dao * Dao ) SuggestUniqueAuthRecordUsername (
collectionNameOrId string ,
baseUsername string ,
excludeIds ... string ,
) string {
username := baseUsername
for i := 0 ; i < 10 ; i ++ { // max 10 attempts
isUnique := dao . IsRecordValueUnique (
collectionNameOrId ,
schema . FieldNameUsername ,
username ,
excludeIds ... ,
)
if isUnique {
break // already unique
}
username = baseUsername + security . RandomStringWithAlphabet ( 3 + i , "123456789" )
}
return username
2022-07-06 23:19:05 +02:00
}
// SaveRecord upserts the provided Record model.
func ( dao * Dao ) SaveRecord ( record * models . Record ) error {
2022-10-30 10:28:14 +02:00
if record . Collection ( ) . IsAuth ( ) {
if record . Username ( ) == "" {
2022-12-06 12:26:29 +02:00
return errors . New ( "unable to save auth record without username" )
2022-10-30 10:28:14 +02:00
}
// Cross-check that the auth record id is unique for all auth collections.
// This is to make sure that the filter `@request.auth.id` always returns a unique id.
authCollections , err := dao . FindCollectionsByType ( models . CollectionTypeAuth )
if err != nil {
2022-12-06 12:26:29 +02:00
return fmt . Errorf ( "unable to fetch the auth collections for cross-id unique check: %w" , err )
2022-10-30 10:28:14 +02:00
}
for _ , collection := range authCollections {
if record . Collection ( ) . Id == collection . Id {
continue // skip current collection (sqlite will do the check for us)
}
isUnique := dao . IsRecordValueUnique ( collection . Id , schema . FieldNameId , record . Id )
if ! isUnique {
2022-12-06 12:26:29 +02:00
return errors . New ( "the auth record ID must be unique across all auth collections" )
2022-10-30 10:28:14 +02:00
}
}
}
2022-07-06 23:19:05 +02:00
return dao . Save ( record )
}
// DeleteRecord deletes the provided Record model.
//
// This method will also cascade the delete operation to all linked
2022-12-13 09:08:54 +02:00
// relational records (delete or unset, depending on the rel settings).
2022-07-06 23:19:05 +02:00
//
// The delete operation may fail if the record is part of a required
2022-12-13 09:08:54 +02:00
// reference in another record (aka. cannot be deleted or unset).
2022-07-06 23:19:05 +02:00
func ( dao * Dao ) DeleteRecord ( record * models . Record ) error {
2022-12-08 10:40:42 +02:00
// fetch rel references (if any)
//
// note: the select is outside of the transaction to minimize
// SQLITE_BUSY errors when mixing read&write in a single transaction
refs , err := dao . FindCollectionReferences ( record . Collection ( ) )
if err != nil {
return err
2022-07-06 23:19:05 +02:00
}
2022-12-08 10:40:42 +02:00
return dao . RunInTransaction ( func ( txDao * Dao ) error {
2022-12-12 19:19:31 +02:00
// manually trigger delete on any linked external auth to ensure
2022-12-15 16:42:35 +02:00
// that the `OnModel*` hooks are triggered
2022-12-12 19:19:31 +02:00
if record . Collection ( ) . IsAuth ( ) {
2022-12-15 16:42:35 +02:00
// note: the select is outside of the transaction to minimize
// SQLITE_BUSY errors when mixing read&write in a single transaction
2022-12-12 19:19:31 +02:00
externalAuths , err := dao . FindAllExternalAuthsByRecord ( record )
if err != nil {
return err
}
for _ , auth := range externalAuths {
if err := txDao . DeleteExternalAuth ( auth ) ; err != nil {
return err
}
}
}
// delete the record before the relation references to ensure that there
// will be no "A<->B" relations to prevent deadlock when calling DeleteRecord recursively
2022-12-09 01:49:17 +02:00
if err := txDao . Delete ( record ) ; err != nil {
return err
}
2022-12-12 19:19:31 +02:00
return txDao . cascadeRecordDelete ( record , refs )
} )
}
// cascadeRecordDelete triggers cascade deletion for the provided references
// and split the work to a batched set of go routines.
//
// NB! This method is expected to be called inside a transaction.
func ( dao * Dao ) cascadeRecordDelete ( mainRecord * models . Record , refs map [ * models . Collection ] [ ] * schema . SchemaField ) error {
uniqueJsonEachAlias := "__je__" + security . PseudorandomString ( 4 )
for refCollection , fields := range refs {
for _ , field := range fields {
recordTableName := inflector . Columnify ( refCollection . Name )
prefixedFieldName := recordTableName + "." + inflector . Columnify ( field . Name )
query := dao . RecordQuery ( refCollection ) .
Distinct ( true ) .
LeftJoin ( fmt . Sprintf (
// note: the case is used to normalize value access for single and multiple relations.
` json_each(CASE WHEN json_valid([[%s]]) THEN [[%s]] ELSE json_array([[%s]]) END) as {{ % s }} ` ,
prefixedFieldName , prefixedFieldName , prefixedFieldName , uniqueJsonEachAlias ,
) , nil ) .
AndWhere ( dbx . Not ( dbx . HashExp { recordTableName + ".id" : mainRecord . Id } ) ) .
AndWhere ( dbx . HashExp { uniqueJsonEachAlias + ".value" : mainRecord . Id } )
// trigger cascade for each 1000 rel items until there is none
batchSize := 1000
for {
2022-12-13 09:07:50 +02:00
rows := make ( [ ] dbx . NullStringMap , 0 , batchSize )
2022-12-12 19:19:31 +02:00
if err := query . Limit ( int64 ( batchSize ) ) . All ( & rows ) ; err != nil {
2022-07-06 23:19:05 +02:00
return err
}
2022-12-11 21:25:31 +02:00
total := len ( rows )
if total == 0 {
2022-12-12 19:19:31 +02:00
break
2022-12-11 21:25:31 +02:00
}
2022-07-06 23:19:05 +02:00
2022-12-13 09:07:50 +02:00
perWorker := 50
workers := int ( math . Ceil ( float64 ( total ) / float64 ( perWorker ) ) )
2022-12-11 21:25:31 +02:00
2022-12-12 19:19:31 +02:00
batchErr := func ( ) error {
ch := make ( chan error )
defer close ( ch )
2022-12-12 19:21:41 +02:00
for i := 0 ; i < workers ; i ++ {
2022-12-12 19:19:31 +02:00
var chunks [ ] dbx . NullStringMap
2022-12-13 09:07:50 +02:00
if len ( rows ) <= perWorker {
2022-12-12 19:21:41 +02:00
chunks = rows
2022-12-12 19:19:31 +02:00
rows = nil
} else {
2022-12-13 09:07:50 +02:00
chunks = rows [ : perWorker ]
rows = rows [ perWorker : ]
2022-12-12 19:19:31 +02:00
}
go func ( ) {
refRecords := models . NewRecordsFromNullStringMaps ( refCollection , chunks )
ch <- dao . deleteRefRecords ( mainRecord , refRecords , field )
} ( )
2022-07-06 23:19:05 +02:00
}
2022-12-12 19:21:41 +02:00
for i := 0 ; i < workers ; i ++ {
2022-12-12 19:19:31 +02:00
if err := <- ch ; err != nil {
return err
}
2022-07-06 23:19:05 +02:00
}
2022-12-12 19:19:31 +02:00
return nil
} ( )
if batchErr != nil {
return batchErr
2022-07-06 23:19:05 +02:00
}
2022-12-12 19:19:31 +02:00
if total < batchSize {
break // no more items
}
2022-10-30 10:28:14 +02:00
}
}
2022-12-12 19:19:31 +02:00
}
2022-10-30 10:28:14 +02:00
2022-12-12 19:19:31 +02:00
return nil
2022-07-06 23:19:05 +02:00
}
2022-12-12 19:19:31 +02:00
// deleteRefRecords checks if related records has to be deleted (if `CascadeDelete` is set)
// OR
// just unset the record id from any relation field values (if they are not required).
//
// NB! This method is expected to be called inside a transaction.
2022-12-11 21:25:31 +02:00
func ( dao * Dao ) deleteRefRecords ( mainRecord * models . Record , refRecords [ ] * models . Record , field * schema . SchemaField ) error {
options , _ := field . Options . ( * schema . RelationOptions )
if options == nil {
return errors . New ( "relation field options are not initialized" )
}
for _ , refRecord := range refRecords {
ids := refRecord . GetStringSlice ( field . Name )
// unset the record id
for i := len ( ids ) - 1 ; i >= 0 ; i -- {
if ids [ i ] == mainRecord . Id {
ids = append ( ids [ : i ] , ids [ i + 1 : ] ... )
break
}
}
// cascade delete the reference
// (only if there are no other active references in case of multiple select)
if options . CascadeDelete && len ( ids ) == 0 {
if err := dao . DeleteRecord ( refRecord ) ; err != nil {
return err
}
// no further actions are needed (the reference is deleted)
continue
}
if field . Required && len ( ids ) == 0 {
2022-12-18 14:06:48 +02:00
return fmt . Errorf ( "the record cannot be deleted because it is part of a required reference in record %s (%s collection)" , refRecord . Id , refRecord . Collection ( ) . Name )
2022-12-11 21:25:31 +02:00
}
// save the reference changes
refRecord . Set ( field . Name , field . PrepareValue ( ids ) )
if err := dao . SaveRecord ( refRecord ) ; err != nil {
return err
}
}
return nil
}