diff --git a/server/app/blocks.go b/server/app/blocks.go index 36a5a93bf..d7a4c8abd 100644 --- a/server/app/blocks.go +++ b/server/app/blocks.go @@ -49,7 +49,7 @@ func (a *App) InsertBlocks(c store.Container, blocks []model.Block) error { return err } - a.wsServer.BroadcastBlockChange(block) + a.wsServer.BroadcastBlockChange(c.WorkspaceID, block) go a.webhook.NotifyUpdate(block) } @@ -84,7 +84,7 @@ func (a *App) DeleteBlock(c store.Container, blockID string, modifiedBy string) return err } - a.wsServer.BroadcastBlockDelete(blockID, parentID) + a.wsServer.BroadcastBlockDelete(c.WorkspaceID, blockID, parentID) return nil } diff --git a/server/ws/websockets.go b/server/ws/websockets.go index f8a002e16..69b38abe5 100644 --- a/server/ws/websockets.go +++ b/server/ws/websockets.go @@ -14,16 +14,21 @@ import ( "github.com/mattermost/focalboard/server/services/store" ) +type WorkspaceAuthenticator interface { + DoesUserHaveWorkspaceAccess(session *model.Session, workspaceID string) bool +} + // IsValidSessionToken authenticates session tokens type IsValidSessionToken func(token string) bool // Server is a WebSocket server. type Server struct { - upgrader websocket.Upgrader - listeners map[string][]*websocket.Conn - mu sync.RWMutex - auth *auth.Auth - singleUserToken string + upgrader websocket.Upgrader + listeners map[string][]*websocket.Conn + mu sync.RWMutex + auth *auth.Auth + singleUserToken string + WorkspaceAuthenticator WorkspaceAuthenticator } // UpdateMsg is sent on block updates @@ -39,15 +44,17 @@ type ErrorMsg struct { // WebsocketCommand is an incoming command from the client. type WebsocketCommand struct { - Action string `json:"action"` - Token string `json:"token"` - ReadToken string `json:"readToken"` - BlockIDs []string `json:"blockIds"` + Action string `json:"action"` + WorkspaceID string `json:"workspaceId"` + Token string `json:"token"` + ReadToken string `json:"readToken"` + BlockIDs []string `json:"blockIds"` } type websocketSession struct { client *websocket.Conn isAuthenticated bool + workspaceID string } // NewServer creates a new Server. @@ -73,7 +80,8 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request // Upgrade initial GET request to a websocket client, err := ws.upgrader.Upgrade(w, r, nil) if err != nil { - log.Fatal(err) + log.Printf("ERROR upgrading to websocket: %v", err) + return } // TODO: Auth @@ -118,14 +126,14 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request switch command.Action { case "AUTH": log.Printf(`Command: AUTH, client: %s`, client.RemoteAddr()) - ws.authenticateListener(&wsSession, command.Token, command.ReadToken) + ws.authenticateListener(&wsSession, command.WorkspaceID, command.Token, command.ReadToken) case "ADD": - log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) + log.Printf(`Command: Add workspaceID: %s, blockIDs: %v, client: %s`, wsSession.workspaceID, command.BlockIDs, client.RemoteAddr()) ws.addListener(&wsSession, &command) case "REMOVE": - log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr()) + log.Printf(`Command: Remove workspaceID: %s, blockID: %v, client: %s`, wsSession.workspaceID, command.BlockIDs, client.RemoteAddr()) ws.removeListenerFromBlocks(&wsSession, &command) default: @@ -134,30 +142,49 @@ func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request } } -func (ws *Server) isValidSessionToken(token string) bool { +func (ws *Server) isValidSessionToken(token, workspaceID string) bool { if len(ws.singleUserToken) > 0 { return token == ws.singleUserToken } session, err := ws.auth.GetSession(token) - if session != nil && err == nil { - return true + if session == nil || err != nil { + return false } - return false + // Check workspace permission + if ws.WorkspaceAuthenticator != nil { + if !ws.WorkspaceAuthenticator.DoesUserHaveWorkspaceAccess(session, workspaceID) { + return false + } + } + + return true } -func (ws *Server) authenticateListener(wsSession *websocketSession, token, readToken string) { +func (ws *Server) authenticateListener(wsSession *websocketSession, workspaceID, token, readToken string) { + if wsSession.isAuthenticated { + // Do not allow multiple auth calls (for security) + log.Printf("authenticateListener: Ignoring already authenticated session") + return + } + // Authenticate session - isValidSession := ws.isValidSessionToken(token) + isValidSession := ws.isValidSessionToken(token, workspaceID) if !isValidSession { wsSession.client.Close() return } // Authenticated + + // Special case: Default workspace is blank + if workspaceID == "0" { + workspaceID = "" + } + wsSession.workspaceID = workspaceID wsSession.isAuthenticated = true - log.Printf("authenticateListener: Authenticated") + log.Printf("authenticateListener: Authenticated, workspaceID: %s", workspaceID) } func (ws *Server) getContainer(wsSession *websocketSession) (store.Container, error) { @@ -195,6 +222,11 @@ func (ws *Server) checkAuthentication(wsSession *websocketSession, command *Webs return false } +// TODO: Refactor workspace hashing +func makeItemID(workspaceID, blockID string) string { + return workspaceID + "-" + blockID +} + // addListener adds a listener for a block's change. func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCommand) { if !ws.checkAuthentication(wsSession, command) { @@ -203,13 +235,16 @@ func (ws *Server) addListener(wsSession *websocketSession, command *WebsocketCom return } + workspaceID := wsSession.workspaceID + ws.mu.Lock() for _, blockID := range command.BlockIDs { - if ws.listeners[blockID] == nil { - ws.listeners[blockID] = []*websocket.Conn{} + itemID := makeItemID(workspaceID, blockID) + if ws.listeners[itemID] == nil { + ws.listeners[itemID] = []*websocket.Conn{} } - ws.listeners[blockID] = append(ws.listeners[blockID], wsSession.client) + ws.listeners[itemID] = append(ws.listeners[itemID], wsSession.client) } ws.mu.Unlock() } @@ -239,10 +274,12 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command return } - ws.mu.Lock() + workspaceID := wsSession.workspaceID + ws.mu.Lock() for _, blockID := range command.BlockIDs { - listeners := ws.listeners[blockID] + itemID := makeItemID(workspaceID, blockID) + listeners := ws.listeners[itemID] if listeners == nil { return } @@ -252,7 +289,7 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command for index, listener := range listeners { if wsSession.client == listener { newListeners := append(listeners[:index], listeners[index+1:]...) - ws.listeners[blockID] = newListeners + ws.listeners[itemID] = newListeners break } @@ -275,16 +312,17 @@ func sendError(conn *websocket.Conn, message string) { } // getListeners returns the listeners to a blockID's changes. -func (ws *Server) getListeners(blockID string) []*websocket.Conn { +func (ws *Server) getListeners(workspaceID string, blockID string) []*websocket.Conn { ws.mu.Lock() - listeners := ws.listeners[blockID] + itemID := makeItemID(workspaceID, blockID) + listeners := ws.listeners[itemID] ws.mu.Unlock() return listeners } // BroadcastBlockDelete broadcasts delete messages to clients -func (ws *Server) BroadcastBlockDelete(blockID, parentID string) { +func (ws *Server) BroadcastBlockDelete(workspaceID, blockID, parentID string) { now := time.Now().Unix() block := model.Block{} block.ID = blockID @@ -292,15 +330,15 @@ func (ws *Server) BroadcastBlockDelete(blockID, parentID string) { block.UpdateAt = now block.DeleteAt = now - ws.BroadcastBlockChange(block) + ws.BroadcastBlockChange(workspaceID, block) } // BroadcastBlockChange broadcasts update messages to clients -func (ws *Server) BroadcastBlockChange(block model.Block) { +func (ws *Server) BroadcastBlockChange(workspaceID string, block model.Block) { blockIDsToNotify := []string{block.ID, block.ParentID} for _, blockID := range blockIDsToNotify { - listeners := ws.getListeners(blockID) + listeners := ws.getListeners(workspaceID, blockID) log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID) if listeners != nil { @@ -310,7 +348,7 @@ func (ws *Server) BroadcastBlockChange(block model.Block) { } for _, listener := range listeners { - log.Printf("Broadcast change, blockID: %s, remoteAddr: %s", blockID, listener.RemoteAddr()) + log.Printf("Broadcast change, workspaceID: %s, blockID: %s, remoteAddr: %s", workspaceID, blockID, listener.RemoteAddr()) err := listener.WriteJSON(message) if err != nil { diff --git a/webapp/src/components/cardDialog.tsx b/webapp/src/components/cardDialog.tsx index 3b67dcdd3..34a7db9bf 100644 --- a/webapp/src/components/cardDialog.tsx +++ b/webapp/src/components/cardDialog.tsx @@ -4,6 +4,7 @@ import React from 'react' import {FormattedMessage, injectIntl, IntlShape} from 'react-intl' import mutator from '../mutator' +import octoClient from '../octoClient' import {OctoListener} from '../octoListener' import {Utils} from '../utils' import {BoardTree} from '../viewModel/boardTree' @@ -53,6 +54,7 @@ class CardDialog extends React.Component { this.cardListener = new OctoListener() this.cardListener.open( + octoClient.workspaceId, [this.props.cardId], async (blocks) => { Utils.log(`cardListener.onChanged: ${blocks.length}`) diff --git a/webapp/src/octoListener.ts b/webapp/src/octoListener.ts index 898eb03b4..d569d0148 100644 --- a/webapp/src/octoListener.ts +++ b/webapp/src/octoListener.ts @@ -6,6 +6,7 @@ import {Utils} from './utils' // These are outgoing commands to the server type WSCommand = { action: string + workspaceId?: string readToken?: string blockIds: string[] } @@ -54,7 +55,7 @@ class OctoListener { return readToken } - open(blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void { + open(workspaceId: string, blockIds: string[], onChange: OnChangeHandler, onReconnect: () => void): void { if (this.ws) { this.close() } @@ -70,7 +71,7 @@ class OctoListener { ws.onopen = () => { Utils.log('OctoListener webSocket opened.') - this.authenticate() + this.authenticate(workspaceId) this.addBlocks(blockIds) this.isInitialized = true } @@ -86,7 +87,7 @@ class OctoListener { const reopenBlockIds = this.isInitialized ? this.blockIds.slice() : blockIds.slice() Utils.logError(`Unexpected close, re-opening with ${reopenBlockIds.length} blocks...`) setTimeout(() => { - this.open(reopenBlockIds, onChange, onReconnect) + this.open(workspaceId, reopenBlockIds, onChange, onReconnect) onReconnect() }, this.reopenDelay) } @@ -135,7 +136,7 @@ class OctoListener { ws.close() } - authenticate(): void { + private authenticate(workspaceId: string): void { if (!this.ws) { Utils.assertFailure('OctoListener.addBlocks: ws is not open') return @@ -147,11 +148,12 @@ class OctoListener { const command = { action: 'AUTH', token: this.token, + workspaceId, } this.ws.send(JSON.stringify(command)) } - addBlocks(blockIds: string[]): void { + private addBlocks(blockIds: string[]): void { if (!this.ws) { Utils.assertFailure('OctoListener.addBlocks: ws is not open') return @@ -167,7 +169,7 @@ class OctoListener { this.blockIds.push(...blockIds) } - removeBlocks(blockIds: string[]): void { + private removeBlocks(blockIds: string[]): void { if (!this.ws) { Utils.assertFailure('OctoListener.removeBlocks: ws is not open') return diff --git a/webapp/src/pages/boardPage.tsx b/webapp/src/pages/boardPage.tsx index faacf28ae..c57fc52d4 100644 --- a/webapp/src/pages/boardPage.tsx +++ b/webapp/src/pages/boardPage.tsx @@ -221,6 +221,7 @@ class BoardPage extends React.Component { // Listen to boards plus all blocks at root (Empty string for parentId) this.workspaceListener.open( + octoClient.workspaceId, boardIdsToListen, async (blocks) => { Utils.log(`workspaceListener.onChanged: ${blocks.length}`)