From 818857dea2a6b5a157f3ab3da67d35509d280e89 Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Thu, 20 Apr 2023 10:44:20 +0300 Subject: [PATCH] [#2325] trigger the related record realtime events on custom record model change --- apis/realtime.go | 48 ++++++++++++++++++---- apis/realtime_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 8 deletions(-) diff --git a/apis/realtime.go b/apis/realtime.go index 74c9ae8c..556b14a8 100644 --- a/apis/realtime.go +++ b/apis/realtime.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "net/http" + "strings" "time" "github.com/labstack/echo/v5" @@ -215,7 +216,9 @@ func (api *realtimeApi) setSubscriptions(c echo.Context) error { func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error { for _, client := range api.app.SubscriptionsBroker().Clients() { clientModel, _ := client.Get(contextKey).(models.Model) - if clientModel != nil && clientModel.GetId() == newModel.GetId() { + if clientModel != nil && + clientModel.TableName() == newModel.TableName() && + clientModel.GetId() == newModel.GetId() { client.Set(contextKey, newModel) } } @@ -227,7 +230,9 @@ func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel model func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error { for _, client := range api.app.SubscriptionsBroker().Clients() { clientModel, _ := client.Get(contextKey).(models.Model) - if clientModel != nil && clientModel.GetId() == model.GetId() { + if clientModel != nil && + clientModel.TableName() == model.TableName() && + clientModel.GetId() == model.GetId() { api.app.SubscriptionsBroker().Unregister(client.Id()) } } @@ -238,7 +243,7 @@ func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model mo func (api *realtimeApi) bindEvents() { // update the clients that has admin or auth record association api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error { - if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() { + if record := api.resolveRecord(e.Model); record != nil && record.Collection().IsAuth() { return api.updateClientsAuthModel(ContextAuthRecordKey, record) } @@ -251,8 +256,8 @@ func (api *realtimeApi) bindEvents() { // remove the client(s) associated to the deleted admin or auth record api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error { - if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() { - return api.unregisterClientsByAuthModel(ContextAuthRecordKey, record) + if collection := api.resolveRecordCollection(e.Model); collection != nil && collection.IsAuth() { + return api.unregisterClientsByAuthModel(ContextAuthRecordKey, e.Model) } if admin, ok := e.Model.(*models.Admin); ok && admin != nil { @@ -263,7 +268,7 @@ func (api *realtimeApi) bindEvents() { }) api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error { - if record, ok := e.Model.(*models.Record); ok { + if record := api.resolveRecord(e.Model); record != nil { if err := api.broadcastRecord("create", record); err != nil && api.app.IsDebug() { log.Println(err) } @@ -272,7 +277,7 @@ func (api *realtimeApi) bindEvents() { }) api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error { - if record, ok := e.Model.(*models.Record); ok { + if record := api.resolveRecord(e.Model); record != nil { if err := api.broadcastRecord("update", record); err != nil && api.app.IsDebug() { log.Println(err) } @@ -281,7 +286,7 @@ func (api *realtimeApi) bindEvents() { }) api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error { - if record, ok := e.Model.(*models.Record); ok { + if record := api.resolveRecord(e.Model); record != nil { if err := api.broadcastRecord("delete", record); err != nil && api.app.IsDebug() { log.Println(err) } @@ -290,6 +295,33 @@ func (api *realtimeApi) bindEvents() { }) } +// resolveRecord converts *if possible* the provided model interface to a Record. +// This is usually helpful if the provided model is a custom Record model struct. +func (api *realtimeApi) resolveRecord(model models.Model) (record *models.Record) { + record, _ = model.(*models.Record) + + // check if it is custom Record model struct (ignore "private" tables) + if record == nil && !strings.HasPrefix(model.TableName(), "_") { + record, _ = api.app.Dao().FindRecordById(model.TableName(), model.GetId()) + } + + return record +} + +// resolveRecordCollection extracts *if possible* the Collection model from the provided model interface. +// This is usually helpful if the provided model is a custom Record model struct. +func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection *models.Collection) { + if record, ok := model.(*models.Record); ok { + collection = record.Collection() + } else if !strings.HasPrefix(model.TableName(), "_") { + // check if it is custom Record model struct (ignore "private" tables) + collection, _ = api.app.Dao().FindCollectionByNameOrId(model.TableName()) + } + + return collection +} + +// canAccessRecord checks if the subscription client has access to the specified record model. func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool { admin, _ := client.Get(ContextAdminKey).(*models.Admin) if admin != nil { diff --git a/apis/realtime_test.go b/apis/realtime_test.go index 9c65bfb9..b715dcb5 100644 --- a/apis/realtime_test.go +++ b/apis/realtime_test.go @@ -7,8 +7,10 @@ import ( "testing" "github.com/labstack/echo/v5" + "github.com/pocketbase/dbx" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" + "github.com/pocketbase/pocketbase/daos" "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/hook" @@ -353,3 +355,96 @@ func TestRealtimeAdminUpdateEvent(t *testing.T) { t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email) } } + +// Custom auth record model struct +// ------------------------------------------------------------------- +var _ models.Model = (*CustomUser)(nil) + +type CustomUser struct { + models.BaseModel + + Email string `db:"email" json:"email"` +} + +func (m *CustomUser) TableName() string { + return "users" // the name of your collection +} + +func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) { + model := &CustomUser{} + + err := dao.ModelQuery(model). + AndWhere(dbx.HashExp{"email": email}). + Limit(1). + One(model) + + if err != nil { + return nil, err + } + + return model, nil +} + +func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + apis.InitApi(testApp) + + authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") + if err != nil { + t.Fatal(err) + } + + client := subscriptions.NewDefaultClient() + client.Set(apis.ContextAuthRecordKey, authRecord) + testApp.SubscriptionsBroker().Register(client) + + // refetch the authRecord as CustomUser + customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com") + if err != nil { + t.Fatal(err) + } + + // delete the custom user (should unset the client auth record) + if err := testApp.Dao().Delete(customUser); err != nil { + t.Fatal(err) + } + + if len(testApp.SubscriptionsBroker().Clients()) != 0 { + t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients())) + } +} + +func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) { + testApp, _ := tests.NewTestApp() + defer testApp.Cleanup() + + apis.InitApi(testApp) + + authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") + if err != nil { + t.Fatal(err) + } + + client := subscriptions.NewDefaultClient() + client.Set(apis.ContextAuthRecordKey, authRecord) + testApp.SubscriptionsBroker().Register(client) + + // refetch the authRecord as CustomUser + customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com") + if err != nil { + t.Fatal(err) + } + + // change its email + customUser.Email = "new@example.com" + if err := testApp.Dao().Save(customUser); err != nil { + t.Fatal(err) + } + + clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) + if clientAuthRecord.Email() != customUser.Email { + t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email()) + } +}