You've already forked pocketbase
mirror of
https://github.com/pocketbase/pocketbase.git
synced 2025-11-27 08:27:06 +02:00
initial public commit
This commit is contained in:
345
apis/realtime.go
Normal file
345
apis/realtime.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package apis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pocketbase/dbx"
|
||||
"github.com/labstack/echo/v5"
|
||||
"github.com/pocketbase/pocketbase/core"
|
||||
"github.com/pocketbase/pocketbase/forms"
|
||||
"github.com/pocketbase/pocketbase/models"
|
||||
"github.com/pocketbase/pocketbase/resolvers"
|
||||
"github.com/pocketbase/pocketbase/tools/rest"
|
||||
"github.com/pocketbase/pocketbase/tools/search"
|
||||
"github.com/pocketbase/pocketbase/tools/subscriptions"
|
||||
)
|
||||
|
||||
// BindRealtimeApi registers the realtime api endpoints.
|
||||
func BindRealtimeApi(app core.App, rg *echo.Group) {
|
||||
api := realtimeApi{app: app}
|
||||
|
||||
subGroup := rg.Group("/realtime", ActivityLogger(app))
|
||||
subGroup.GET("", api.connect)
|
||||
subGroup.POST("", api.setSubscriptions)
|
||||
|
||||
api.bindEvents()
|
||||
}
|
||||
|
||||
type realtimeApi struct {
|
||||
app core.App
|
||||
}
|
||||
|
||||
func (api *realtimeApi) connect(c echo.Context) error {
|
||||
cancelCtx, cancelRequest := context.WithCancel(c.Request().Context())
|
||||
defer cancelRequest()
|
||||
c.SetRequest(c.Request().Clone(cancelCtx))
|
||||
|
||||
// register new subscription client
|
||||
client := subscriptions.NewDefaultClient()
|
||||
api.app.SubscriptionsBroker().Register(client)
|
||||
defer api.app.SubscriptionsBroker().Unregister(client.Id())
|
||||
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream; charset=UTF-8")
|
||||
c.Response().Header().Set("Cache-Control", "no-store")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
event := &core.RealtimeConnectEvent{
|
||||
HttpContext: c,
|
||||
Client: client,
|
||||
}
|
||||
|
||||
if err := api.app.OnRealtimeConnectRequest().Trigger(event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// signalize established connection (aka. fire "connect" message)
|
||||
fmt.Fprint(c.Response(), "id:"+client.Id()+"\n")
|
||||
fmt.Fprint(c.Response(), "event:PB_CONNECT\n")
|
||||
fmt.Fprint(c.Response(), "data:{\"clientId\":\""+client.Id()+"\"}\n\n")
|
||||
c.Response().Flush()
|
||||
|
||||
// start an idle timer to keep track of inactive/forgotten connections
|
||||
idleDuration := 5 * time.Minute
|
||||
idleTimer := time.NewTimer(idleDuration)
|
||||
defer idleTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-idleTimer.C:
|
||||
cancelRequest()
|
||||
case msg, ok := <-client.Channel():
|
||||
if !ok {
|
||||
// channel is closed
|
||||
if api.app.IsDebug() {
|
||||
log.Println("Realtime connection closed (closed channel):", client.Id())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
w := c.Response()
|
||||
fmt.Fprint(w, "id:"+client.Id()+"\n")
|
||||
fmt.Fprint(w, "event:"+msg.Name+"\n")
|
||||
fmt.Fprint(w, "data:"+msg.Data+"\n\n")
|
||||
w.Flush()
|
||||
|
||||
idleTimer.Stop()
|
||||
idleTimer.Reset(idleDuration)
|
||||
case <-c.Request().Context().Done():
|
||||
// connection is closed
|
||||
if api.app.IsDebug() {
|
||||
log.Println("Realtime connection closed (cancelled request):", client.Id())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// note: in case of reconnect, clients will have to resubmit all subscriptions again
|
||||
func (api *realtimeApi) setSubscriptions(c echo.Context) error {
|
||||
form := forms.NewRealtimeSubscribe()
|
||||
|
||||
// read request data
|
||||
if err := c.Bind(form); err != nil {
|
||||
return rest.NewBadRequestError("", err)
|
||||
}
|
||||
|
||||
// validate request data
|
||||
if err := form.Validate(); err != nil {
|
||||
return rest.NewBadRequestError("", err)
|
||||
}
|
||||
|
||||
// find subscription client
|
||||
client, err := api.app.SubscriptionsBroker().ClientById(form.ClientId)
|
||||
if err != nil {
|
||||
return rest.NewNotFoundError("Missing or invalid client id.", err)
|
||||
}
|
||||
|
||||
// check if the previous request was authorized
|
||||
oldAuthId := extractAuthIdFromGetter(client)
|
||||
newAuthId := extractAuthIdFromGetter(c)
|
||||
if oldAuthId != "" && oldAuthId != newAuthId {
|
||||
return rest.NewForbiddenError("The current and the previous request authorization don't match.", nil)
|
||||
}
|
||||
|
||||
event := &core.RealtimeSubscribeEvent{
|
||||
HttpContext: c,
|
||||
Client: client,
|
||||
Subscriptions: form.Subscriptions,
|
||||
}
|
||||
|
||||
handlerErr := api.app.OnRealtimeBeforeSubscribeRequest().Trigger(event, func(e *core.RealtimeSubscribeEvent) error {
|
||||
// update auth state
|
||||
e.Client.Set(ContextAdminKey, e.HttpContext.Get(ContextAdminKey))
|
||||
e.Client.Set(ContextUserKey, e.HttpContext.Get(ContextUserKey))
|
||||
|
||||
// unsubscribe from any previous existing subscriptions
|
||||
e.Client.Unsubscribe()
|
||||
|
||||
// subscribe to the new subscriptions
|
||||
e.Client.Subscribe(e.Subscriptions...)
|
||||
|
||||
return e.HttpContext.NoContent(http.StatusNoContent)
|
||||
})
|
||||
|
||||
if handlerErr == nil {
|
||||
api.app.OnRealtimeAfterSubscribeRequest().Trigger(event)
|
||||
}
|
||||
|
||||
return handlerErr
|
||||
}
|
||||
|
||||
func (api *realtimeApi) bindEvents() {
|
||||
userTable := (&models.User{}).TableName()
|
||||
adminTable := (&models.Admin{}).TableName()
|
||||
|
||||
// update user/admin auth state
|
||||
api.app.OnModelAfterUpdate().Add(func(data *core.ModelEvent) error {
|
||||
modelTable := data.Model.TableName()
|
||||
|
||||
var contextKey string
|
||||
if modelTable == userTable {
|
||||
contextKey = ContextUserKey
|
||||
} else if modelTable == adminTable {
|
||||
contextKey = ContextAdminKey
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, client := range api.app.SubscriptionsBroker().Clients() {
|
||||
model, _ := client.Get(contextKey).(models.Model)
|
||||
if model != nil && model.GetId() == data.Model.GetId() {
|
||||
client.Set(contextKey, data.Model)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// remove user/admin client(s)
|
||||
api.app.OnModelAfterDelete().Add(func(data *core.ModelEvent) error {
|
||||
modelTable := data.Model.TableName()
|
||||
|
||||
var contextKey string
|
||||
if modelTable == userTable {
|
||||
contextKey = ContextUserKey
|
||||
} else if modelTable == adminTable {
|
||||
contextKey = ContextAdminKey
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, client := range api.app.SubscriptionsBroker().Clients() {
|
||||
model, _ := client.Get(contextKey).(models.Model)
|
||||
if model != nil && model.GetId() == data.Model.GetId() {
|
||||
api.app.SubscriptionsBroker().Unregister(client.Id())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
api.app.OnRecordAfterCreateRequest().Add(func(data *core.RecordCreateEvent) error {
|
||||
api.broadcastRecord("create", data.Record)
|
||||
return nil
|
||||
})
|
||||
|
||||
api.app.OnRecordAfterUpdateRequest().Add(func(data *core.RecordUpdateEvent) error {
|
||||
api.broadcastRecord("update", data.Record)
|
||||
return nil
|
||||
})
|
||||
|
||||
api.app.OnRecordAfterDeleteRequest().Add(func(data *core.RecordDeleteEvent) error {
|
||||
api.broadcastRecord("delete", data.Record)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (api *realtimeApi) canAccessRecord(client subscriptions.Client, record *models.Record, accessRule *string) bool {
|
||||
admin, _ := client.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin != nil {
|
||||
// admins can access everything
|
||||
return true
|
||||
}
|
||||
|
||||
if accessRule == nil {
|
||||
// only admins can access this record
|
||||
return false
|
||||
}
|
||||
|
||||
ruleFunc := func(q *dbx.SelectQuery) error {
|
||||
if *accessRule == "" {
|
||||
return nil // empty public rule
|
||||
}
|
||||
|
||||
// emulate request data
|
||||
requestData := map[string]any{
|
||||
"method": "get",
|
||||
"query": map[string]any{},
|
||||
"data": map[string]any{},
|
||||
"user": nil,
|
||||
}
|
||||
user, _ := client.Get(ContextUserKey).(*models.User)
|
||||
if user != nil {
|
||||
requestData["user"], _ = user.AsMap()
|
||||
}
|
||||
|
||||
resolver := resolvers.NewRecordFieldResolver(api.app.Dao(), record.Collection(), requestData)
|
||||
expr, err := search.FilterData(*accessRule).BuildExpr(resolver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolver.UpdateQuery(q)
|
||||
q.AndWhere(expr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
foundRecord, err := api.app.Dao().FindRecordById(record.Collection(), record.Id, ruleFunc)
|
||||
if err == nil && foundRecord != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type recordData struct {
|
||||
Action string `json:"action"`
|
||||
Record *models.Record `json:"record"`
|
||||
}
|
||||
|
||||
func (api *realtimeApi) broadcastRecord(action string, record *models.Record) error {
|
||||
collection := record.Collection()
|
||||
if collection == nil {
|
||||
return errors.New("Record collection not set.")
|
||||
}
|
||||
|
||||
clients := api.app.SubscriptionsBroker().Clients()
|
||||
if len(clients) == 0 {
|
||||
return nil // no subscribers
|
||||
}
|
||||
|
||||
subscriptionRuleMap := map[string]*string{
|
||||
(collection.Name + "/" + record.Id): collection.ViewRule,
|
||||
(collection.Id + "/" + record.Id): collection.ViewRule,
|
||||
collection.Name: collection.ListRule,
|
||||
collection.Id: collection.ListRule,
|
||||
}
|
||||
|
||||
recordData := &recordData{
|
||||
Action: action,
|
||||
Record: record,
|
||||
}
|
||||
|
||||
serializedData, err := json.Marshal(recordData)
|
||||
if err != nil {
|
||||
if api.app.IsDebug() {
|
||||
log.Println(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
for _, client := range clients {
|
||||
for subscription, rule := range subscriptionRuleMap {
|
||||
if !client.HasSubscription(subscription) {
|
||||
continue
|
||||
}
|
||||
|
||||
if !api.canAccessRecord(client, record, rule) {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := subscriptions.Message{
|
||||
Name: subscription,
|
||||
Data: string(serializedData),
|
||||
}
|
||||
|
||||
client.Channel() <- msg
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type getter interface {
|
||||
Get(string) any
|
||||
}
|
||||
|
||||
func extractAuthIdFromGetter(val getter) string {
|
||||
user, _ := val.Get(ContextUserKey).(*models.User)
|
||||
if user != nil {
|
||||
return user.Id
|
||||
}
|
||||
|
||||
admin, _ := val.Get(ContextAdminKey).(*models.Admin)
|
||||
if admin != nil {
|
||||
return admin.Id
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user