mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-01-07 17:06:20 +02:00
[#2325] trigger the related record realtime events on custom record model change
This commit is contained in:
parent
fdf4f6d3bd
commit
818857dea2
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/labstack/echo/v5"
|
"github.com/labstack/echo/v5"
|
||||||
@ -215,7 +216,9 @@ func (api *realtimeApi) setSubscriptions(c echo.Context) error {
|
|||||||
func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error {
|
func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel models.Model) error {
|
||||||
for _, client := range api.app.SubscriptionsBroker().Clients() {
|
for _, client := range api.app.SubscriptionsBroker().Clients() {
|
||||||
clientModel, _ := client.Get(contextKey).(models.Model)
|
clientModel, _ := client.Get(contextKey).(models.Model)
|
||||||
if clientModel != nil && clientModel.GetId() == newModel.GetId() {
|
if clientModel != nil &&
|
||||||
|
clientModel.TableName() == newModel.TableName() &&
|
||||||
|
clientModel.GetId() == newModel.GetId() {
|
||||||
client.Set(contextKey, newModel)
|
client.Set(contextKey, newModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -227,7 +230,9 @@ func (api *realtimeApi) updateClientsAuthModel(contextKey string, newModel model
|
|||||||
func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error {
|
func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model models.Model) error {
|
||||||
for _, client := range api.app.SubscriptionsBroker().Clients() {
|
for _, client := range api.app.SubscriptionsBroker().Clients() {
|
||||||
clientModel, _ := client.Get(contextKey).(models.Model)
|
clientModel, _ := client.Get(contextKey).(models.Model)
|
||||||
if clientModel != nil && clientModel.GetId() == model.GetId() {
|
if clientModel != nil &&
|
||||||
|
clientModel.TableName() == model.TableName() &&
|
||||||
|
clientModel.GetId() == model.GetId() {
|
||||||
api.app.SubscriptionsBroker().Unregister(client.Id())
|
api.app.SubscriptionsBroker().Unregister(client.Id())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -238,7 +243,7 @@ func (api *realtimeApi) unregisterClientsByAuthModel(contextKey string, model mo
|
|||||||
func (api *realtimeApi) bindEvents() {
|
func (api *realtimeApi) bindEvents() {
|
||||||
// update the clients that has admin or auth record association
|
// update the clients that has admin or auth record association
|
||||||
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
|
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
|
||||||
if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() {
|
if record := api.resolveRecord(e.Model); record != nil && record.Collection().IsAuth() {
|
||||||
return api.updateClientsAuthModel(ContextAuthRecordKey, record)
|
return api.updateClientsAuthModel(ContextAuthRecordKey, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,8 +256,8 @@ func (api *realtimeApi) bindEvents() {
|
|||||||
|
|
||||||
// remove the client(s) associated to the deleted admin or auth record
|
// remove the client(s) associated to the deleted admin or auth record
|
||||||
api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error {
|
api.app.OnModelAfterDelete().PreAdd(func(e *core.ModelEvent) error {
|
||||||
if record, ok := e.Model.(*models.Record); ok && record != nil && record.Collection().IsAuth() {
|
if collection := api.resolveRecordCollection(e.Model); collection != nil && collection.IsAuth() {
|
||||||
return api.unregisterClientsByAuthModel(ContextAuthRecordKey, record)
|
return api.unregisterClientsByAuthModel(ContextAuthRecordKey, e.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
if admin, ok := e.Model.(*models.Admin); ok && admin != nil {
|
if admin, ok := e.Model.(*models.Admin); ok && admin != nil {
|
||||||
@ -263,7 +268,7 @@ func (api *realtimeApi) bindEvents() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error {
|
api.app.OnModelAfterCreate().PreAdd(func(e *core.ModelEvent) error {
|
||||||
if record, ok := e.Model.(*models.Record); ok {
|
if record := api.resolveRecord(e.Model); record != nil {
|
||||||
if err := api.broadcastRecord("create", record); err != nil && api.app.IsDebug() {
|
if err := api.broadcastRecord("create", record); err != nil && api.app.IsDebug() {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
@ -272,7 +277,7 @@ func (api *realtimeApi) bindEvents() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
|
api.app.OnModelAfterUpdate().PreAdd(func(e *core.ModelEvent) error {
|
||||||
if record, ok := e.Model.(*models.Record); ok {
|
if record := api.resolveRecord(e.Model); record != nil {
|
||||||
if err := api.broadcastRecord("update", record); err != nil && api.app.IsDebug() {
|
if err := api.broadcastRecord("update", record); err != nil && api.app.IsDebug() {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
@ -281,7 +286,7 @@ func (api *realtimeApi) bindEvents() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error {
|
api.app.OnModelBeforeDelete().Add(func(e *core.ModelEvent) error {
|
||||||
if record, ok := e.Model.(*models.Record); ok {
|
if record := api.resolveRecord(e.Model); record != nil {
|
||||||
if err := api.broadcastRecord("delete", record); err != nil && api.app.IsDebug() {
|
if err := api.broadcastRecord("delete", record); err != nil && api.app.IsDebug() {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
@ -290,6 +295,33 @@ func (api *realtimeApi) bindEvents() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveRecord converts *if possible* the provided model interface to a Record.
|
||||||
|
// This is usually helpful if the provided model is a custom Record model struct.
|
||||||
|
func (api *realtimeApi) resolveRecord(model models.Model) (record *models.Record) {
|
||||||
|
record, _ = model.(*models.Record)
|
||||||
|
|
||||||
|
// check if it is custom Record model struct (ignore "private" tables)
|
||||||
|
if record == nil && !strings.HasPrefix(model.TableName(), "_") {
|
||||||
|
record, _ = api.app.Dao().FindRecordById(model.TableName(), model.GetId())
|
||||||
|
}
|
||||||
|
|
||||||
|
return record
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRecordCollection extracts *if possible* the Collection model from the provided model interface.
|
||||||
|
// This is usually helpful if the provided model is a custom Record model struct.
|
||||||
|
func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection *models.Collection) {
|
||||||
|
if record, ok := model.(*models.Record); ok {
|
||||||
|
collection = record.Collection()
|
||||||
|
} else if !strings.HasPrefix(model.TableName(), "_") {
|
||||||
|
// check if it is custom Record model struct (ignore "private" tables)
|
||||||
|
collection, _ = api.app.Dao().FindCollectionByNameOrId(model.TableName())
|
||||||
|
}
|
||||||
|
|
||||||
|
return collection
|
||||||
|
}
|
||||||
|
|
||||||
|
// canAccessRecord checks if the subscription client has access to the specified record model.
|
||||||
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
|
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
|
||||||
admin, _ := client.Get(ContextAdminKey).(*models.Admin)
|
admin, _ := client.Get(ContextAdminKey).(*models.Admin)
|
||||||
if admin != nil {
|
if admin != nil {
|
||||||
|
@ -7,8 +7,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/labstack/echo/v5"
|
"github.com/labstack/echo/v5"
|
||||||
|
"github.com/pocketbase/dbx"
|
||||||
"github.com/pocketbase/pocketbase/apis"
|
"github.com/pocketbase/pocketbase/apis"
|
||||||
"github.com/pocketbase/pocketbase/core"
|
"github.com/pocketbase/pocketbase/core"
|
||||||
|
"github.com/pocketbase/pocketbase/daos"
|
||||||
"github.com/pocketbase/pocketbase/models"
|
"github.com/pocketbase/pocketbase/models"
|
||||||
"github.com/pocketbase/pocketbase/tests"
|
"github.com/pocketbase/pocketbase/tests"
|
||||||
"github.com/pocketbase/pocketbase/tools/hook"
|
"github.com/pocketbase/pocketbase/tools/hook"
|
||||||
@ -353,3 +355,96 @@ func TestRealtimeAdminUpdateEvent(t *testing.T) {
|
|||||||
t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
|
t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Custom auth record model struct
|
||||||
|
// -------------------------------------------------------------------
|
||||||
|
var _ models.Model = (*CustomUser)(nil)
|
||||||
|
|
||||||
|
type CustomUser struct {
|
||||||
|
models.BaseModel
|
||||||
|
|
||||||
|
Email string `db:"email" json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CustomUser) TableName() string {
|
||||||
|
return "users" // the name of your collection
|
||||||
|
}
|
||||||
|
|
||||||
|
func findCustomUserByEmail(dao *daos.Dao, email string) (*CustomUser, error) {
|
||||||
|
model := &CustomUser{}
|
||||||
|
|
||||||
|
err := dao.ModelQuery(model).
|
||||||
|
AndWhere(dbx.HashExp{"email": email}).
|
||||||
|
Limit(1).
|
||||||
|
One(model)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealtimeCustomAuthModelDeleteEvent(t *testing.T) {
|
||||||
|
testApp, _ := tests.NewTestApp()
|
||||||
|
defer testApp.Cleanup()
|
||||||
|
|
||||||
|
apis.InitApi(testApp)
|
||||||
|
|
||||||
|
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := subscriptions.NewDefaultClient()
|
||||||
|
client.Set(apis.ContextAuthRecordKey, authRecord)
|
||||||
|
testApp.SubscriptionsBroker().Register(client)
|
||||||
|
|
||||||
|
// refetch the authRecord as CustomUser
|
||||||
|
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the custom user (should unset the client auth record)
|
||||||
|
if err := testApp.Dao().Delete(customUser); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(testApp.SubscriptionsBroker().Clients()) != 0 {
|
||||||
|
t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRealtimeCustomAuthModelUpdateEvent(t *testing.T) {
|
||||||
|
testApp, _ := tests.NewTestApp()
|
||||||
|
defer testApp.Cleanup()
|
||||||
|
|
||||||
|
apis.InitApi(testApp)
|
||||||
|
|
||||||
|
authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := subscriptions.NewDefaultClient()
|
||||||
|
client.Set(apis.ContextAuthRecordKey, authRecord)
|
||||||
|
testApp.SubscriptionsBroker().Register(client)
|
||||||
|
|
||||||
|
// refetch the authRecord as CustomUser
|
||||||
|
customUser, err := findCustomUserByEmail(testApp.Dao(), "test@example.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// change its email
|
||||||
|
customUser.Email = "new@example.com"
|
||||||
|
if err := testApp.Dao().Save(customUser); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record)
|
||||||
|
if clientAuthRecord.Email() != customUser.Email {
|
||||||
|
t.Fatalf("Expected authRecord with email %q, got %q", customUser.Email, clientAuthRecord.Email())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user