1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-01-07 17:06:20 +02:00

[#2325] trigger the related record realtime events on custom record model change

This commit is contained in:
Gani Georgiev 2023-04-20 10:44:20 +03:00
parent fdf4f6d3bd
commit 818857dea2
2 changed files with 135 additions and 8 deletions

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
@ -215,7 +216,9 @@ func (api *realtimeApi) setSubscriptions(c echo.Context) error {
func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error { func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error {
for _, client := range api.app.SubscriptionsBroker().Clients() { for _, client := range api.app.SubscriptionsBroker().Clients() {
clientModel, _ := client.Get(contextKey).(models.Model) clientModel, _ := client.Get(contextKey).(models.Model)
if clientModel != nil && clientModel.GetId() == newModel.GetId() { if clientModel != nil &&
clientModel.TableName() == newModel.TableName() &&
clientModel.GetId() == newModel.GetId() {
client.Set(contextKey, newModel) client.Set(contextKey, newModel)
} }
} }
@ -227,7 +230,9 @@ func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel model
func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error { func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error {
for _, client := range api.app.SubscriptionsBroker().Clients() { for _, client := range api.app.SubscriptionsBroker().Clients() {
clientModel, _ := client.Get(contextKey).(models.Model) clientModel, _ := client.Get(contextKey).(models.Model)
if clientModel != nil && clientModel.GetId() == model.GetId() { if clientModel != nil &&
clientModel.TableName() == model.TableName() &&
clientModel.GetId() == model.GetId() {
api.app.SubscriptionsBroker().Unregister(client.Id()) api.app.SubscriptionsBroker().Unregister(client.Id())
} }
} }
@ -238,7 +243,7 @@ func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model mo
func (api *realtimeApi) bindEvents() { func (api *realtimeApi) bindEvents() {
// update the clients that has admin or auth record association // update the clients that has admin or auth record association
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error { api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() { if record := api.resolveRecord(e.Model); record != nil && record.Collection().IsAuth() {
return api.updateClientsAuthModel(ContextAuthRecordKey, record) return api.updateClientsAuthModel(ContextAuthRecordKey, record)
} }
@ -251,8 +256,8 @@ func (api *realtimeApi) bindEvents() {
// remove the client(s) associated to the deleted admin or auth record // remove the client(s) associated to the deleted admin or auth record
api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error { api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error {
if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() { if collection := api.resolveRecordCollection(e.Model); collection != nil && collection.IsAuth() {
return api.unregisterClientsByAuthModel(ContextAuthRecordKey, record) return api.unregisterClientsByAuthModel(ContextAuthRecordKey, e.Model)
} }
if admin, ok := e.Model.(*models.Admin); ok && admin != nil { if admin, ok := e.Model.(*models.Admin); ok && admin != nil {
@ -263,7 +268,7 @@ func (api *realtimeApi) bindEvents() {
}) })
api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error { api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error {
if record, ok := e.Model.(*models.Record); ok { if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastRecord("create", record); err != nil && api.app.IsDebug() { if err := api.broadcastRecord("create", record); err != nil && api.app.IsDebug() {
log.Println(err) log.Println(err)
} }
@ -272,7 +277,7 @@ func (api *realtimeApi) bindEvents() {
}) })
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error { api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
if record, ok := e.Model.(*models.Record); ok { if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastRecord("update", record); err != nil && api.app.IsDebug() { if err := api.broadcastRecord("update", record); err != nil && api.app.IsDebug() {
log.Println(err) log.Println(err)
} }
@ -281,7 +286,7 @@ func (api *realtimeApi) bindEvents() {
}) })
api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error { api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error {
if record, ok := e.Model.(*models.Record); ok { if record := api.resolveRecord(e.Model); record != nil {
if err := api.broadcastRecord("delete", record); err != nil && api.app.IsDebug() { if err := api.broadcastRecord("delete", record); err != nil && api.app.IsDebug() {
log.Println(err) log.Println(err)
} }
@ -290,6 +295,33 @@ func (api *realtimeApi) bindEvents() {
}) })
} }
// resolveRecord converts *if possible* the provided model interface to a Record.
// This is usually helpful if the provided model is a custom Record model struct.
func (api *realtimeApi) resolveRecord(model models.Model) (record *models.Record) {
record, _ = model.(*models.Record)
// check if it is custom Record model struct (ignore "private" tables)
if record == nil && !strings.HasPrefix(model.TableName(), "_") {
record, _ = api.app.Dao().FindRecordById(model.TableName(), model.GetId())
}
return record
}
// resolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
// This is usually helpful if the provided model is a custom Record model struct.
func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection *models.Collection) {
if record, ok := model.(*models.Record); ok {
collection = record.Collection()
} else if !strings.HasPrefix(model.TableName(), "_") {
// check if it is custom Record model struct (ignore "private" tables)
collection, _ = api.app.Dao().FindCollectionByNameOrId(model.TableName())
}
return collection
}
// canAccessRecord checks if the subscription client has access to the specified record model.
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool { func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
admin, _ := client.Get(ContextAdminKey).(*models.Admin) admin, _ := client.Get(ContextAdminKey).(*models.Admin)
if admin != nil { if admin != nil {

View File

@ -7,8 +7,10 @@ import (
"testing" "testing"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/pocketbase/dbx"
"github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/apis"
"github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/core"
"github.com/pocketbase/pocketbase/daos"
"github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tests"
"github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/hook"
@ -353,3 +355,96 @@ func TestRealtimeAdminUpdateEvent(t *testing.T) {
t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email) t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
} }
} }
// Custom auth record model struct
// -------------------------------------------------------------------
var _ models.Model = (*CustomUser)(nil)
type CustomUser struct {
models.BaseModel
Email string `db:"email" json:"email"`
}
func (m *CustomUser) TableName() string {
return "users" // the name of your collection
}
func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) {
model := &CustomUser{}
err := dao.ModelQuery(model).
AndWhere(dbx.HashExp{"email": email}).
Limit(1).
One(model)
if err != nil {
return nil, err
}
return model, nil
}
func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
if err != nil {
t.Fatal(err)
}
// delete the custom user (should unset the client auth record)
if err := testApp.Dao().Delete(customUser); err != nil {
t.Fatal(err)
}
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
}
}
func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
testApp, _ := tests.NewTestApp()
defer testApp.Cleanup()
apis.InitApi(testApp)
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
if err != nil {
t.Fatal(err)
}
client := subscriptions.NewDefaultClient()
client.Set(apis.ContextAuthRecordKey, authRecord)
testApp.SubscriptionsBroker().Register(client)
// refetch the authRecord as CustomUser
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
if err != nil {
t.Fatal(err)
}
// change its email
customUser.Email = "new@example.com"
if err := testApp.Dao().Save(customUser); err != nil {
t.Fatal(err)
}
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
if clientAuthRecord.Email() != customUser.Email {
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
}
}