1
0
mirror of https://github.com/mattermost/focalboard.git synced 2025-10-31 00:17:42 +02:00

fix remaining golangci linter warnings (#686)

* fix remaining linter warnings
This commit is contained in:
Doug Lauder
2021-07-08 21:09:02 -04:00
committed by GitHub
parent 0531d2eefc
commit ebd477464b
23 changed files with 215 additions and 445 deletions

View File

@@ -56,7 +56,7 @@ func (p *Plugin) OnActivate() error {
mmconfig := p.API.GetUnsanitizedConfig() mmconfig := p.API.GetUnsanitizedConfig()
filesS3Config := config.AmazonS3Config{} filesS3Config := config.AmazonS3Config{}
if mmconfig.FileSettings.AmazonS3AccessKeyId != nil { if mmconfig.FileSettings.AmazonS3AccessKeyId != nil {
filesS3Config.AccessKeyId = *mmconfig.FileSettings.AmazonS3AccessKeyId filesS3Config.AccessKeyID = *mmconfig.FileSettings.AmazonS3AccessKeyId
} }
if mmconfig.FileSettings.AmazonS3SecretAccessKey != nil { if mmconfig.FileSettings.AmazonS3SecretAccessKey != nil {
filesS3Config.SecretAccessKey = *mmconfig.FileSettings.AmazonS3SecretAccessKey filesS3Config.SecretAccessKey = *mmconfig.FileSettings.AmazonS3SecretAccessKey

View File

@@ -43,11 +43,10 @@ linters:
- depguard - depguard
- dogsled - dogsled
- dupl - dupl
- gochecknoinits
- goconst - goconst
- gocritic - gocritic
- godot - godot
# - goerr113 - goerr113
- goheader - goheader
- golint - golint
# - gomnd # - gomnd

View File

@@ -2,7 +2,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@@ -31,6 +30,14 @@ const (
ErrorNoWorkspaceMessage = "No workspace" ErrorNoWorkspaceMessage = "No workspace"
) )
type PermissionError struct {
msg string
}
func (pe PermissionError) Error() string {
return pe.msg
}
// ---------------------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------------------
// REST APIs // REST APIs
@@ -128,7 +135,7 @@ func (a *API) hasValidReadTokenForBlock(r *http.Request, container store.Contain
func (a *API) getContainerAllowingReadTokenForBlock(r *http.Request, blockID string) (*store.Container, error) { func (a *API) getContainerAllowingReadTokenForBlock(r *http.Request, blockID string) (*store.Container, error) {
ctx := r.Context() ctx := r.Context()
session, _ := ctx.Value("session").(*model.Session) session, _ := ctx.Value(sessionContextKey).(*model.Session)
if a.MattermostAuth { if a.MattermostAuth {
// Workspace auth // Workspace auth
@@ -153,7 +160,7 @@ func (a *API) getContainerAllowingReadTokenForBlock(r *http.Request, blockID str
return &container, nil return &container, nil
} }
return nil, errors.New("access denied to workspace") return nil, PermissionError{"access denied to workspace"}
} }
// Native auth: always use root workspace // Native auth: always use root workspace
@@ -171,7 +178,7 @@ func (a *API) getContainerAllowingReadTokenForBlock(r *http.Request, blockID str
return &container, nil return &container, nil
} }
return nil, errors.New("access denied to workspace") return nil, PermissionError{"access denied to workspace"}
} }
func (a *API) getContainer(r *http.Request) (*store.Container, error) { func (a *API) getContainer(r *http.Request) (*store.Container, error) {
@@ -256,7 +263,7 @@ func (a *API) handleGetBlocks(w http.ResponseWriter, r *http.Request) {
func stampModificationMetadata(r *http.Request, blocks []model.Block, auditRec *audit.Record) { func stampModificationMetadata(r *http.Request, blocks []model.Block, auditRec *audit.Record) {
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
userID := session.UserID userID := session.UserID
if userID == SingleUser { if userID == SingleUser {
userID = "" userID = ""
@@ -352,7 +359,7 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) {
stampModificationMetadata(r, blocks, auditRec) stampModificationMetadata(r, blocks, auditRec)
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
err = a.app.InsertBlocks(*container, blocks, session.UserID) err = a.app.InsertBlocks(*container, blocks, session.UserID)
if err != nil { if err != nil {
@@ -437,7 +444,7 @@ func (a *API) handleGetMe(w http.ResponseWriter, r *http.Request) {
// "$ref": "#/definitions/ErrorResponse" // "$ref": "#/definitions/ErrorResponse"
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
var user *model.User var user *model.User
var err error var err error
@@ -503,7 +510,7 @@ func (a *API) handleDeleteBlock(w http.ResponseWriter, r *http.Request) {
// "$ref": "#/definitions/ErrorResponse" // "$ref": "#/definitions/ErrorResponse"
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
userID := session.UserID userID := session.UserID
vars := mux.Vars(r) vars := mux.Vars(r)
@@ -782,7 +789,7 @@ func (a *API) handleImport(w http.ResponseWriter, r *http.Request) {
stampModificationMetadata(r, blocks, auditRec) stampModificationMetadata(r, blocks, auditRec)
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
err = a.app.InsertBlocks(*container, blocks, session.UserID) err = a.app.InsertBlocks(*container, blocks, session.UserID)
if err != nil { if err != nil {
a.errorResponse(w, http.StatusInternalServerError, "", err) a.errorResponse(w, http.StatusInternalServerError, "", err)
@@ -931,7 +938,7 @@ func (a *API) handlePostSharing(w http.ResponseWriter, r *http.Request) {
// Stamp ModifiedBy // Stamp ModifiedBy
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
userID := session.UserID userID := session.UserID
if userID == SingleUser { if userID == SingleUser {
userID = "" userID = ""
@@ -986,7 +993,7 @@ func (a *API) handleGetWorkspace(w http.ResponseWriter, r *http.Request) {
workspaceID := vars["workspaceID"] workspaceID := vars["workspaceID"]
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
if !a.app.DoesUserHaveWorkspaceAccess(session.UserID, workspaceID) { if !a.app.DoesUserHaveWorkspaceAccess(session.UserID, workspaceID) {
a.errorResponse(w, http.StatusUnauthorized, "", nil) a.errorResponse(w, http.StatusUnauthorized, "", nil)
return return
@@ -1264,9 +1271,9 @@ func (a *API) getWorkspaceUsers(w http.ResponseWriter, r *http.Request) {
workspaceID := vars["workspaceID"] workspaceID := vars["workspaceID"]
ctx := r.Context() ctx := r.Context()
session := ctx.Value("session").(*model.Session) session := ctx.Value(sessionContextKey).(*model.Session)
if !a.app.DoesUserHaveWorkspaceAccess(session.UserID, workspaceID) { if !a.app.DoesUserHaveWorkspaceAccess(session.UserID, workspaceID) {
a.errorResponse(w, http.StatusForbidden, "Access denied to workspace", errors.New("access denied to workspace")) a.errorResponse(w, http.StatusForbidden, "Access denied to workspace", PermissionError{"access denied to workspace"})
return return
} }

View File

@@ -12,7 +12,7 @@ func (a *API) makeAuditRecord(r *http.Request, event string, initialStatus strin
ctx := r.Context() ctx := r.Context()
var sessionID string var sessionID string
var userID string var userID string
if session, ok := ctx.Value("session").(*model.Session); ok { if session, ok := ctx.Value(sessionContextKey).(*model.Session); ok {
sessionID = session.ID sessionID = session.ID
userID = session.UserID userID = session.UserID
} }

View File

@@ -3,7 +3,6 @@ package api
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@@ -12,7 +11,6 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
serverContext "github.com/mattermost/focalboard/server/context"
"github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/audit" "github.com/mattermost/focalboard/server/services/audit"
"github.com/mattermost/focalboard/server/services/auth" "github.com/mattermost/focalboard/server/services/auth"
@@ -23,6 +21,14 @@ const (
MinimumPasswordLength = 8 MinimumPasswordLength = 8
) )
type ParamError struct {
msg string
}
func (pe ParamError) Error() string {
return pe.msg
}
// LoginRequest is a login request // LoginRequest is a login request
// swagger:model // swagger:model
type LoginRequest struct { type LoginRequest struct {
@@ -78,16 +84,16 @@ type RegisterRequest struct {
func (rd *RegisterRequest) IsValid() error { func (rd *RegisterRequest) IsValid() error {
if strings.TrimSpace(rd.Username) == "" { if strings.TrimSpace(rd.Username) == "" {
return errors.New("username is required") return ParamError{"username is required"}
} }
if strings.TrimSpace(rd.Email) == "" { if strings.TrimSpace(rd.Email) == "" {
return errors.New("email is required") return ParamError{"email is required"}
} }
if !auth.IsEmailValid(rd.Email) { if !auth.IsEmailValid(rd.Email) {
return errors.New("invalid email format") return ParamError{"invalid email format"}
} }
if rd.Password == "" { if rd.Password == "" {
return errors.New("password is required") return ParamError{"password is required"}
} }
if err := isValidPassword(rd.Password); err != nil { if err := isValidPassword(rd.Password); err != nil {
return err return err
@@ -110,10 +116,10 @@ type ChangePasswordRequest struct {
// IsValid validates a password change request. // IsValid validates a password change request.
func (rd *ChangePasswordRequest) IsValid() error { func (rd *ChangePasswordRequest) IsValid() error {
if rd.OldPassword == "" { if rd.OldPassword == "" {
return errors.New("old password is required") return ParamError{"old password is required"}
} }
if rd.NewPassword == "" { if rd.NewPassword == "" {
return errors.New("new password is required") return ParamError{"new password is required"}
} }
if err := isValidPassword(rd.NewPassword); err != nil { if err := isValidPassword(rd.NewPassword); err != nil {
return err return err
@@ -124,7 +130,7 @@ func (rd *ChangePasswordRequest) IsValid() error {
func isValidPassword(password string) error { func isValidPassword(password string) error {
if len(password) < MinimumPasswordLength { if len(password) < MinimumPasswordLength {
return fmt.Errorf("password must be at least %d characters", MinimumPasswordLength) return ParamError{fmt.Sprintf("password must be at least %d characters", MinimumPasswordLength)}
} }
return nil return nil
} }
@@ -387,7 +393,7 @@ func (a *API) attachSession(handler func(w http.ResponseWriter, r *http.Request)
CreateAt: now, CreateAt: now,
UpdateAt: now, UpdateAt: now,
} }
ctx := context.WithValue(r.Context(), "session", session) ctx := context.WithValue(r.Context(), sessionContextKey, session)
handler(w, r.WithContext(ctx)) handler(w, r.WithContext(ctx))
return return
} }
@@ -404,7 +410,7 @@ func (a *API) attachSession(handler func(w http.ResponseWriter, r *http.Request)
CreateAt: now, CreateAt: now,
UpdateAt: now, UpdateAt: now,
} }
ctx := context.WithValue(r.Context(), "session", session) ctx := context.WithValue(r.Context(), sessionContextKey, session)
handler(w, r.WithContext(ctx)) handler(w, r.WithContext(ctx))
return return
} }
@@ -431,7 +437,7 @@ func (a *API) attachSession(handler func(w http.ResponseWriter, r *http.Request)
return return
} }
ctx := context.WithValue(r.Context(), "session", session) ctx := context.WithValue(r.Context(), sessionContextKey, session)
handler(w, r.WithContext(ctx)) handler(w, r.WithContext(ctx))
} }
} }
@@ -439,7 +445,7 @@ func (a *API) attachSession(handler func(w http.ResponseWriter, r *http.Request)
func (a *API) adminRequired(handler func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { func (a *API) adminRequired(handler func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// Currently, admin APIs require local unix connections // Currently, admin APIs require local unix connections
conn := serverContext.GetContextConn(r) conn := GetContextConn(r)
if _, isUnix := conn.(*net.UnixConn); !isUnix { if _, isUnix := conn.(*net.UnixConn); !isUnix {
a.errorResponse(w, http.StatusUnauthorized, "", nil) a.errorResponse(w, http.StatusUnauthorized, "", nil)
return return

View File

@@ -1,4 +1,4 @@
package context package api
import ( import (
"context" "context"
@@ -6,20 +6,21 @@ import (
"net/http" "net/http"
) )
type contextKey struct { type contextKey int
key string
}
var connContextKey = &contextKey{"http-conn"} const (
httpConnContextKey contextKey = iota
sessionContextKey
)
// SetContextConn stores the connection in the request context. // SetContextConn stores the connection in the request context.
func SetContextConn(ctx context.Context, c net.Conn) context.Context { func SetContextConn(ctx context.Context, c net.Conn) context.Context {
return context.WithValue(ctx, connContextKey, c) return context.WithValue(ctx, httpConnContextKey, c)
} }
// GetContextConn gets the stored connection from the request context. // GetContextConn gets the stored connection from the request context.
func GetContextConn(r *http.Request) net.Conn { func GetContextConn(r *http.Request) net.Conn {
value := r.Context().Value(connContextKey) value := r.Context().Value(httpConnContextKey)
if value == nil { if value == nil {
return nil return nil
} }

View File

@@ -1,15 +1,23 @@
package app package app
import ( import (
"errors"
"github.com/mattermost/focalboard/server/model"
"testing" "testing"
"github.com/mattermost/focalboard/server/model"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
st "github.com/mattermost/focalboard/server/services/store" st "github.com/mattermost/focalboard/server/services/store"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type blockError struct {
msg string
}
func (be blockError) Error() string {
return be.msg
}
func TestGetParentID(t *testing.T) { func TestGetParentID(t *testing.T) {
th := SetupTestHelper(t) th := SetupTestHelper(t)
@@ -24,10 +32,10 @@ func TestGetParentID(t *testing.T) {
}) })
t.Run("fail query", func(t *testing.T) { t.Run("fail query", func(t *testing.T) {
th.Store.EXPECT().GetParentID(gomock.Eq(container), gomock.Eq("test-id")).Return("", errors.New("block-not-found")) th.Store.EXPECT().GetParentID(gomock.Eq(container), gomock.Eq("test-id")).Return("", blockError{"block-not-found"})
_, err := th.App.GetParentID(container, "test-id") _, err := th.App.GetParentID(container, "test-id")
require.Error(t, err) require.Error(t, err)
require.Equal(t, "block-not-found", err.Error()) require.ErrorIs(t, err, blockError{"block-not-found"})
}) })
} }
@@ -47,8 +55,8 @@ func TestInsertBlock(t *testing.T) {
t.Run("error scenerio", func(t *testing.T) { t.Run("error scenerio", func(t *testing.T) {
block := model.Block{} block := model.Block{}
th.Store.EXPECT().InsertBlock(gomock.Eq(container), gomock.Eq(&block), gomock.Eq("user-id-1")).Return(errors.New("dummy error")) th.Store.EXPECT().InsertBlock(gomock.Eq(container), gomock.Eq(&block), gomock.Eq("user-id-1")).Return(blockError{"error"})
err := th.App.InsertBlock(container, block, "user-id-1") err := th.App.InsertBlock(container, block, "user-id-1")
require.Error(t, err, "dummy error") require.Error(t, err, "error")
}) })
} }

View File

@@ -1,7 +1,6 @@
package app package app
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"path/filepath" "path/filepath"
@@ -25,7 +24,7 @@ func (a *App) SaveFile(reader io.Reader, workspaceID, rootID, filename string) (
_, appErr := a.filesBackend.WriteFile(reader, filePath) _, appErr := a.filesBackend.WriteFile(reader, filePath)
if appErr != nil { if appErr != nil {
return "", errors.New("unable to store the file in the files storage") return "", fmt.Errorf("unable to store the file in the files storage: %w", appErr)
} }
return createdFilename, nil return createdFilename, nil

View File

@@ -16,6 +16,14 @@ const (
APIURLSuffix = "/api/v1" APIURLSuffix = "/api/v1"
) )
type RequestReaderError struct {
buf []byte
}
func (rre RequestReaderError) Error() string {
return "payload: " + string(rre.buf)
}
type Response struct { type Response struct {
StatusCode int StatusCode int
Error error Error error
@@ -131,7 +139,7 @@ func (c *Client) doAPIRequestReader(method, url string, data io.Reader, _ /* eta
if err != nil { if err != nil {
return rp, fmt.Errorf("error when parsing response with code %d: %w", rp.StatusCode, err) return rp, fmt.Errorf("error when parsing response with code %d: %w", rp.StatusCode, err)
} }
return rp, fmt.Errorf(string(b)) return rp, RequestReaderError{b}
} }
return rp, nil return rp, nil

View File

@@ -72,7 +72,7 @@ func SetupTestHelper() *TestHelper {
cfg := getTestConfig() cfg := getTestConfig()
db, err := server.NewStore(cfg, logger) db, err := server.NewStore(cfg, logger)
if err != nil { if err != nil {
logger.Fatal("server.NewStore ERROR", mlog.Err(err)) panic(err)
} }
srv, err := server.New(cfg, sessionToken, db, logger) srv, err := server.New(cfg, sessionToken, db, logger)
if err != nil { if err != nil {
@@ -94,7 +94,7 @@ func (th *TestHelper) InitBasic() *TestHelper {
for { for {
URL := th.Server.Config().ServerRoot URL := th.Server.Config().ServerRoot
th.Server.Logger().Info("Polling server", mlog.String("url", URL)) th.Server.Logger().Info("Polling server", mlog.String("url", URL))
resp, err := http.Get(URL) resp, err := http.Get(URL) //nolint:gosec
if err != nil { if err != nil {
th.Server.Logger().Error("Polling failed", mlog.Err(err)) th.Server.Logger().Error("Polling failed", mlog.Err(err))
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@@ -18,7 +18,6 @@ import (
"github.com/mattermost/focalboard/server/api" "github.com/mattermost/focalboard/server/api"
"github.com/mattermost/focalboard/server/app" "github.com/mattermost/focalboard/server/app"
"github.com/mattermost/focalboard/server/auth" "github.com/mattermost/focalboard/server/auth"
"github.com/mattermost/focalboard/server/context"
appModel "github.com/mattermost/focalboard/server/model" appModel "github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/audit" "github.com/mattermost/focalboard/server/services/audit"
"github.com/mattermost/focalboard/server/services/config" "github.com/mattermost/focalboard/server/services/config"
@@ -42,7 +41,6 @@ const (
cleanupSessionTaskFrequency = 10 * time.Minute cleanupSessionTaskFrequency = 10 * time.Minute
updateMetricsTaskFrequency = 15 * time.Minute updateMetricsTaskFrequency = 15 * time.Minute
//nolint:gomnd
minSessionExpiryTime = int64(60 * 60 * 24 * 31) // 31 days minSessionExpiryTime = int64(60 * 60 * 24 * 31) // 31 days
MattermostAuthMod = "mattermost" MattermostAuthMod = "mattermost"
@@ -76,7 +74,7 @@ func New(cfg *config.Configuration, singleUserToken string, db store.Store, logg
filesBackendSettings := filestore.FileBackendSettings{} filesBackendSettings := filestore.FileBackendSettings{}
filesBackendSettings.DriverName = cfg.FilesDriver filesBackendSettings.DriverName = cfg.FilesDriver
filesBackendSettings.Directory = cfg.FilesPath filesBackendSettings.Directory = cfg.FilesPath
filesBackendSettings.AmazonS3AccessKeyId = cfg.FilesS3Config.AccessKeyId filesBackendSettings.AmazonS3AccessKeyId = cfg.FilesS3Config.AccessKeyID
filesBackendSettings.AmazonS3SecretAccessKey = cfg.FilesS3Config.SecretAccessKey filesBackendSettings.AmazonS3SecretAccessKey = cfg.FilesS3Config.SecretAccessKey
filesBackendSettings.AmazonS3Bucket = cfg.FilesS3Config.Bucket filesBackendSettings.AmazonS3Bucket = cfg.FilesS3Config.Bucket
filesBackendSettings.AmazonS3PathPrefix = cfg.FilesS3Config.PathPrefix filesBackendSettings.AmazonS3PathPrefix = cfg.FilesS3Config.PathPrefix
@@ -320,7 +318,7 @@ func (s *Server) Logger() *mlog.Logger {
func (s *Server) startLocalModeServer() error { func (s *Server) startLocalModeServer() error {
s.localModeServer = &http.Server{ s.localModeServer = &http.Server{
Handler: s.localRouter, Handler: s.localRouter,
ConnContext: context.SetContextConn, ConnContext: api.SetContextConn,
} }
// TODO: Close and delete socket file on shutdown // TODO: Close and delete socket file on shutdown

View File

@@ -135,7 +135,10 @@ func TestIsPasswordValidWithSettings(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} else { } else {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.ExpectedFailingCriterias, err.(*InvalidPasswordError).FailingCriterias) var errFC *InvalidPasswordError
if assert.ErrorAs(t, err, &errFC) {
assert.Equal(t, tc.ExpectedFailingCriterias, errFC.FailingCriterias)
}
} }
}) })
} }

View File

@@ -12,7 +12,7 @@ const (
) )
type AmazonS3Config struct { type AmazonS3Config struct {
AccessKeyId string AccessKeyID string
SecretAccessKey string SecretAccessKey string
Bucket string Bucket string
PathPrefix string PathPrefix string

View File

@@ -3,8 +3,6 @@ package mattermostauthlayer
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"log"
"strings" "strings"
"time" "time"
@@ -20,6 +18,14 @@ const (
postgresDBType = "postgres" postgresDBType = "postgres"
) )
type NotSupportedError struct {
msg string
}
func (pe NotSupportedError) Error() string {
return pe.msg
}
// Store represents the abstraction of the data storage. // Store represents the abstraction of the data storage.
type MattermostAuthLayer struct { type MattermostAuthLayer struct {
store.Store store.Store
@@ -99,19 +105,19 @@ func (s *MattermostAuthLayer) GetUserByUsername(username string) (*model.User, e
} }
func (s *MattermostAuthLayer) CreateUser(user *model.User) error { func (s *MattermostAuthLayer) CreateUser(user *model.User) error {
return errors.New("no user creation allowed from focalboard, create it using mattermost") return NotSupportedError{"no user creation allowed from focalboard, create it using mattermost"}
} }
func (s *MattermostAuthLayer) UpdateUser(user *model.User) error { func (s *MattermostAuthLayer) UpdateUser(user *model.User) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) UpdateUserPassword(username, password string) error { func (s *MattermostAuthLayer) UpdateUserPassword(username, password string) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) UpdateUserPasswordByID(userID, password string) error { func (s *MattermostAuthLayer) UpdateUserPasswordByID(userID, password string) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
// GetActiveUserCount returns the number of users with active sessions within N seconds ago. // GetActiveUserCount returns the number of users with active sessions within N seconds ago.
@@ -133,27 +139,27 @@ func (s *MattermostAuthLayer) GetActiveUserCount(updatedSecondsAgo int64) (int,
} }
func (s *MattermostAuthLayer) GetSession(token string, expireTime int64) (*model.Session, error) { func (s *MattermostAuthLayer) GetSession(token string, expireTime int64) (*model.Session, error) {
return nil, errors.New("sessions not used when using mattermost") return nil, NotSupportedError{"sessions not used when using mattermost"}
} }
func (s *MattermostAuthLayer) CreateSession(session *model.Session) error { func (s *MattermostAuthLayer) CreateSession(session *model.Session) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) RefreshSession(session *model.Session) error { func (s *MattermostAuthLayer) RefreshSession(session *model.Session) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) UpdateSession(session *model.Session) error { func (s *MattermostAuthLayer) UpdateSession(session *model.Session) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) DeleteSession(sessionID string) error { func (s *MattermostAuthLayer) DeleteSession(sessionID string) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) CleanUpSessions(expireTime int64) error { func (s *MattermostAuthLayer) CleanUpSessions(expireTime int64) error {
return errors.New("no update allowed from focalboard, update it using mattermost") return NotSupportedError{"no update allowed from focalboard, update it using mattermost"}
} }
func (s *MattermostAuthLayer) GetWorkspace(id string) (*model.Workspace, error) { func (s *MattermostAuthLayer) GetWorkspace(id string) (*model.Workspace, error) {
@@ -205,7 +211,7 @@ func (s *MattermostAuthLayer) GetWorkspace(id string) (*model.Workspace, error)
} }
var name string var name string
if err := rows.Scan(&name); err != nil { if err := rows.Scan(&name); err != nil {
log.Fatal(err) return nil, err
} }
sb.WriteString(name) sb.WriteString(name)
} }

View File

@@ -4,18 +4,24 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"github.com/mattermost/focalboard/server/utils"
"time" "time"
"github.com/mattermost/focalboard/server/utils"
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
_ "github.com/lib/pq" _ "github.com/lib/pq" // postgres driver
"github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/mlog" "github.com/mattermost/focalboard/server/services/mlog"
"github.com/mattermost/focalboard/server/services/store" "github.com/mattermost/focalboard/server/services/store"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3" // sqlite driver
) )
type RootIDNilError struct{}
func (re RootIDNilError) Error() string {
return "rootId is nil"
}
func (s *SQLStore) GetBlocksWithParentAndType(c store.Container, parentID string, blockType string) ([]model.Block, error) { func (s *SQLStore) GetBlocksWithParentAndType(c store.Container, parentID string, blockType string) ([]model.Block, error) {
query := s.getQueryBuilder(). query := s.getQueryBuilder().
Select( Select(
@@ -327,7 +333,7 @@ func (s *SQLStore) GetParentID(c store.Container, blockID string) (string, error
func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID string) error { func (s *SQLStore) InsertBlock(c store.Container, block *model.Block, userID string) error {
if block.RootID == "" { if block.RootID == "" {
return errors.New("rootId is nil") return RootIDNilError{}
} }
fieldsJSON, err := json.Marshal(block.Fields) fieldsJSON, err := json.Marshal(block.Fields)

View File

@@ -66,7 +66,7 @@ func (s *SQLStore) isInitializationNeeded() (bool, error) {
var count int var count int
err := row.Scan(&count) err := row.Scan(&count)
if err != nil { if err != nil {
s.logger.Fatal("isInitializationNeeded", mlog.Err(err)) s.logger.Error("isInitializationNeeded", mlog.Err(err))
return false, err return false, err
} }

View File

@@ -17,9 +17,9 @@ import (
"github.com/golang-migrate/migrate/v4/database/postgres" "github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source" "github.com/golang-migrate/migrate/v4/source"
_ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/golang-migrate/migrate/v4/source/file" // fileystem driver
bindata "github.com/golang-migrate/migrate/v4/source/go_bindata" bindata "github.com/golang-migrate/migrate/v4/source/go_bindata"
_ "github.com/lib/pq" _ "github.com/lib/pq" // postgres driver
"github.com/mattermost/focalboard/server/services/store/sqlstore/migrations" "github.com/mattermost/focalboard/server/services/store/sqlstore/migrations"
) )

View File

@@ -72,7 +72,7 @@ func (s *SQLStore) getQueryBuilder() sq.StatementBuilderType {
return builder.RunWith(s.db) return builder.RunWith(s.db)
} }
func (s *SQLStore) escapeField(fieldName string) string { func (s *SQLStore) escapeField(fieldName string) string { //nolint:unparam
if s.dbType == mysqlDBType { if s.dbType == mysqlDBType {
return "`" + fieldName + "`" return "`" + fieldName + "`"
} }

View File

@@ -3,7 +3,7 @@ package sqlstore
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors" "fmt"
"log" "log"
"time" "time"
@@ -12,6 +12,14 @@ import (
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
) )
type UserNotFoundError struct {
id string
}
func (unf UserNotFoundError) Error() string {
return fmt.Sprintf("user not found (%s)", unf.id)
}
func (s *SQLStore) GetRegisteredUserCount() (int, error) { func (s *SQLStore) GetRegisteredUserCount() (int, error) {
query := s.getQueryBuilder(). query := s.getQueryBuilder().
Select("count(*)"). Select("count(*)").
@@ -132,7 +140,7 @@ func (s *SQLStore) UpdateUser(user *model.User) error {
} }
if rowCount < 1 { if rowCount < 1 {
return errors.New("user not found") return UserNotFoundError{user.ID}
} }
return nil return nil
@@ -157,7 +165,7 @@ func (s *SQLStore) UpdateUserPassword(username, password string) error {
} }
if rowCount < 1 { if rowCount < 1 {
return errors.New("user not found") return UserNotFoundError{username}
} }
return nil return nil
@@ -182,7 +190,7 @@ func (s *SQLStore) UpdateUserPasswordByID(userID, password string) error {
} }
if rowCount < 1 { if rowCount < 1 {
return errors.New("user not found") return UserNotFoundError{userID}
} }
return nil return nil

View File

@@ -1,15 +1,20 @@
package storetests package storetests
import ( import (
"github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/mattermost/focalboard/server/model" "github.com/mattermost/focalboard/server/model"
"github.com/mattermost/focalboard/server/services/store" "github.com/mattermost/focalboard/server/services/store"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const (
testUserID = "user-id"
)
func StoreTestBlocksStore(t *testing.T, setup func(t *testing.T) (store.Store, func())) { func StoreTestBlocksStore(t *testing.T, setup func(t *testing.T) (store.Store, func())) {
container := store.Container{ container := store.Container{
WorkspaceID: "0", WorkspaceID: "0",
@@ -38,42 +43,17 @@ func StoreTestBlocksStore(t *testing.T, setup func(t *testing.T) (store.Store, f
t.Run("GetParentID", func(t *testing.T) { t.Run("GetParentID", func(t *testing.T) {
store, tearDown := setup(t) store, tearDown := setup(t)
defer tearDown() defer tearDown()
testGetParentID(t, store, container) testGetParents(t, store, container)
}) })
t.Run("GetRootID", func(t *testing.T) { t.Run("GetBlocks", func(t *testing.T) {
store, tearDown := setup(t) store, tearDown := setup(t)
defer tearDown() defer tearDown()
testGetRootID(t, store, container) testGetBlocks(t, store, container)
})
t.Run("GetBlocksWithParentAndType", func(t *testing.T) {
store, tearDown := setup(t)
defer tearDown()
testGetBlocksWithParentAndType(t, store, container)
})
t.Run("GetBlocksWithParent", func(t *testing.T) {
store, tearDown := setup(t)
defer tearDown()
testGetBlocksWithParent(t, store, container)
})
t.Run("GetBlocksWithType", func(t *testing.T) {
store, tearDown := setup(t)
defer tearDown()
testGetBlocksWithType(t, store, container)
})
t.Run("GetBlocksWithRootID", func(t *testing.T) {
store, tearDown := setup(t)
defer tearDown()
testGetBlocksWithRootID(t, store, container)
})
t.Run("GetBlock", func(t *testing.T) {
store, tearDown := setup(t)
defer tearDown()
testGetBlock(t, store, container)
}) })
} }
func testInsertBlock(t *testing.T, store store.Store, container store.Container) { func testInsertBlock(t *testing.T, store store.Store, container store.Container) {
userID := "user-id" userID := testUserID
blocks, err := store.GetAllBlocks(container) blocks, err := store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
@@ -175,12 +155,12 @@ func testInsertBlock(t *testing.T, store store.Store, container store.Container)
t.Run("data tamper attempt", func(t *testing.T) { t.Run("data tamper attempt", func(t *testing.T) {
block := model.Block{ block := model.Block{
ID: "id-10", ID: "id-10",
RootID: "root-id", RootID: "root-id",
Title: "Old Title", Title: "Old Title",
CreateAt: createdAt.Unix(), CreateAt: createdAt.Unix(),
UpdateAt: updateAt.Unix(), UpdateAt: updateAt.Unix(),
CreatedBy: "user-id-5", CreatedBy: "user-id-5",
ModifiedBy: "user-id-6", ModifiedBy: "user-id-6",
} }
@@ -193,58 +173,58 @@ func testInsertBlock(t *testing.T, store store.Store, container store.Container)
assert.NotNil(t, retrievedBlock) assert.NotNil(t, retrievedBlock)
assert.Equal(t, "user-id-1", retrievedBlock.CreatedBy) assert.Equal(t, "user-id-1", retrievedBlock.CreatedBy)
assert.Equal(t, "user-id-1", retrievedBlock.ModifiedBy) assert.Equal(t, "user-id-1", retrievedBlock.ModifiedBy)
assert.WithinDurationf(t, time.Now(), time.Unix(retrievedBlock.CreateAt / 1000, 0), 1 * time.Second, "create time should be current time") assert.WithinDurationf(t, time.Now(), time.Unix(retrievedBlock.CreateAt/1000, 0), 1*time.Second, "create time should be current time")
assert.WithinDurationf(t, time.Now(), time.Unix(retrievedBlock.UpdateAt / 1000, 0), 1 * time.Second, "update time should be current time") assert.WithinDurationf(t, time.Now(), time.Unix(retrievedBlock.UpdateAt/1000, 0), 1*time.Second, "update time should be current time")
}) })
} }
func testGetSubTree2(t *testing.T, store store.Store, container store.Container) { var (
userID := "user-id" subtreeSampleBlocks = []model.Block{
blocks, err := store.GetAllBlocks(container)
require.NoError(t, err)
initialCount := len(blocks)
blocksToInsert := []model.Block{
{ {
ID: "parent", ID: "parent",
RootID: "parent", RootID: "parent",
ModifiedBy: userID, ModifiedBy: testUserID,
}, },
{ {
ID: "child1", ID: "child1",
RootID: "parent", RootID: "parent",
ParentID: "parent", ParentID: "parent",
ModifiedBy: userID, ModifiedBy: testUserID,
}, },
{ {
ID: "child2", ID: "child2",
RootID: "parent", RootID: "parent",
ParentID: "parent", ParentID: "parent",
ModifiedBy: userID, ModifiedBy: testUserID,
}, },
{ {
ID: "grandchild1", ID: "grandchild1",
RootID: "parent", RootID: "parent",
ParentID: "child1", ParentID: "child1",
ModifiedBy: userID, ModifiedBy: testUserID,
}, },
{ {
ID: "grandchild2", ID: "grandchild2",
RootID: "parent", RootID: "parent",
ParentID: "child2", ParentID: "child2",
ModifiedBy: userID, ModifiedBy: testUserID,
}, },
{ {
ID: "greatgrandchild1", ID: "greatgrandchild1",
RootID: "parent", RootID: "parent",
ParentID: "grandchild1", ParentID: "grandchild1",
ModifiedBy: userID, ModifiedBy: testUserID,
}, },
} }
)
InsertBlocks(t, store, container, blocksToInsert, "user-id-1") func testGetSubTree2(t *testing.T, store store.Store, container store.Container) {
defer DeleteBlocks(t, store, container, blocksToInsert, "test") blocks, err := store.GetAllBlocks(container)
require.NoError(t, err)
initialCount := len(blocks)
InsertBlocks(t, store, container, subtreeSampleBlocks, "user-id-1")
defer DeleteBlocks(t, store, container, subtreeSampleBlocks, "test")
blocks, err = store.GetAllBlocks(container) blocks, err = store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
@@ -275,52 +255,12 @@ func testGetSubTree2(t *testing.T, store store.Store, container store.Container)
} }
func testGetSubTree3(t *testing.T, store store.Store, container store.Container) { func testGetSubTree3(t *testing.T, store store.Store, container store.Container) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container) blocks, err := store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
initialCount := len(blocks) initialCount := len(blocks)
blocksToInsert := []model.Block{ InsertBlocks(t, store, container, subtreeSampleBlocks, "user-id-1")
{ defer DeleteBlocks(t, store, container, subtreeSampleBlocks, "test")
ID: "parent",
RootID: "parent",
ModifiedBy: userID,
},
{
ID: "child1",
RootID: "parent",
ParentID: "parent",
ModifiedBy: userID,
},
{
ID: "child2",
RootID: "parent",
ParentID: "parent",
ModifiedBy: userID,
},
{
ID: "grandchild1",
RootID: "parent",
ParentID: "child1",
ModifiedBy: userID,
},
{
ID: "grandchild2",
RootID: "parent",
ParentID: "child2",
ModifiedBy: userID,
},
{
ID: "greatgrandchild1",
RootID: "parent",
ParentID: "grandchild1",
ModifiedBy: userID,
},
}
InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test")
blocks, err = store.GetAllBlocks(container) blocks, err = store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
@@ -353,148 +293,55 @@ func testGetSubTree3(t *testing.T, store store.Store, container store.Container)
}) })
} }
func testGetRootID(t *testing.T, store store.Store, container store.Container) { func testGetParents(t *testing.T, store store.Store, container store.Container) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container) blocks, err := store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
initialCount := len(blocks) initialCount := len(blocks)
blocksToInsert := []model.Block{ InsertBlocks(t, store, container, subtreeSampleBlocks, "user-id-1")
{ defer DeleteBlocks(t, store, container, subtreeSampleBlocks, "test")
ID: "parent",
RootID: "parent",
ModifiedBy: userID,
},
{
ID: "child1",
RootID: "parent",
ParentID: "parent",
ModifiedBy: userID,
},
{
ID: "child2",
RootID: "parent",
ParentID: "parent",
ModifiedBy: userID,
},
{
ID: "grandchild1",
RootID: "parent",
ParentID: "child1",
ModifiedBy: userID,
},
{
ID: "grandchild2",
RootID: "parent",
ParentID: "child2",
ModifiedBy: userID,
},
{
ID: "greatgrandchild1",
RootID: "parent",
ParentID: "grandchild1",
ModifiedBy: userID,
},
}
InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test")
blocks, err = store.GetAllBlocks(container) blocks, err = store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, blocks, initialCount+6) require.Len(t, blocks, initialCount+6)
t.Run("from root id", func(t *testing.T) { t.Run("root from root id", func(t *testing.T) {
rootID, err := store.GetRootID(container, "parent") rootID, err := store.GetRootID(container, "parent")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "parent", rootID) require.Equal(t, "parent", rootID)
}) })
t.Run("from child id", func(t *testing.T) { t.Run("root from child id", func(t *testing.T) {
rootID, err := store.GetRootID(container, "child1") rootID, err := store.GetRootID(container, "child1")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "parent", rootID) require.Equal(t, "parent", rootID)
}) })
t.Run("from not existing id", func(t *testing.T) { t.Run("root from not existing id", func(t *testing.T) {
_, err := store.GetRootID(container, "not-exists") _, err := store.GetRootID(container, "not-exists")
require.Error(t, err) require.Error(t, err)
}) })
}
func testGetParentID(t *testing.T, store store.Store, container store.Container) { t.Run("parent from root id", func(t *testing.T) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container)
require.NoError(t, err)
initialCount := len(blocks)
blocksToInsert := []model.Block{
{
ID: "parent",
RootID: "parent",
ModifiedBy: userID,
},
{
ID: "child1",
RootID: "parent",
ParentID: "parent",
ModifiedBy: userID,
},
{
ID: "child2",
RootID: "parent",
ParentID: "parent",
ModifiedBy: userID,
},
{
ID: "grandchild1",
RootID: "parent",
ParentID: "child1",
ModifiedBy: userID,
},
{
ID: "grandchild2",
RootID: "parent",
ParentID: "child2",
ModifiedBy: userID,
},
{
ID: "greatgrandchild1",
RootID: "parent",
ParentID: "grandchild1",
ModifiedBy: userID,
},
}
InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test")
blocks, err = store.GetAllBlocks(container)
require.NoError(t, err)
require.Len(t, blocks, initialCount+6)
t.Run("from root id", func(t *testing.T) {
parentID, err := store.GetParentID(container, "parent") parentID, err := store.GetParentID(container, "parent")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "", parentID) require.Equal(t, "", parentID)
}) })
t.Run("from child id", func(t *testing.T) { t.Run("parent from child id", func(t *testing.T) {
parentID, err := store.GetParentID(container, "grandchild1") parentID, err := store.GetParentID(container, "grandchild1")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "child1", parentID) require.Equal(t, "child1", parentID)
}) })
t.Run("from not existing id", func(t *testing.T) { t.Run("parent from not existing id", func(t *testing.T) {
_, err := store.GetParentID(container, "not-exists") _, err := store.GetParentID(container, "not-exists")
require.Error(t, err) require.Error(t, err)
}) })
} }
func testDeleteBlock(t *testing.T, store store.Store, container store.Container) { func testDeleteBlock(t *testing.T, store store.Store, container store.Container) {
userID := "user-id" userID := testUserID
blocks, err := store.GetAllBlocks(container) blocks, err := store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
@@ -550,9 +397,7 @@ func testDeleteBlock(t *testing.T, store store.Store, container store.Container)
}) })
} }
func testGetBlocksWithParentAndType(t *testing.T, store store.Store, container store.Container) { func testGetBlocks(t *testing.T, store store.Store, container store.Container) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container) blocks, err := store.GetAllBlocks(container)
require.NoError(t, err) require.NoError(t, err)
@@ -561,38 +406,39 @@ func testGetBlocksWithParentAndType(t *testing.T, store store.Store, container s
ID: "block1", ID: "block1",
ParentID: "", ParentID: "",
RootID: "block1", RootID: "block1",
ModifiedBy: userID, ModifiedBy: testUserID,
Type: "test", Type: "test",
}, },
{ {
ID: "block2", ID: "block2",
ParentID: "block1", ParentID: "block1",
RootID: "block1", RootID: "block1",
ModifiedBy: userID, ModifiedBy: testUserID,
Type: "test", Type: "test",
}, },
{ {
ID: "block3", ID: "block3",
ParentID: "block1", ParentID: "block1",
RootID: "block1", RootID: "block1",
ModifiedBy: userID, ModifiedBy: testUserID,
Type: "test", Type: "test",
}, },
{ {
ID: "block4", ID: "block4",
ParentID: "block1", ParentID: "block1",
RootID: "block1", RootID: "block1",
ModifiedBy: userID, ModifiedBy: testUserID,
Type: "test2", Type: "test2",
}, },
{ {
ID: "block5", ID: "block5",
ParentID: "block2", ParentID: "block2",
RootID: "block1", RootID: "block2",
ModifiedBy: userID, ModifiedBy: testUserID,
Type: "test", Type: "test",
}, },
} }
InsertBlocks(t, store, container, blocksToInsert, "user-id-1") InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test") defer DeleteBlocks(t, store, container, blocksToInsert, "test")
@@ -616,53 +462,6 @@ func testGetBlocksWithParentAndType(t *testing.T, store store.Store, container s
require.NoError(t, err) require.NoError(t, err)
require.Len(t, blocks, 2) require.Len(t, blocks, 2)
}) })
}
func testGetBlocksWithParent(t *testing.T, store store.Store, container store.Container) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container)
require.NoError(t, err)
blocksToInsert := []model.Block{
{
ID: "block1",
ParentID: "",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block2",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block3",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block4",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test2",
},
{
ID: "block5",
ParentID: "block2",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
}
InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test")
t.Run("not existing parent", func(t *testing.T) { t.Run("not existing parent", func(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
@@ -677,53 +476,6 @@ func testGetBlocksWithParent(t *testing.T, store store.Store, container store.Co
require.NoError(t, err) require.NoError(t, err)
require.Len(t, blocks, 3) require.Len(t, blocks, 3)
}) })
}
func testGetBlocksWithType(t *testing.T, store store.Store, container store.Container) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container)
require.NoError(t, err)
blocksToInsert := []model.Block{
{
ID: "block1",
ParentID: "",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block2",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block3",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block4",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test2",
},
{
ID: "block5",
ParentID: "block2",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
}
InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test")
t.Run("not existing type", func(t *testing.T) { t.Run("not existing type", func(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
@@ -738,53 +490,6 @@ func testGetBlocksWithType(t *testing.T, store store.Store, container store.Cont
require.NoError(t, err) require.NoError(t, err)
require.Len(t, blocks, 4) require.Len(t, blocks, 4)
}) })
}
func testGetBlocksWithRootID(t *testing.T, store store.Store, container store.Container) {
userID := "user-id"
blocks, err := store.GetAllBlocks(container)
require.NoError(t, err)
blocksToInsert := []model.Block{
{
ID: "block1",
ParentID: "",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block2",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block3",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test",
},
{
ID: "block4",
ParentID: "block1",
RootID: "block1",
ModifiedBy: userID,
Type: "test2",
},
{
ID: "block5",
ParentID: "block2",
RootID: "block2",
ModifiedBy: userID,
Type: "test",
},
}
InsertBlocks(t, store, container, blocksToInsert, "user-id-1")
defer DeleteBlocks(t, store, container, blocksToInsert, "test")
t.Run("not existing parent", func(t *testing.T) { t.Run("not existing parent", func(t *testing.T) {
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
@@ -819,8 +524,8 @@ func testGetBlock(t *testing.T, store store.Store, container store.Container) {
require.Equal(t, "root-id-1", fetchedBlock.RootID) require.Equal(t, "root-id-1", fetchedBlock.RootID)
require.Equal(t, "user-id-1", fetchedBlock.CreatedBy) require.Equal(t, "user-id-1", fetchedBlock.CreatedBy)
require.Equal(t, "user-id-1", fetchedBlock.ModifiedBy) require.Equal(t, "user-id-1", fetchedBlock.ModifiedBy)
assert.WithinDurationf(t, time.Now(), time.Unix(fetchedBlock.CreateAt / 1000, 0), 1 * time.Second, "create time should be current time") assert.WithinDurationf(t, time.Now(), time.Unix(fetchedBlock.CreateAt/1000, 0), 1*time.Second, "create time should be current time")
assert.WithinDurationf(t, time.Now(), time.Unix(fetchedBlock.UpdateAt / 1000, 0), 1 * time.Second, "update time should be current time") assert.WithinDurationf(t, time.Now(), time.Unix(fetchedBlock.UpdateAt/1000, 0), 1*time.Second, "update time should be current time")
}) })
t.Run("get a non-existing block", func(t *testing.T) { t.Run("get a non-existing block", func(t *testing.T) {

View File

@@ -26,7 +26,7 @@ func testUpsertSharingAndGetSharing(t *testing.T, store store.Store, container s
ID: "sharing-id", ID: "sharing-id",
Enabled: true, Enabled: true,
Token: "token", Token: "token",
ModifiedBy: "user-id", ModifiedBy: testUserID,
} }
err := store.UpsertSharing(container, sharing) err := store.UpsertSharing(container, sharing)

View File

@@ -53,11 +53,11 @@ func (ts *Service) RegisterTracker(name string, f TrackerFunc) {
func (ts *Service) getRudderConfig() RudderConfig { func (ts *Service) getRudderConfig() RudderConfig {
if !strings.Contains(rudderKey, "placeholder") && !strings.Contains(rudderDataplaneURL, "placeholder") { if !strings.Contains(rudderKey, "placeholder") && !strings.Contains(rudderDataplaneURL, "placeholder") {
return RudderConfig{rudderKey, rudderDataplaneURL} return RudderConfig{rudderKey, rudderDataplaneURL}
} else if os.Getenv("RUDDER_KEY") != "" && os.Getenv("RUDDER_DATAPLANE_URL") != "" {
return RudderConfig{os.Getenv("RUDDER_KEY"), os.Getenv("RUDDER_DATAPLANE_URL")}
} else {
return RudderConfig{}
} }
if os.Getenv("RUDDER_KEY") != "" && os.Getenv("RUDDER_DATAPLANE_URL") != "" {
return RudderConfig{os.Getenv("RUDDER_KEY"), os.Getenv("RUDDER_DATAPLANE_URL")}
}
return RudderConfig{}
} }
func (ts *Service) sendDailyTelemetry(override bool) { func (ts *Service) sendDailyTelemetry(override bool) {
@@ -113,15 +113,23 @@ func (ts *Service) initRudder(endpoint, rudderKey string) {
func (ts *Service) doTelemetryIfNeeded(firstRun time.Time) { func (ts *Service) doTelemetryIfNeeded(firstRun time.Time) {
hoursSinceFirstServerRun := time.Since(firstRun).Hours() hoursSinceFirstServerRun := time.Since(firstRun).Hours()
// Send once every 10 minutes for the first hour // Send once every 10 minutes for the first hour
// Send once every hour thereafter for the first 12 hours
// Send at the 24 hour mark and every 24 hours after
if hoursSinceFirstServerRun < 1 { if hoursSinceFirstServerRun < 1 {
ts.doTelemetry() ts.doTelemetry()
} else if hoursSinceFirstServerRun <= 12 && time.Since(ts.timestampLastTelemetrySent) >= time.Hour { return
}
// Send once every hour thereafter for the first 12 hours
if hoursSinceFirstServerRun <= 12 && time.Since(ts.timestampLastTelemetrySent) >= time.Hour {
ts.doTelemetry() ts.doTelemetry()
} else if hoursSinceFirstServerRun > 12 && time.Since(ts.timestampLastTelemetrySent) >= 24*time.Hour { return
}
// Send at the 24 hour mark and every 24 hours after
if hoursSinceFirstServerRun > 12 && time.Since(ts.timestampLastTelemetrySent) >= 24*time.Hour {
ts.doTelemetry() ts.doTelemetry()
return
} }
} }

View File

@@ -2,7 +2,6 @@ package ws
import ( import (
"encoding/json" "encoding/json"
"errors"
"log" "log"
"net/http" "net/http"
"sync" "sync"
@@ -220,6 +219,14 @@ func (ws *Server) authenticateListener(wsSession *websocketSession, workspaceID,
ws.logger.Debug("authenticateListener: Authenticated", mlog.String("workspaceID", workspaceID)) ws.logger.Debug("authenticateListener: Authenticated", mlog.String("workspaceID", workspaceID))
} }
type AuthWorkspaceError struct {
msg string
}
func (awe AuthWorkspaceError) Error() string {
return awe.msg
}
func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, command *WebsocketCommand) (string, error) { func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, command *WebsocketCommand) (string, error) {
if wsSession.isAuthenticated { if wsSession.isAuthenticated {
return wsSession.workspaceID, nil return wsSession.workspaceID, nil
@@ -229,7 +236,7 @@ func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, comma
workspaceID := command.WorkspaceID workspaceID := command.WorkspaceID
if len(workspaceID) == 0 { if len(workspaceID) == 0 {
ws.logger.Error("getAuthenticatedWorkspaceID: No workspace") ws.logger.Error("getAuthenticatedWorkspaceID: No workspace")
return "", errors.New("no workspace") return "", AuthWorkspaceError{"no workspace"}
} }
container := store.Container{ container := store.Container{
@@ -241,13 +248,13 @@ func (ws *Server) getAuthenticatedWorkspaceID(wsSession *websocketSession, comma
for _, blockID := range command.BlockIDs { for _, blockID := range command.BlockIDs {
isValid, _ := ws.auth.IsValidReadToken(container, blockID, command.ReadToken) isValid, _ := ws.auth.IsValidReadToken(container, blockID, command.ReadToken)
if !isValid { if !isValid {
return "", errors.New("invalid read token for workspace") return "", AuthWorkspaceError{"invalid read token for workspace"}
} }
} }
return workspaceID, nil return workspaceID, nil
} }
return "", errors.New("no read token") return "", AuthWorkspaceError{"no read token"}
} }
// TODO: Refactor workspace hashing. // TODO: Refactor workspace hashing.
@@ -314,7 +321,8 @@ func (ws *Server) removeListenerFromBlocks(wsSession *websocketSession, command
// Note: A client can listen multiple times to the same block // Note: A client can listen multiple times to the same block
for index, listener := range listeners { for index, listener := range listeners {
if wsSession.client == listener { if wsSession.client == listener {
newListeners := append(listeners[:index], listeners[index+1:]...) newListeners := listeners[:index]
newListeners = append(newListeners, listeners[index+1:]...)
ws.listeners[itemID] = newListeners ws.listeners[itemID] = newListeners
break break