1
0
mirror of https://github.com/pocketbase/pocketbase.git synced 2025-11-24 23:24:00 +02:00

=added experimental expand, filter, fields, custom query and headers parameters support for the realtime subscriptions

This commit is contained in:
Gani Georgiev
2023-10-23 22:46:47 +03:00
parent e6f1b3dfe4
commit 79617e6d99
41 changed files with 553 additions and 257 deletions

View File

@@ -15,9 +15,11 @@ import (
"github.com/pocketbase/pocketbase/forms"
"github.com/pocketbase/pocketbase/models"
"github.com/pocketbase/pocketbase/resolvers"
"github.com/pocketbase/pocketbase/tools/rest"
"github.com/pocketbase/pocketbase/tools/routine"
"github.com/pocketbase/pocketbase/tools/search"
"github.com/pocketbase/pocketbase/tools/subscriptions"
"github.com/spf13/cast"
)
// bindRealtimeApi registers the realtime api endpoints.
@@ -326,58 +328,16 @@ func (api *realtimeApi) resolveRecordCollection(model models.Model) (collection
return collection
}
// canAccessRecord checks if the subscription client has access to the specified record model.
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
admin, _ := client.Get(ContextAdminKey).(*models.Admin)
if admin != nil {
// admins can access everything
return true
}
if accessRule == nil {
// only admins can access this record
return false
}
ruleFunc := func(q *dbx.SelectQuery) error {
if *accessRule == "" {
return nil // empty public rule
}
// mock request data
requestInfo := &models.RequestInfo{
Method: "GET",
}
requestInfo.AuthRecord, _ = client.Get(ContextAuthRecordKey).(*models.Record)
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), record.Collection(), requestInfo, true)
expr, err := search.FilterData(*accessRule).BuildExpr(resolver)
if err != nil {
return err
}
resolver.UpdateQuery(q)
q.AndWhere(expr)
return nil
}
foundRecord, err := api.app.Dao().FindRecordById(record.Collection().Id, record.Id, ruleFunc)
if err == nil && foundRecord != nil {
return true
}
return false
}
// recordData represents the broadcasted record subscrition message data.
type recordData struct {
Record *models.Record `json:"record"`
Action string `json:"action"`
Record any `json:"record"` /* map or models.Record */
Action string `json:"action"`
}
func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dryCache bool) error {
collection := record.Collection()
if collection == nil {
return errors.New("Record collection not set.")
return errors.New("[broadcastRecord] Record collection not set.")
}
clients := api.app.SubscriptionsBroker().Clients()
@@ -385,67 +345,106 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
return nil // no subscribers
}
// create a clean record copy without expand and unknown fields
// because we don't know if the clients have permissions to view them
cleanRecord := record.CleanCopy()
subscriptionRuleMap := map[string]*string{
(collection.Name + "/" + cleanRecord.Id): collection.ViewRule,
(collection.Id + "/" + cleanRecord.Id): collection.ViewRule,
(collection.Name + "/*"): collection.ListRule,
(collection.Id + "/*"): collection.ListRule,
// @deprecated: the same as the wildcard topic but kept for backward compatibility
collection.Name: collection.ListRule,
collection.Id: collection.ListRule,
(collection.Name + "/" + record.Id + "?"): collection.ViewRule,
(collection.Id + "/" + record.Id + "?"): collection.ViewRule,
(collection.Name + "/*?"): collection.ListRule,
(collection.Id + "/*?"): collection.ListRule,
// @deprecated: the same as the wildcard topic but kept for backward compatibility
(collection.Name + "?"): collection.ListRule,
(collection.Id + "?"): collection.ListRule,
}
data := &recordData{
Action: action,
Record: cleanRecord,
}
dataBytes, err := json.Marshal(data)
if err != nil {
return err
}
dryCacheKey := action + "/" + record.Id
for _, client := range clients {
client := client
for subscription, rule := range subscriptionRuleMap {
if !client.HasSubscription(subscription) {
// note: not executed concurrently to avoid races and to ensure
// that the access checks are applied for the current record db state
for prefix, rule := range subscriptionRuleMap {
subs := client.Subscriptions(prefix)
if len(subs) == 0 {
continue
}
if !api.canAccessRecord(client, data.Record, rule) {
continue
}
for sub, options := range subs {
// create a clean record copy without expand and unknown fields
// because we don't know yet which exact fields the client subscription has permissions to access
cleanRecord := record.CleanCopy()
msg := subscriptions.Message{
Name: subscription,
Data: dataBytes,
}
// ignore the auth record email visibility checks for
// auth owner, admin or manager
if collection.IsAuth() {
authId := extractAuthIdFromGetter(client)
if authId == data.Record.Id ||
api.canAccessRecord(client, data.Record, collection.AuthOptions().ManageRule) {
data.Record.IgnoreEmailVisibility(true) // ignore
if newData, err := json.Marshal(data); err == nil {
msg.Data = newData
}
data.Record.IgnoreEmailVisibility(false) // restore
// mock request data
requestInfo := &models.RequestInfo{
Method: "GET",
Query: options.Query,
Headers: options.Headers,
}
}
requestInfo.Admin, _ = client.Get(ContextAdminKey).(*models.Admin)
requestInfo.AuthRecord, _ = client.Get(ContextAuthRecordKey).(*models.Record)
if dryCache {
client.Set(action+"/"+data.Record.Id, msg)
} else {
routine.FireAndForget(func() {
client.Send(msg)
})
if !api.canAccessRecord(cleanRecord, requestInfo, rule) {
continue
}
rawExpand := cast.ToString(options.Query[expandQueryParam])
if rawExpand != "" {
expandErrs := api.app.Dao().ExpandRecord(cleanRecord, strings.Split(rawExpand, ","), expandFetch(api.app.Dao(), requestInfo))
if api.app.IsDebug() && len(expandErrs) > 0 {
log.Println("[broadcastRecord] expand errors", expandErrs)
}
}
// ignore the auth record email visibility checks
// for auth owner, admin or manager
if collection.IsAuth() {
authId := extractAuthIdFromGetter(client)
if authId == cleanRecord.Id {
if api.canAccessRecord(cleanRecord, requestInfo, collection.AuthOptions().ManageRule) {
cleanRecord.IgnoreEmailVisibility(true)
}
}
}
data := &recordData{
Action: action,
Record: cleanRecord,
}
// check fields
rawFields := cast.ToString(options.Query[fieldsQueryParam])
if rawFields != "" {
decoded, err := rest.PickFields(cleanRecord, rawFields)
if err == nil {
data.Record = decoded
} else if api.app.IsDebug() {
log.Println(err)
}
}
dataBytes, err := json.Marshal(data)
if err != nil && api.app.IsDebug() {
log.Println("[broadcastRecord] data marshal error", err)
continue
}
msg := subscriptions.Message{
Name: sub,
Data: dataBytes,
}
if dryCache {
messages, ok := client.Get(dryCacheKey).([]subscriptions.Message)
if !ok {
messages = []subscriptions.Message{msg}
} else {
messages = append(messages, msg)
}
client.Set(dryCacheKey, messages)
} else {
routine.FireAndForget(func() {
client.Send(msg)
})
}
}
}
}
@@ -453,14 +452,14 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr
return nil
}
// broadcastDryCachedRecord broadcasts record if it is cached in the client context.
// broadcastDryCachedRecord broadcasts all cached record related messages.
func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.Record) error {
key := action + "/" + record.Id
clients := api.app.SubscriptionsBroker().Clients()
for _, client := range clients {
key := action + "/" + record.Id
msg, ok := client.Get(key).(subscriptions.Message)
messages, ok := client.Get(key).([]subscriptions.Message)
if !ok {
continue
}
@@ -470,9 +469,12 @@ func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.R
client := client
routine.FireAndForget(func() {
client.Send(msg)
for _, msg := range messages {
client.Send(msg)
}
})
}
return nil
}
@@ -493,3 +495,41 @@ func extractAuthIdFromGetter(val getter) string {
return ""
}
// canAccessRecord checks if the subscription client has access to the specified record model.
func (api *realtimeApi) canAccessRecord(
record *models.Record,
requestInfo *models.RequestInfo,
accessRule *string,
) bool {
// check the access rule
// ---
if ok, _ := api.app.Dao().CanAccessRecord(record, requestInfo, accessRule); !ok {
return false
}
// check the subscription client-side filter (if any)
// ---
filter := cast.ToString(requestInfo.Query[search.FilterQueryParam])
if filter == "" {
return true // no further checks needed
}
ruleFunc := func(q *dbx.SelectQuery) error {
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), record.Collection(), requestInfo, false)
expr, err := search.FilterData(filter).BuildExpr(resolver)
if err != nil {
return err
}
q.AndWhere(expr)
resolver.UpdateQuery(q)
return nil
}
_, err := api.app.Dao().FindRecordById(record.Collection().Id, record.Id, ruleFunc)
return err == nil
}