From 0192ef6f3f51dbebb4e212c24ae6e5fb2423f11d Mon Sep 17 00:00:00 2001 From: Scott Bishel Date: Wed, 29 Mar 2023 16:27:18 -0600 Subject: [PATCH] only send category updates to user (#4672) * only send category updates to user * remove unused function * Revert "remove unused function" This reverts commit 8c4fc9b2002635ea13c73cde31fb32845eca8cb2. * remove unused function * fix test --------- Co-authored-by: Mattermost Build --- server/ws/server.go | 64 +++++++++++++++------------------------- server/ws/server_test.go | 42 ++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/server/ws/server.go b/server/ws/server.go index 9df1b1f6e..dfa4af8a1 100644 --- a/server/ws/server.go +++ b/server/ws/server.go @@ -465,10 +465,15 @@ func (ws *Server) getListenersForBlock(blockID string) []*websocketSession { return ws.listenersByBlock[blockID] } -// getListenersForTeam returns the listeners subscribed to a +// getListenersForUser returns the listener for a user subscribed to a // team changes. -func (ws *Server) getListenersForTeam(teamID string) []*websocketSession { - return ws.listenersByTeam[teamID] +func (ws *Server) getListenerForUser(teamID, userID string) *websocketSession { + for _, listener := range ws.listenersByTeam[teamID] { + if listener.userID == userID { + return listener + } + } + return nil } // getListenersForTeamAndBoard returns the listeners subscribed to a @@ -567,16 +572,10 @@ func (ws *Server) BroadcastCategoryChange(category model.Category) { Category: &category, } - listeners := ws.getListenersForTeam(category.TeamID) - ws.logger.Debug("listener(s) for teamID", - mlog.Int("listener_count", len(listeners)), - mlog.String("teamID", category.TeamID), - mlog.String("categoryID", category.ID), - ) - - for _, listener := range listeners { - ws.logger.Debug("Broadcast block change", - mlog.Int("listener_count", len(listeners)), + listener := ws.getListenerForUser(category.TeamID, category.UserID) + if listener != nil { + ws.logger.Debug("Broadcast category change", + mlog.String("userID", category.UserID), mlog.String("teamID", category.TeamID), mlog.String("categoryID", category.ID), mlog.Stringer("remoteAddr", listener.conn.RemoteAddr()), @@ -596,15 +595,10 @@ func (ws *Server) BroadcastCategoryReorder(teamID, userID string, categoryOrder TeamID: teamID, } - listeners := ws.getListenersForTeam(teamID) - ws.logger.Debug("listener(s) for teamID", - mlog.Int("listener_count", len(listeners)), - mlog.String("teamID", teamID), - ) - - for _, listener := range listeners { + listener := ws.getListenerForUser(teamID, userID) + if listener != nil { ws.logger.Debug("Broadcast category order change", - mlog.Int("listener_count", len(listeners)), + mlog.String("userID", userID), mlog.String("teamID", teamID), mlog.Stringer("remoteAddr", listener.conn.RemoteAddr()), ) @@ -624,21 +618,17 @@ func (ws *Server) BroadcastCategoryBoardsReorder(teamID, userID, categoryID stri TeamID: teamID, } - listeners := ws.getListenersForTeam(teamID) - ws.logger.Debug("listener(s) for teamID", - mlog.Int("listener_count", len(listeners)), - mlog.String("teamID", teamID), - ) - - for _, listener := range listeners { + listener := ws.getListenerForUser(teamID, userID) + if listener != nil { ws.logger.Debug("Broadcast board category order change", - mlog.Int("listener_count", len(listeners)), + mlog.String("userID", userID), mlog.String("teamID", teamID), + mlog.String("categoryID", categoryID), mlog.Stringer("remoteAddr", listener.conn.RemoteAddr()), ) if err := listener.WriteJSON(message); err != nil { - ws.logger.Error("broadcast category order change error", mlog.Err(err)) + ws.logger.Error("broadcast category boards order change error", mlog.Err(err)) listener.conn.Close() } } @@ -651,16 +641,10 @@ func (ws *Server) BroadcastCategoryBoardChange(teamID, userID string, boardCateg BoardCategories: boardCategories, } - listeners := ws.getListenersForTeam(teamID) - ws.logger.Debug("listener(s) for teamID", - mlog.Int("listener_count", len(listeners)), - mlog.String("teamID", teamID), - mlog.Int("numEntries", len(boardCategories)), - ) - - for _, listener := range listeners { - ws.logger.Debug("Broadcast block change", - mlog.Int("listener_count", len(listeners)), + listener := ws.getListenerForUser(teamID, userID) + if listener != nil { + ws.logger.Debug("Broadcast category board change", + mlog.String("userID", userID), mlog.String("teamID", teamID), mlog.Int("numEntries", len(boardCategories)), mlog.Stringer("remoteAddr", listener.conn.RemoteAddr()), diff --git a/server/ws/server_test.go b/server/ws/server_test.go index ba56bacc1..fa728e2cd 100644 --- a/server/ws/server_test.go +++ b/server/ws/server_test.go @@ -101,6 +101,48 @@ func TestTeamSubscription(t *testing.T) { require.Empty(t, server.listenersByTeam[teamID]) require.Empty(t, server.listenersByTeam[teamID2]) }) + + t.Run("Subscribe users to team retrieve by user", func(t *testing.T) { + userID1 := "fake-user-id" + userSession1 := &websocketSession{ + conn: &websocket.Conn{}, + mu: sync.Mutex{}, + userID: userID1, + teams: []string{}, + blocks: []string{}, + } + userID2 := "fake-user-id2" + userSession2 := &websocketSession{ + conn: &websocket.Conn{}, + mu: sync.Mutex{}, + userID: userID2, + teams: []string{}, + blocks: []string{}, + } + teamID := "fake-team-id" + + server.addListener(session) + server.subscribeListenerToTeam(session, teamID) + server.addListener(userSession1) + server.subscribeListenerToTeam(userSession1, teamID) + server.addListener(userSession2) + server.subscribeListenerToTeam(userSession2, teamID) + + require.Len(t, server.listeners, 3) + require.Len(t, server.listenersByTeam[teamID], 3) + + listener := server.getListenerForUser(teamID, userID1) + require.NotNil(t, listener) + require.Equal(t, listener.userID, userID1) + + server.removeListener(session) + server.removeListener(userSession1) + server.removeListener(userSession2) + + require.Empty(t, server.listeners) + require.Empty(t, server.listenersByTeam[teamID]) + require.Empty(t, server.getListenerForUser(teamID, userID1)) + }) } func TestBlocksSubscription(t *testing.T) {