1
0
mirror of https://github.com/mattermost/focalboard.git synced 2024-12-24 13:43:12 +02:00

Some improvements based on golangci-lint checks, and adding more rules

This commit is contained in:
Jesús Espino 2020-10-22 13:34:42 +02:00
parent 3516e6b26c
commit 607b8aa063
14 changed files with 122 additions and 90 deletions

View File

@ -25,11 +25,14 @@ server-lint:
echo "golangci-lint is not installed. Please see https://github.com/golangci/golangci-lint#install for installation instructions."; \
exit 1; \
fi; \
cd server; golangci-lint run ./...
cd server; golangci-lint run -p format -p unused -p complexity -p bugs -p performance -E asciicheck -E depguard -E dogsled -E dupl -E funlen -E gochecknoglobals -E gochecknoinits -E goconst -E gocritic -E godot -E godox -E goerr113 -E goheader -E golint -E gomnd -E gomodguard -E goprintffuncname -E gosimple -E interfacer -E lll -E misspell -E nlreturn -E nolintlint -E stylecheck -E unconvert -E whitespace -E wsl --skip-dirs services/store/sqlstore/migrations/ ./...
server-test:
cd server; go test ./...
server-doc:
cd server; go doc ./...
watch-server:
cd server; modd

View File

@ -83,7 +83,7 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) {
}()
var blocks []model.Block
err = json.Unmarshal([]byte(requestBody), &blocks)
err = json.Unmarshal(requestBody, &blocks)
if err != nil {
errorResponse(w, http.StatusInternalServerError, ``)
return
@ -95,15 +95,16 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) {
errorResponse(w, http.StatusInternalServerError, fmt.Sprintf(`{"description": "missing type", "id": "%s"}`, block.ID))
return
}
if block.CreateAt < 1 {
errorResponse(w, http.StatusInternalServerError, fmt.Sprintf(`{"description": "invalid createAt", "id": "%s"}`, block.ID))
return
}
if block.UpdateAt < 1 {
errorResponse(w, http.StatusInternalServerError, fmt.Sprintf(`{"description": "invalid updateAt", "id": "%s"}`, block.ID))
return
}
}
err = a.app().InsertBlocks(blocks)
@ -190,7 +191,7 @@ func (a *API) handleImport(w http.ResponseWriter, r *http.Request) {
}()
var blocks []model.Block
err = json.Unmarshal([]byte(requestBody), &blocks)
err = json.Unmarshal(requestBody, &blocks)
if err != nil {
errorResponse(w, http.StatusInternalServerError, ``)
return
@ -229,6 +230,7 @@ func (a *API) handleServeFile(w http.ResponseWriter, r *http.Request) {
func (a *API) handleUploadFile(w http.ResponseWriter, r *http.Request) {
fmt.Println(`handleUploadFile`)
file, handle, err := r.FormFile("file")
if err != nil {
fmt.Fprintf(w, "%v", err)
@ -243,6 +245,7 @@ func (a *API) handleUploadFile(w http.ResponseWriter, r *http.Request) {
jsonStringResponse(w, http.StatusInternalServerError, `{}`)
return
}
log.Printf(`saveFile, url: %s`, url)
json := fmt.Sprintf(`{ "url": "%s" }`, url)
jsonStringResponse(w, http.StatusOK, json)

View File

@ -10,10 +10,10 @@ import (
type App struct {
config *config.Configuration
store store.Store
wsServer *ws.WSServer
wsServer *ws.Server
filesBackend filesstore.FileBackend
}
func New(config *config.Configuration, store store.Store, wsServer *ws.WSServer, filesBackend filesstore.FileBackend) *App {
func New(config *config.Configuration, store store.Store, wsServer *ws.Server, filesBackend filesstore.FileBackend) *App {
return &App{config: config, store: store, wsServer: wsServer, filesBackend: filesBackend}
}

View File

@ -24,12 +24,14 @@ func (a *App) InsertBlock(block model.Block) error {
func (a *App) InsertBlocks(blocks []model.Block) error {
var blockIDsToNotify = []string{}
uniqueBlockIDs := make(map[string]bool)
for _, block := range blocks {
if !uniqueBlockIDs[block.ID] {
blockIDsToNotify = append(blockIDsToNotify, block.ID)
}
if len(block.ParentID) > 0 && !uniqueBlockIDs[block.ParentID] {
blockIDsToNotify = append(blockIDsToNotify, block.ParentID)
}

View File

@ -16,14 +16,16 @@ func TestGetParentID(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
store := mockstore.NewMockStore(ctrl)
wsserver := ws.NewWSServer()
wsserver := ws.NewServer()
app := New(&config.Configuration{}, store, wsserver, &mocks.FileBackend{})
t.Run("success query", func(t *testing.T) {
store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("test-parent-id", nil)
result, err := app.GetParentID("test-id")
require.NoError(t, err)
require.Equal(t, "test-parent-id", result)
})
t.Run("fail query", func(t *testing.T) {
store.EXPECT().GetParentID(gomock.Eq("test-id")).Return("", errors.New("block-not-found"))
_, err := app.GetParentID("test-id")

View File

@ -15,6 +15,10 @@ import (
// ----------------------------------------------------------------------------------------------------
// WebSocket OnChange listener
const (
timeBetweenPidMonitoringChecks = 2 * time.Second
)
func isProcessRunning(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
@ -27,13 +31,15 @@ func isProcessRunning(pid int) bool {
func monitorPid(pid int) {
log.Printf("Monitoring PID: %d", pid)
go func() {
for {
if !isProcessRunning(pid) {
log.Printf("Monitored process not found, exiting.")
os.Exit(1)
}
time.Sleep(2 * time.Second)
time.Sleep(timeBetweenPidMonitoringChecks)
}
}()
}

View File

@ -26,11 +26,11 @@ const CurrentVersion = "0.0.1"
type Server struct {
config *config.Configuration
wsServer *ws.WSServer
wsServer *ws.Server
webServer *web.WebServer
store store.Store
filesBackend filesstore.FileBackend
telemetry *telemetry.TelemetryService
telemetry *telemetry.Service
logger *zap.Logger
}
@ -46,7 +46,7 @@ func New(cfg *config.Configuration) (*Server, error) {
return nil, err
}
wsServer := ws.NewWSServer()
wsServer := ws.NewServer()
filesBackendSettings := model.FileSettings{}
filesBackendSettings.SetDefaults(false)
@ -67,11 +67,13 @@ func New(cfg *config.Configuration) (*Server, error) {
// Ctrl+C handling
handler := make(chan os.Signal, 1)
signal.Notify(handler, os.Interrupt)
go func() {
for sig := range handler {
// sig is a ^C, handle it
if sig == os.Interrupt {
os.Exit(1)
break
}
}

View File

@ -41,9 +41,7 @@ func createTask(name string, function TaskFunc, interval time.Duration, recurrin
defer close(task.cancelled)
ticker := time.NewTicker(interval)
defer func() {
ticker.Stop()
}()
defer ticker.Stop()
for {
select {

View File

@ -12,67 +12,67 @@ import (
)
func TestCreateTask(t *testing.T) {
TASK_NAME := "Test Task"
TASK_TIME := time.Millisecond * 200
TASK_WAIT := time.Millisecond * 100
taskName := "Test Task"
taskTime := time.Millisecond * 200
taskWait := time.Millisecond * 100
executionCount := new(int32)
testFunc := func() {
atomic.AddInt32(executionCount, 1)
}
task := CreateTask(TASK_NAME, testFunc, TASK_TIME)
task := CreateTask(taskName, testFunc, taskTime)
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
time.Sleep(TASK_TIME + TASK_WAIT)
time.Sleep(taskTime + taskWait)
assert.EqualValues(t, 1, atomic.LoadInt32(executionCount))
assert.Equal(t, TASK_NAME, task.Name)
assert.Equal(t, TASK_TIME, task.Interval)
assert.Equal(t, taskName, task.Name)
assert.Equal(t, taskTime, task.Interval)
assert.False(t, task.Recurring)
}
func TestCreateRecurringTask(t *testing.T) {
TASK_NAME := "Test Recurring Task"
TASK_TIME := time.Millisecond * 200
TASK_WAIT := time.Millisecond * 100
taskName := "Test Recurring Task"
taskTime := time.Millisecond * 200
taskWait := time.Millisecond * 100
executionCount := new(int32)
testFunc := func() {
atomic.AddInt32(executionCount, 1)
}
task := CreateRecurringTask(TASK_NAME, testFunc, TASK_TIME)
task := CreateRecurringTask(taskName, testFunc, taskTime)
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
time.Sleep(TASK_TIME + TASK_WAIT)
time.Sleep(taskTime + taskWait)
assert.EqualValues(t, 1, atomic.LoadInt32(executionCount))
time.Sleep(TASK_TIME)
time.Sleep(taskTime)
assert.EqualValues(t, 2, atomic.LoadInt32(executionCount))
assert.Equal(t, TASK_NAME, task.Name)
assert.Equal(t, TASK_TIME, task.Interval)
assert.Equal(t, taskName, task.Name)
assert.Equal(t, taskTime, task.Interval)
assert.True(t, task.Recurring)
task.Cancel()
}
func TestCancelTask(t *testing.T) {
TASK_NAME := "Test Task"
TASK_TIME := time.Millisecond * 100
TASK_WAIT := time.Millisecond * 100
taskName := "Test Task"
taskTime := time.Millisecond * 100
taskWait := time.Millisecond * 100
executionCount := new(int32)
testFunc := func() {
atomic.AddInt32(executionCount, 1)
}
task := CreateTask(TASK_NAME, testFunc, TASK_TIME)
task := CreateTask(taskName, testFunc, taskTime)
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
task.Cancel()
time.Sleep(TASK_TIME + TASK_WAIT)
time.Sleep(taskTime + taskWait)
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
}

View File

@ -17,7 +17,8 @@ func (s *SQLStore) latestsBlocksSubquery() sq.SelectBuilder {
}
func (s *SQLStore) GetBlocksWithParentAndType(parentID string, blockType string) ([]model.Block, error) {
query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
query := s.getQueryBuilder().
Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
FromSelect(s.latestsBlocksSubquery(), "latest").
Where(sq.Eq{"delete_at": 0}).
Where(sq.Eq{"parent_id": parentID}).
@ -25,6 +26,7 @@ func (s *SQLStore) GetBlocksWithParentAndType(parentID string, blockType string)
rows, err := query.Query()
if err != nil {
log.Printf(`getBlocksWithParentAndType ERROR: %v`, err)
return nil, err
}
@ -32,7 +34,8 @@ func (s *SQLStore) GetBlocksWithParentAndType(parentID string, blockType string)
}
func (s *SQLStore) GetBlocksWithParent(parentID string) ([]model.Block, error) {
query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
query := s.getQueryBuilder().
Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
FromSelect(s.latestsBlocksSubquery(), "latest").
Where(sq.Eq{"delete_at": 0}).
Where(sq.Eq{"parent_id": parentID})
@ -47,7 +50,8 @@ func (s *SQLStore) GetBlocksWithParent(parentID string) ([]model.Block, error) {
}
func (s *SQLStore) GetBlocksWithType(blockType string) ([]model.Block, error) {
query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
query := s.getQueryBuilder().
Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
FromSelect(s.latestsBlocksSubquery(), "latest").
Where(sq.Eq{"delete_at": 0}).
Where(sq.Eq{"type": blockType})
@ -61,7 +65,8 @@ func (s *SQLStore) GetBlocksWithType(blockType string) ([]model.Block, error) {
}
func (s *SQLStore) GetSubTree(blockID string) ([]model.Block, error) {
query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
query := s.getQueryBuilder().
Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
FromSelect(s.latestsBlocksSubquery(), "latest").
Where(sq.Eq{"delete_at": 0}).
Where(sq.Or{sq.Eq{"id": blockID}, sq.Eq{"parent_id": blockID}})
@ -76,7 +81,8 @@ func (s *SQLStore) GetSubTree(blockID string) ([]model.Block, error) {
}
func (s *SQLStore) GetAllBlocks() ([]model.Block, error) {
query := s.getQueryBuilder().Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
query := s.getQueryBuilder().
Select("id", "parent_id", "schema", "type", "title", "COALESCE(\"fields\", '{}')", "create_at", "update_at", "delete_at").
FromSelect(s.latestsBlocksSubquery(), "latest").
Where(sq.Eq{"delete_at": 0})
@ -97,6 +103,7 @@ func blocksFromRows(rows *sql.Rows) ([]model.Block, error) {
for rows.Next() {
var block model.Block
var fieldsJSON string
err := rows.Scan(
&block.ID,
&block.ParentID,

View File

@ -1,6 +1,8 @@
package sqlstore
import (
"errors"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/postgres"
@ -13,28 +15,24 @@ import (
)
func (s *SQLStore) Migrate() error {
var bresource *bindata.AssetSource
var driver database.Driver
var err error
var bresource *bindata.AssetSource
if s.dbType == "sqlite3" {
driver, err = sqlite3.WithInstance(s.db, &sqlite3.Config{})
if err != nil {
return err
}
bresource = bindata.Resource(sqlite.AssetNames(),
func(name string) ([]byte, error) {
return sqlite.Asset(name)
})
bresource = bindata.Resource(sqlite.AssetNames(), sqlite.Asset)
}
if s.dbType == "postgres" {
driver, err = postgres.WithInstance(s.db, &postgres.Config{})
if err != nil {
return err
}
bresource = bindata.Resource(pgmigrations.AssetNames(),
func(name string) ([]byte, error) {
return pgmigrations.Asset(name)
})
bresource = bindata.Resource(pgmigrations.AssetNames(), pgmigrations.Asset)
}
d, err := bindata.WithInstance(bresource)
@ -48,8 +46,9 @@ func (s *SQLStore) Migrate() error {
}
err = m.Up()
if err != nil && err != migrate.ErrNoChange {
if err != nil && errors.Is(err, migrate.ErrNoChange) {
return err
}
return nil
}

View File

@ -14,19 +14,15 @@ import (
)
const (
DAY_MILLISECONDS = 24 * 60 * 60 * 1000
MONTH_MILLISECONDS = 31 * DAY_MILLISECONDS
RUDDER_KEY = "placeholder_rudder_key"
RUDDER_DATAPLANE_URL = "placeholder_rudder_dataplane_url"
TRACK_CONFIG = "config"
rudderKey = "placeholder_rudder_key"
rudderDataplaneURL = "placeholder_rudder_dataplane_url"
timeBetweenTelemetryChecks = 10 * time.Minute
)
type telemetryTracker func() map[string]interface{}
type Tracker func() map[string]interface{}
type TelemetryService struct {
trackers map[string]telemetryTracker
type Service struct {
trackers map[string]Tracker
log *log.Logger
rudderClient rudder.Client
telemetryID string
@ -35,25 +31,25 @@ type TelemetryService struct {
type RudderConfig struct {
RudderKey string
DataplaneUrl string
DataplaneURL string
}
func New(telemetryID string, log *log.Logger) *TelemetryService {
service := &TelemetryService{
func New(telemetryID string, log *log.Logger) *Service {
service := &Service{
log: log,
telemetryID: telemetryID,
trackers: map[string]telemetryTracker{},
trackers: map[string]Tracker{},
}
return service
}
func (ts *TelemetryService) RegisterTracker(name string, tracker telemetryTracker) {
func (ts *Service) RegisterTracker(name string, tracker Tracker) {
ts.trackers[name] = tracker
}
func (ts *TelemetryService) getRudderConfig() RudderConfig {
if !strings.Contains(RUDDER_KEY, "placeholder") && !strings.Contains(RUDDER_DATAPLANE_URL, "placeholder") {
return RudderConfig{RUDDER_KEY, RUDDER_DATAPLANE_URL}
func (ts *Service) getRudderConfig() RudderConfig {
if !strings.Contains(rudderKey, "placeholder") && !strings.Contains(rudderDataplaneURL, "placeholder") {
return RudderConfig{rudderKey, rudderDataplaneURL}
} else if os.Getenv("RUDDER_KEY") != "" && os.Getenv("RUDDER_DATAPLANE_URL") != "" {
return RudderConfig{os.Getenv("RUDDER_KEY"), os.Getenv("RUDDER_DATAPLANE_URL")}
} else {
@ -61,17 +57,18 @@ func (ts *TelemetryService) getRudderConfig() RudderConfig {
}
}
func (ts *TelemetryService) sendDailyTelemetry(override bool) {
func (ts *Service) sendDailyTelemetry(override bool) {
config := ts.getRudderConfig()
if (config.DataplaneUrl != "" && config.RudderKey != "") || override {
ts.initRudder(config.DataplaneUrl, config.RudderKey)
if (config.DataplaneURL != "" && config.RudderKey != "") || override {
ts.initRudder(config.DataplaneURL, config.RudderKey)
for name, tracker := range ts.trackers {
ts.sendTelemetry(name, tracker())
}
}
}
func (ts *TelemetryService) sendTelemetry(event string, properties map[string]interface{}) {
func (ts *Service) sendTelemetry(event string, properties map[string]interface{}) {
if ts.rudderClient != nil {
var context *rudder.Context
ts.rudderClient.Enqueue(rudder.Track{
@ -83,13 +80,13 @@ func (ts *TelemetryService) sendTelemetry(event string, properties map[string]in
}
}
func (ts *TelemetryService) initRudder(endpoint string, rudderKey string) {
func (ts *Service) initRudder(endpoint string, rudderKey string) {
if ts.rudderClient == nil {
config := rudder.Config{}
config.Logger = rudder.StdLogger(ts.log)
config.Endpoint = endpoint
// For testing
if endpoint != RUDDER_DATAPLANE_URL {
if endpoint != rudderDataplaneURL {
config.Verbose = true
config.BatchSize = 1
}
@ -106,7 +103,7 @@ func (ts *TelemetryService) initRudder(endpoint string, rudderKey string) {
}
}
func (ts *TelemetryService) doTelemetryIfNeeded(firstRun time.Time) {
func (ts *Service) doTelemetryIfNeeded(firstRun time.Time) {
hoursSinceFirstServerRun := time.Since(firstRun).Hours()
// Send once every 10 minutes for the first hour
// Send once every hour thereafter for the first 12 hours
@ -120,21 +117,21 @@ func (ts *TelemetryService) doTelemetryIfNeeded(firstRun time.Time) {
}
}
func (ts *TelemetryService) RunTelemetryJob(firstRun int64) {
func (ts *Service) RunTelemetryJob(firstRun int64) {
// Send on boot
ts.doTelemetry()
scheduler.CreateRecurringTask("Telemetry", func() {
ts.doTelemetryIfNeeded(time.Unix(0, firstRun*int64(time.Millisecond)))
}, time.Minute*10)
}, timeBetweenTelemetryChecks)
}
func (ts *TelemetryService) doTelemetry() {
func (ts *Service) doTelemetry() {
ts.timestampLastTelemetrySent = time.Now()
ts.sendDailyTelemetry(false)
}
// Shutdown closes the telemetry client.
func (ts *TelemetryService) Shutdown() error {
func (ts *Service) Shutdown() error {
if ts.rudderClient != nil {
return ts.rudderClient.Close()
}

View File

@ -53,19 +53,23 @@ func (ws *WebServer) Start() error {
urlPort := fmt.Sprintf(`:%d`, ws.port)
var isSSL = ws.ssl && fileExists("./cert/cert.pem") && fileExists("./cert/key.pem")
if isSSL {
log.Println("https server started on ", urlPort)
err := http.ListenAndServeTLS(urlPort, "./cert/cert.pem", "./cert/key.pem", nil)
if err != nil {
return err
}
return nil
}
log.Println("http server started on ", urlPort)
err := http.ListenAndServe(urlPort, nil)
if err != nil {
return err
}
return nil
}
@ -75,5 +79,6 @@ func fileExists(path string) bool {
if os.IsNotExist(err) {
return false
}
return err == nil
}

View File

@ -11,27 +11,29 @@ import (
)
// RegisterRoutes registeres routes
func (ws *WSServer) RegisterRoutes(r *mux.Router) {
func (ws *Server) RegisterRoutes(r *mux.Router) {
r.HandleFunc("/ws/onchange", ws.handleWebSocketOnChange)
}
// AddListener adds a listener for a block's change
func (ws *WSServer) AddListener(client *websocket.Conn, blockIDs []string) {
func (ws *Server) AddListener(client *websocket.Conn, blockIDs []string) {
ws.mu.Lock()
for _, blockID := range blockIDs {
if ws.listeners[blockID] == nil {
ws.listeners[blockID] = []*websocket.Conn{}
}
ws.listeners[blockID] = append(ws.listeners[blockID], client)
}
ws.mu.Unlock()
}
// RemoveListener removes a webSocket listener from all blocks
func (ws *WSServer) RemoveListener(client *websocket.Conn) {
func (ws *Server) RemoveListener(client *websocket.Conn) {
ws.mu.Lock()
for key, clients := range ws.listeners {
var listeners = []*websocket.Conn{}
for _, existingClient := range clients {
if client != existingClient {
listeners = append(listeners, existingClient)
@ -43,7 +45,7 @@ func (ws *WSServer) RemoveListener(client *websocket.Conn) {
}
// RemoveListenerFromBlocks removes a webSocket listener from a set of block
func (ws *WSServer) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) {
func (ws *Server) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []string) {
ws.mu.Lock()
for _, blockID := range blockIDs {
@ -58,6 +60,7 @@ func (ws *WSServer) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []
if client == listener {
newListeners := append(listeners[:index], listeners[index+1:]...)
ws.listeners[blockID] = newListeners
break
}
}
@ -67,7 +70,7 @@ func (ws *WSServer) RemoveListenerFromBlocks(client *websocket.Conn, blockIDs []
}
// GetListeners returns the listeners to a blockID's changes
func (ws *WSServer) GetListeners(blockID string) []*websocket.Conn {
func (ws *Server) GetListeners(blockID string) []*websocket.Conn {
ws.mu.Lock()
listeners := ws.listeners[blockID]
ws.mu.Unlock()
@ -75,16 +78,16 @@ func (ws *WSServer) GetListeners(blockID string) []*websocket.Conn {
return listeners
}
// WSServer is a WebSocket server
type WSServer struct {
// Server is a WebSocket server
type Server struct {
upgrader websocket.Upgrader
listeners map[string][]*websocket.Conn
mu sync.RWMutex
}
// NewWSServer creates a new WSServer
func NewWSServer() *WSServer {
return &WSServer{
// NewServer creates a new Server
func NewServer() *Server {
return &Server{
listeners: make(map[string][]*websocket.Conn),
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
@ -106,7 +109,7 @@ type WebsocketCommand struct {
BlockIDs []string `json:"blockIds"`
}
func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) {
func (ws *Server) handleWebSocketOnChange(w http.ResponseWriter, r *http.Request) {
// Upgrade initial GET request to a websocket
client, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil {
@ -133,6 +136,7 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque
if err != nil {
log.Printf("ERROR WebSocket onChange, client: %s, err: %v", client.RemoteAddr(), err)
ws.RemoveListener(client)
break
}
@ -141,6 +145,7 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque
if err != nil {
// handle this error
log.Printf(`ERROR webSocket parsing command JSON: %v`, string(p))
continue
}
@ -148,9 +153,11 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque
case "ADD":
log.Printf(`Command: Add blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.AddListener(client, command.BlockIDs)
case "REMOVE":
log.Printf(`Command: Remove blockID: %v, client: %s`, command.BlockIDs, client.RemoteAddr())
ws.RemoveListenerFromBlocks(client, command.BlockIDs)
default:
log.Printf(`ERROR webSocket command, invalid action: %v`, command.Action)
}
@ -158,7 +165,7 @@ func (ws *WSServer) handleWebSocketOnChange(w http.ResponseWriter, r *http.Reque
}
// BroadcastBlockChangeToWebsocketClients broadcasts change to clients
func (ws *WSServer) BroadcastBlockChangeToWebsocketClients(blockIDs []string) {
func (ws *Server) BroadcastBlockChangeToWebsocketClients(blockIDs []string) {
for _, blockID := range blockIDs {
listeners := ws.GetListeners(blockID)
log.Printf("%d listener(s) for blockID: %s", len(listeners), blockID)
@ -168,6 +175,7 @@ func (ws *WSServer) BroadcastBlockChangeToWebsocketClients(blockIDs []string) {
Action: "UPDATE_BLOCK",
BlockID: blockID,
}
for _, listener := range listeners {
log.Printf("Broadcast change, blockID: %s, remoteAddr: %s", blockID, listener.RemoteAddr())
err := listener.WriteJSON(message)