From ac52befb5ba4fa961ae4ac1ebdff68a0465c949c Mon Sep 17 00:00:00 2001 From: Gani Georgiev Date: Thu, 20 Jul 2023 16:32:21 +0300 Subject: [PATCH] changed subscription.Message.Data to []byte and added client.Send(m) helper --- CHANGELOG.md | 2 ++ apis/realtime.go | 33 +++++++++++------------ apis/record_auth.go | 4 +-- apis/record_auth_test.go | 2 +- tools/subscriptions/client.go | 14 +++++++++- tools/subscriptions/client_test.go | 43 ++++++++++++++++++++++++++++++ 6 files changed, 76 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12ce2b98..f875e9ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -87,6 +87,8 @@ - **!** renamed `models.RequestData` to `models.RequestInfo` and soft-deprecated `apis.RequestData(c)` to `apis.RequestInfo(c)` to avoid the stuttering with the `Data` field. _The old `apis.RequestData()` method still works to minimize the breaking changes but it is recommended to replace it with `apis.RequestInfo(c)`._ +- **!** Changed the type of `subscriptions.Message.Data` from `string` to `[]byte` because `Data` usually is a json bytes slice anyway. + - Added `?download` file query parameter option to instruct the browser to always download a file and not show a preview. diff --git a/apis/realtime.go b/apis/realtime.go index 7d4a38d1..24446e1a 100644 --- a/apis/realtime.go +++ b/apis/realtime.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "log" "net/http" "strings" @@ -83,14 +82,16 @@ func (api *realtimeApi) connect(c echo.Context) error { Client: client, Message: &subscriptions.Message{ Name: "PB_CONNECT", - Data: `{"clientId":"` + client.Id() + `"}`, + Data: []byte(`{"clientId":"` + client.Id() + `"}`), }, } connectMsgErr := api.app.OnRealtimeBeforeMessageSend().Trigger(connectMsgEvent, func(e *core.RealtimeMessageEvent) error { w := e.HttpContext.Response() - fmt.Fprint(w, "id:"+client.Id()+"\n") - fmt.Fprint(w, "event:"+e.Message.Name+"\n") - fmt.Fprint(w, "data:"+e.Message.Data+"\n\n") + w.Write([]byte("id:" + client.Id() + "\n")) + w.Write([]byte("event:" + e.Message.Name + "\n")) + w.Write([]byte("data:")) + w.Write(e.Message.Data) + w.Write([]byte("\n\n")) w.Flush() return api.app.OnRealtimeAfterMessageSend().Trigger(e) }) @@ -126,9 +127,11 @@ func (api *realtimeApi) connect(c echo.Context) error { } msgErr := api.app.OnRealtimeBeforeMessageSend().Trigger(msgEvent, func(e *core.RealtimeMessageEvent) error { w := e.HttpContext.Response() - fmt.Fprint(w, "id:"+e.Client.Id()+"\n") - fmt.Fprint(w, "event:"+e.Message.Name+"\n") - fmt.Fprint(w, "data:"+e.Message.Data+"\n\n") + w.Write([]byte("id:" + e.Client.Id() + "\n")) + w.Write([]byte("event:" + e.Message.Name + "\n")) + w.Write([]byte("data:")) + w.Write(e.Message.Data) + w.Write([]byte("\n\n")) w.Flush() return api.app.OnRealtimeAfterMessageSend().Trigger(msgEvent) }) @@ -406,8 +409,6 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr return err } - encodedData := string(dataBytes) - for _, client := range clients { client := client @@ -422,7 +423,7 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr msg := subscriptions.Message{ Name: subscription, - Data: encodedData, + Data: dataBytes, } // ignore the auth record email visibility checks for @@ -433,7 +434,7 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr api.canAccessRecord(client, data.Record, collection.AuthOptions().ManageRule) { data.Record.IgnoreEmailVisibility(true) // ignore if newData, err := json.Marshal(data); err == nil { - msg.Data = string(newData) + msg.Data = newData } data.Record.IgnoreEmailVisibility(false) // restore } @@ -443,9 +444,7 @@ func (api *realtimeApi) broadcastRecord(action string, record *models.Record, dr client.Set(action+"/"+data.Record.Id, msg) } else { routine.FireAndForget(func() { - if !client.IsDiscarded() { - client.Channel() <- msg - } + client.Send(msg) }) } } @@ -471,9 +470,7 @@ func (api *realtimeApi) broadcastDryCachedRecord(action string, record *models.R client := client routine.FireAndForget(func() { - if !client.IsDiscarded() { - client.Channel() <- msg - } + client.Send(msg) }) } return nil diff --git a/apis/record_auth.go b/apis/record_auth.go index e2ce0e6e..c2b7b84e 100644 --- a/apis/record_auth.go +++ b/apis/record_auth.go @@ -661,10 +661,10 @@ func (api *recordAuthApi) oauth2SubscriptionRedirect(c echo.Context) error { msg := subscriptions.Message{ Name: oauth2SubscriptionTopic, - Data: string(encodedData), + Data: encodedData, } - client.Channel() <- msg + client.Send(msg) return c.Redirect(http.StatusTemporaryRedirect, "../_/#/auth/oauth2-redirect") } diff --git a/apis/record_auth_test.go b/apis/record_auth_test.go index cfc2484d..9144480f 100644 --- a/apis/record_auth_test.go +++ b/apis/record_auth_test.go @@ -1362,7 +1362,7 @@ func TestRecordAuthOAuth2Redirect(t *testing.T) { expectedParams := []string{`"state"`, `"code"`} for _, p := range expectedParams { - if !strings.Contains(msg.Data, p) { + if !strings.Contains(string(msg.Data), p) { t.Errorf("Couldn't find %s in \n%v", p, msg.Data) } } diff --git a/tools/subscriptions/client.go b/tools/subscriptions/client.go index b50fd28f..85b40827 100644 --- a/tools/subscriptions/client.go +++ b/tools/subscriptions/client.go @@ -9,7 +9,7 @@ import ( // Message defines a client's channel data. type Message struct { Name string - Data string + Data []byte } // Client is an interface for a generic subscription client. @@ -50,6 +50,9 @@ type Client interface { // IsDiscarded indicates whether the client has been "discarded" // and should no longer be used. IsDiscarded() bool + + // Send sends the specified message to the client's channel (if not discarded). + Send(m Message) } // ensures that DefaultClient satisfies the Client interface @@ -183,3 +186,12 @@ func (c *DefaultClient) IsDiscarded() bool { return c.isDiscarded } + +// Send sends the specified message to the client's channel (if not discarded). +func (c *DefaultClient) Send(m Message) { + if c.IsDiscarded() { + return + } + + c.Channel() <- m +} diff --git a/tools/subscriptions/client_test.go b/tools/subscriptions/client_test.go index 00ffe33f..58bbd939 100644 --- a/tools/subscriptions/client_test.go +++ b/tools/subscriptions/client_test.go @@ -2,6 +2,7 @@ package subscriptions_test import ( "testing" + "time" "github.com/pocketbase/pocketbase/tools/subscriptions" ) @@ -143,3 +144,45 @@ func TestDiscard(t *testing.T) { t.Fatal("Expected true, got false") } } + +func TestSend(t *testing.T) { + c := subscriptions.NewDefaultClient() + + received := []string{} + go func() { + for { + select { + case m, ok := <-c.Channel(): + if !ok { + return + } + received = append(received, m.Name) + } + } + }() + + c.Send(subscriptions.Message{Name: "m1"}) + c.Send(subscriptions.Message{Name: "m2"}) + c.Discard() + c.Send(subscriptions.Message{Name: "m3"}) + c.Send(subscriptions.Message{Name: "m4"}) + time.Sleep(5 * time.Millisecond) + + expected := []string{"m1", "m2"} + + if len(received) != len(expected) { + t.Fatalf("Expected %d messages, got %d", len(expected), len(received)) + } + for _, name := range expected { + var exists bool + for _, n := range received { + if n == name { + exists = true + break + } + } + if !exists { + t.Fatalf("Missing expected %q message, got %v", name, received) + } + } +}