1
0
mirror of https://github.com/mattermost/focalboard.git synced 2025-07-12 23:50:27 +02:00

Prevented concurrent writes to websocket (#658)

* retained individual connection objects

* unlocking lock in defer

* Completely abstracted internal connection object

* Completely removed direct use of WS connection
This commit is contained in:
Harshil Sharma
2021-07-01 11:41:29 +05:30
committed by GitHub
parent ba69c8b083
commit 1020c03924

View File

@ -24,10 +24,22 @@ type Hub interface {
SetReceiveWSMessage(func(data []byte)) SetReceiveWSMessage(func(data []byte))
} }
type wsClient struct {
*websocket.Conn
lock *sync.RWMutex
}
func (c *wsClient) WriteJSON(v interface{}) error {
c.lock.Lock()
defer c.lock.Unlock()
err := c.Conn.WriteJSON(v)
return err
}
// Server is a WebSocket server. // Server is a WebSocket server.
type Server struct { type Server struct {
upgrader websocket.Upgrader upgrader websocket.Upgrader
listeners map[string][]*websocket.Conn listeners map[string][]*wsClient
mu sync.RWMutex mu sync.RWMutex
auth *auth.Auth auth *auth.Auth
hub Hub hub Hub
@ -64,7 +76,7 @@ type WebsocketCommand struct {
} }
type websocketSession struct { type websocketSession struct {
client *websocket.Conn client *wsClient
isAuthenticated bool isAuthenticated bool
workspaceID string workspaceID string
} }
@ -72,7 +84,7 @@ type websocketSession struct {
// NewServer creates a new Server. // NewServer creates a new Server.
func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, logger *mlog.Logger) *Server { func NewServer(auth *auth.Auth, singleUserToken string, isMattermostAuth bool, logger *mlog.Logger) *Server {
return &Server{ return &Server{
listeners: make(map[string][]*websocket.Conn), listeners: make(map[string][]*wsClient),
upgrader: websocket.Upgrader{ upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
@ -98,36 +110,34 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
return return
} }
// Make sure we close the connection when the function returns
defer func() {
ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", client.RemoteAddr()))
// Remove client from listeners
ws.removeListener(client)
client.Close()
}()
userID := "" userID := ""
if ws.isMattermostAuth { if ws.isMattermostAuth {
userID = r.Header.Get("Mattermost-User-Id") userID = r.Header.Get("Mattermost-User-Id")
} }
wsSession := websocketSession{ wsSession := websocketSession{
client: client, client: &wsClient{client, &sync.RWMutex{}},
isAuthenticated: userID != "", isAuthenticated: userID != "",
} }
// Make sure we close the connection when the function returns
defer func() {
ws.logger.Debug("DISCONNECT WebSocket onChange", mlog.Stringer("client", wsSession.client.RemoteAddr()))
// Remove client from listeners
ws.removeListener(wsSession.client)
wsSession.client.Close()
}()
// Simple message handling loop // Simple message handling loop
for { for {
_, p, err := client.ReadMessage() _, p, err := wsSession.client.ReadMessage()
if err != nil { if err != nil {
ws.logger.Error("ERROR WebSocket onChange", ws.logger.Error("ERROR WebSocket onChange",
mlog.Stringer("client", client.RemoteAddr()), mlog.Stringer("client", wsSession.client.RemoteAddr()),
mlog.Err(err), mlog.Err(err),
) )
ws.removeListener(client) ws.removeListener(wsSession.client)
break break
} }
@ -152,20 +162,20 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request
switch command.Action { switch command.Action {
case "AUTH": case "AUTH":
ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", client.RemoteAddr())) ws.logger.Debug(`Command: AUTH`, mlog.Stringer("client", wsSession.client.RemoteAddr()))
ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token) ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token)
case "ADD": case "ADD":
ws.logger.Debug(`Command: ADD`, ws.logger.Debug(`Command: ADD`,
mlog.String("workspaceID", wsSession.workspaceID), mlog.String("workspaceID", wsSession.workspaceID),
mlog.Array("blockIDs", command.BlockIDs), mlog.Array("blockIDs", command.BlockIDs),
mlog.Stringer("client", client.RemoteAddr()), mlog.Stringer("client", wsSession.client.RemoteAddr()),
) )
ws.addListener(&wsSession, &command) ws.addListener(&wsSession, &command)
case "REMOVE": case "REMOVE":
ws.logger.Debug(`Command: REMOVE`, ws.logger.Debug(`Command: REMOVE`,
mlog.String("workspaceID", wsSession.workspaceID), mlog.String("workspaceID", wsSession.workspaceID),
mlog.Array("blockIDs", command.BlockIDs), mlog.Array("blockIDs", command.BlockIDs),
mlog.Stringer("client", client.RemoteAddr()), mlog.Stringer("client", wsSession.client.RemoteAddr()),
) )
ws.removeListenerFromBlocks(&wsSession, &command) ws.removeListenerFromBlocks(&wsSession, &command)
@ -258,7 +268,7 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom
for _, blockID := range command.BlockIDs { for _, blockID := range command.BlockIDs {
itemID := makeItemID(workspaceID, blockID) itemID := makeItemID(workspaceID, blockID)
if ws.listeners[itemID] == nil { if ws.listeners[itemID] == nil {
ws.listeners[itemID] = []*websocket.Conn{} ws.listeners[itemID] = []*wsClient{}
} }
ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client) ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client)
@ -267,10 +277,10 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom
} }
// removeListener removes a webSocket listener from all blocks. // removeListener removes a webSocket listener from all blocks.
func (ws *Server) removeListener(client *websocket.Conn) { func (ws *Server) removeListener(client *wsClient) {
ws.mu.Lock() ws.mu.Lock()
for key, clients := range ws.listeners { for key, clients := range ws.listeners {
listeners := []*websocket.Conn{} listeners := []*wsClient{}
for _, existingClient := range clients { for _, existingClient := range clients {
if client != existingClient { if client != existingClient {
@ -315,15 +325,15 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command
ws.mu.Unlock() ws.mu.Unlock()
} }
func (ws *Server) sendError(conn *websocket.Conn, message string) { func (ws *Server) sendError(wsClient *wsClient, message string) {
errorMsg := ErrorMsg{ errorMsg := ErrorMsg{
Error: message, Error: message,
} }
err := conn.WriteJSON(errorMsg) err := wsClient.WriteJSON(errorMsg)
if err != nil { if err != nil {
ws.logger.Error("sendError error", mlog.Err(err)) ws.logger.Error("sendError error", mlog.Err(err))
conn.Close() wsClient.Close()
} }
} }
@ -358,7 +368,7 @@ func (ws *Server) SetHub(hub Hub) {
} }
// getListeners returns the listeners to a blockID's changes. // getListeners returns the listeners to a blockID's changes.
func (ws *Server) getListeners(workspaceID string, blockID string) []*websocket.Conn { func (ws *Server) getListeners(workspaceID string, blockID string) []*wsClient {
ws.mu.Lock() ws.mu.Lock()
itemID := makeItemID(workspaceID, blockID) itemID := makeItemID(workspaceID, blockID)
listeners := ws.listeners[itemID] listeners := ws.listeners[itemID]