diff --git a/apis/realtime_test.go b/apis/realtime_test.go index abe854fb..ed6627ba 100644 --- a/apis/realtime_test.go +++ b/apis/realtime_test.go @@ -2,10 +2,13 @@ package apis_test import ( "context" + "encoding/json" "errors" "fmt" "net/http" + "slices" "strings" + "sync" "testing" "time" @@ -14,6 +17,7 @@ import ( "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/subscriptions" + "github.com/pocketbase/pocketbase/tools/types" ) func TestRealtimeConnect(t *testing.T) { @@ -632,3 +636,245 @@ func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) { t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email()) } } + +// ------------------------------------------------------------------- + +var _ core.Model = (*CustomModelResolve)(nil) + +type CustomModelResolve struct { + core.BaseModel + tableName string + + Created string `db:"created"` +} + +func (m *CustomModelResolve) TableName() string { + return m.tableName +} + +func TestRealtimeRecordResove(t *testing.T) { + t.Parallel() + + const testCollectionName = "realtime_test_collection" + + testRecordId := core.GenerateDefaultRandomId() + + client0 := subscriptions.NewDefaultClient() + client0.Subscribe(testCollectionName + "/*") + client0.Discard() + // --- + client1 := subscriptions.NewDefaultClient() + client1.Subscribe(testCollectionName + "/*") + // --- + client2 := subscriptions.NewDefaultClient() + client2.Subscribe(testCollectionName + "/" + testRecordId) + // --- + client3 := subscriptions.NewDefaultClient() + client3.Subscribe("demo1/*") + + scenarios := []struct { + name string + op func(testApp core.App) error + expected map[string][]string // clientId -> [events] + }{ + { + "core.Record", + func(testApp core.App) error { + c, err := testApp.FindCollectionByNameOrId(testCollectionName) + if err != nil { + return err + } + + r := core.NewRecord(c) + r.Id = testRecordId + + // create + err = testApp.Save(r) + if err != nil { + return err + } + + // update + err = testApp.Save(r) + if err != nil { + return err + } + + // delete + err = testApp.Delete(r) + if err != nil { + return err + } + + return nil + }, + map[string][]string{ + client1.Id(): {"create", "update", "delete"}, + client2.Id(): {"create", "update", "delete"}, + }, + }, + { + "core.RecordProxy", + func(testApp core.App) error { + c, err := testApp.FindCollectionByNameOrId(testCollectionName) + if err != nil { + return err + } + + r := core.NewRecord(c) + + proxy := &struct { + core.BaseRecordProxy + }{} + proxy.SetProxyRecord(r) + proxy.Id = testRecordId + + // create + err = testApp.Save(proxy) + if err != nil { + return err + } + + // update + err = testApp.Save(proxy) + if err != nil { + return err + } + + // delete + err = testApp.Delete(proxy) + if err != nil { + return err + } + + return nil + }, + map[string][]string{ + client1.Id(): {"create", "update", "delete"}, + client2.Id(): {"create", "update", "delete"}, + }, + }, + { + "custom model struct", + func(testApp core.App) error { + m := &CustomModelResolve{tableName: testCollectionName} + m.Id = testRecordId + + // create + err := testApp.Save(m) + if err != nil { + return err + } + + // update + m.Created = "123" + err = testApp.Save(m) + if err != nil { + return err + } + + // delete + err = testApp.Delete(m) + if err != nil { + return err + } + + return nil + }, + map[string][]string{ + client1.Id(): {"create", "update", "delete"}, + client2.Id(): {"create", "update", "delete"}, + }, + }, + } + + for _, s := range scenarios { + t.Run(s.name, func(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + // init realtime handlers + apis.NewRouter(testApp) + + // create new test collection with public read access + testCollection := core.NewBaseCollection(testCollectionName) + testCollection.Fields.Add(&core.AutodateField{Name: "created", OnCreate: true, OnUpdate: true}) + testCollection.ListRule = types.Pointer("") + testCollection.ViewRule = types.Pointer("") + err := testApp.Save(testCollection) + if err != nil { + t.Fatal(err) + } + + testApp.SubscriptionsBroker().Register(client0) + testApp.SubscriptionsBroker().Register(client1) + testApp.SubscriptionsBroker().Register(client2) + testApp.SubscriptionsBroker().Register(client3) + + var wg sync.WaitGroup + + var notifications = map[string][]string{} + + var mu sync.Mutex + notify := func(clientId string, eventData []byte) { + data := struct{ Action string }{} + _ = json.Unmarshal(eventData, &data) + + mu.Lock() + defer mu.Unlock() + + if notifications[clientId] == nil { + notifications[clientId] = []string{} + } + notifications[clientId] = append(notifications[clientId], data.Action) + } + + wg.Add(1) + go func() { + defer wg.Done() + + timeout := time.After(250 * time.Millisecond) + + for { + select { + case e, ok := <-client0.Channel(): + if ok { + notify(client0.Id(), e.Data) + } + case e, ok := <-client1.Channel(): + if ok { + notify(client1.Id(), e.Data) + } + case e, ok := <-client2.Channel(): + if ok { + notify(client2.Id(), e.Data) + } + case e, ok := <-client3.Channel(): + if ok { + notify(client3.Id(), e.Data) + } + case <-timeout: + return + } + } + }() + + err = s.op(testApp) + if err != nil { + t.Fatal(err) + } + + wg.Wait() + + if len(s.expected) != len(notifications) { + t.Fatalf("Expected %d notified clients, got %d:\n%v", len(s.expected), len(notifications), notifications) + } + + for id, events := range s.expected { + if !slices.Equal(notifications[id], events) { + t.Fatalf("[%s] Expected %d events, got %d\n%v\nvs\n%v", id, len(events), len(notifications[id]), s.expected, notifications) + } + } + }) + } +} diff --git a/core/db.go b/core/db.go index 455a70c0..e799d832 100644 --- a/core/db.go +++ b/core/db.go @@ -116,8 +116,7 @@ func (app *BaseApp) delete(ctx context.Context, model Model, isForAuxDB bool) er deleteErr := app.OnModelDelete().Trigger(event, func(e *ModelEvent) error { pk := cast.ToString(e.Model.LastSavedPK()) - - if cast.ToString(pk) == "" { + if pk == "" { return errors.New("the model can be deleted only if it is existing and has a non-empty primary key") }