package apis_test import ( "errors" "net/http" "strings" "testing" "github.com/labstack/echo/v5" "github.com/pocketbase/pocketbase/apis" "github.com/pocketbase/pocketbase/core" "github.com/pocketbase/pocketbase/models" "github.com/pocketbase/pocketbase/tests" "github.com/pocketbase/pocketbase/tools/hook" "github.com/pocketbase/pocketbase/tools/subscriptions" ) func TestRealtimeConnect(t *testing.T) { scenarios := []tests.ApiScenario{ { Method: http.MethodGet, Url: "/api/realtime", ExpectedStatus: 200, ExpectedContent: []string{ `id:`, `event:PB_CONNECT`, `data:{"clientId":`, }, ExpectedEvents: map[string]int{ "OnRealtimeConnectRequest": 1, "OnRealtimeBeforeMessageSend": 1, "OnRealtimeAfterMessageSend": 1, "OnRealtimeDisconnectRequest": 1, }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { if len(app.SubscriptionsBroker().Clients()) != 0 { t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) } }, }, { Name: "PB_CONNECT interrupt", Method: http.MethodGet, Url: "/api/realtime", ExpectedStatus: 200, ExpectedEvents: map[string]int{ "OnRealtimeConnectRequest": 1, "OnRealtimeBeforeMessageSend": 1, "OnRealtimeDisconnectRequest": 1, }, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { if e.Message.Name == "PB_CONNECT" { return errors.New("PB_CONNECT error") } return nil }) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { if len(app.SubscriptionsBroker().Clients()) != 0 { t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) } }, }, { Name: "Skipping/ignoring messages", Method: http.MethodGet, Url: "/api/realtime", ExpectedStatus: 200, ExpectedEvents: map[string]int{ "OnRealtimeConnectRequest": 1, "OnRealtimeBeforeMessageSend": 1, "OnRealtimeAfterMessageSend": 1, "OnRealtimeDisconnectRequest": 1, }, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { app.OnRealtimeBeforeMessageSend().Add(func(e *core.RealtimeMessageEvent) error { return hook.StopPropagation }) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { if len(app.SubscriptionsBroker().Clients()) != 0 { t.Errorf("Expected the subscribers to be removed after connection close, found %d", len(app.SubscriptionsBroker().Clients())) } }, }, } for _, scenario := range scenarios { scenario.Test(t) } } func TestRealtimeSubscribe(t *testing.T) { client := subscriptions.NewDefaultClient() resetClient := func() { client.Unsubscribe() client.Set(apis.ContextAdminKey, nil) client.Set(apis.ContextAuthRecordKey, nil) } scenarios := []tests.ApiScenario{ { Name: "missing client", Method: http.MethodPost, Url: "/api/realtime", Body: strings.NewReader(`{"clientId":"missing","subscriptions":["test1", "test2"]}`), ExpectedStatus: 404, ExpectedContent: []string{`"data":{}`}, }, { Name: "existing client - empty subscriptions", Method: http.MethodPost, Url: "/api/realtime", Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":[]}`), ExpectedStatus: 204, ExpectedEvents: map[string]int{ "OnRealtimeBeforeSubscribeRequest": 1, "OnRealtimeAfterSubscribeRequest": 1, }, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { client.Subscribe("test0") app.SubscriptionsBroker().Register(client) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { if len(client.Subscriptions()) != 0 { t.Errorf("Expected no subscriptions, got %v", client.Subscriptions()) } resetClient() }, }, { Name: "existing client - 2 new subscriptions", Method: http.MethodPost, Url: "/api/realtime", Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), ExpectedStatus: 204, ExpectedEvents: map[string]int{ "OnRealtimeBeforeSubscribeRequest": 1, "OnRealtimeAfterSubscribeRequest": 1, }, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { client.Subscribe("test0") app.SubscriptionsBroker().Register(client) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { expectedSubs := []string{"test1", "test2"} if len(expectedSubs) != len(client.Subscriptions()) { t.Errorf("Expected subscriptions %v, got %v", expectedSubs, client.Subscriptions()) } for _, s := range expectedSubs { if !client.HasSubscription(s) { t.Errorf("Cannot find %q subscription in %v", s, client.Subscriptions()) } } resetClient() }, }, { Name: "existing client - authorized admin", Method: http.MethodPost, Url: "/api/realtime", Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), RequestHeaders: map[string]string{ "Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6InN5d2JoZWNuaDQ2cmhtMCIsInR5cGUiOiJhZG1pbiIsImV4cCI6MjIwODk4NTI2MX0.M1m--VOqGyv0d23eeUc0r9xE8ZzHaYVmVFw1VZW6gT8", }, ExpectedStatus: 204, ExpectedEvents: map[string]int{ "OnRealtimeBeforeSubscribeRequest": 1, "OnRealtimeAfterSubscribeRequest": 1, }, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { app.SubscriptionsBroker().Register(client) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { admin, _ := client.Get(apis.ContextAdminKey).(*models.Admin) if admin == nil { t.Errorf("Expected admin auth model, got nil") } resetClient() }, }, { Name: "existing client - authorized record", Method: http.MethodPost, Url: "/api/realtime", Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), RequestHeaders: map[string]string{ "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", }, ExpectedStatus: 204, ExpectedEvents: map[string]int{ "OnRealtimeBeforeSubscribeRequest": 1, "OnRealtimeAfterSubscribeRequest": 1, }, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { app.SubscriptionsBroker().Register(client) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) if authRecord == nil { t.Errorf("Expected auth record model, got nil") } resetClient() }, }, { Name: "existing client - mismatched auth", Method: http.MethodPost, Url: "/api/realtime", Body: strings.NewReader(`{"clientId":"` + client.Id() + `","subscriptions":["test1", "test2"]}`), RequestHeaders: map[string]string{ "Authorization": "eyJhbGciOiJIUzI1NiJ9.eyJpZCI6IjRxMXhsY2xtZmxva3UzMyIsInR5cGUiOiJhdXRoUmVjb3JkIiwiY29sbGVjdGlvbklkIjoiX3BiX3VzZXJzX2F1dGhfIiwiZXhwIjoyMjA4OTg1MjYxfQ.UwD8JvkbQtXpymT09d7J6fdA0aP9g4FJ1GPh_ggEkzc", }, ExpectedStatus: 403, ExpectedContent: []string{`"data":{}`}, BeforeTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { initialAuth := &models.Record{} initialAuth.RefreshId() client.Set(apis.ContextAuthRecordKey, initialAuth) app.SubscriptionsBroker().Register(client) }, AfterTestFunc: func(t *testing.T, app *tests.TestApp, e *echo.Echo) { authRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) if authRecord == nil { t.Errorf("Expected auth record model, got nil") } resetClient() }, }, } for _, scenario := range scenarios { scenario.Test(t) } } func TestRealtimeAuthRecordDeleteEvent(t *testing.T) { testApp, _ := tests.NewTestApp() defer testApp.Cleanup() apis.InitApi(testApp) authRecord, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") if err != nil { t.Fatal(err) } client := subscriptions.NewDefaultClient() client.Set(apis.ContextAuthRecordKey, authRecord) testApp.SubscriptionsBroker().Register(client) testApp.OnModelAfterDelete().Trigger(&core.ModelEvent{Dao: testApp.Dao(), Model: authRecord}) if len(testApp.SubscriptionsBroker().Clients()) != 0 { t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients())) } } func TestRealtimeAuthRecordUpdateEvent(t *testing.T) { testApp, _ := tests.NewTestApp() defer testApp.Cleanup() apis.InitApi(testApp) authRecord1, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") if err != nil { t.Fatal(err) } client := subscriptions.NewDefaultClient() client.Set(apis.ContextAuthRecordKey, authRecord1) testApp.SubscriptionsBroker().Register(client) // refetch the authRecord and change its email authRecord2, err := testApp.Dao().FindFirstRecordByData("users", "email", "test@example.com") if err != nil { t.Fatal(err) } authRecord2.SetEmail("new@example.com") testApp.OnModelAfterUpdate().Trigger(&core.ModelEvent{Dao: testApp.Dao(), Model: authRecord2}) clientAuthRecord, _ := client.Get(apis.ContextAuthRecordKey).(*models.Record) if clientAuthRecord.Email() != authRecord2.Email() { t.Fatalf("Expected authRecord with email %q, got %q", authRecord2.Email(), clientAuthRecord.Email()) } } func TestRealtimeAdminDeleteEvent(t *testing.T) { testApp, _ := tests.NewTestApp() defer testApp.Cleanup() apis.InitApi(testApp) admin, err := testApp.Dao().FindAdminByEmail("test@example.com") if err != nil { t.Fatal(err) } client := subscriptions.NewDefaultClient() client.Set(apis.ContextAdminKey, admin) testApp.SubscriptionsBroker().Register(client) testApp.OnModelAfterDelete().Trigger(&core.ModelEvent{Dao: testApp.Dao(), Model: admin}) if len(testApp.SubscriptionsBroker().Clients()) != 0 { t.Fatalf("Expected no subscription clients, found %d", len(testApp.SubscriptionsBroker().Clients())) } } func TestRealtimeAdminUpdateEvent(t *testing.T) { testApp, _ := tests.NewTestApp() defer testApp.Cleanup() apis.InitApi(testApp) admin1, err := testApp.Dao().FindAdminByEmail("test@example.com") if err != nil { t.Fatal(err) } client := subscriptions.NewDefaultClient() client.Set(apis.ContextAdminKey, admin1) testApp.SubscriptionsBroker().Register(client) // refetch the authRecord and change its email admin2, err := testApp.Dao().FindAdminByEmail("test@example.com") if err != nil { t.Fatal(err) } admin2.Email = "new@example.com" testApp.OnModelAfterUpdate().Trigger(&core.ModelEvent{Dao: testApp.Dao(), Model: admin2}) clientAdmin, _ := client.Get(apis.ContextAdminKey).(*models.Admin) if clientAdmin.Email != admin2.Email { t.Fatalf("Expected authRecord with email %q, got %q", admin2.Email, clientAdmin.Email) } }