diff --git a/server/main/api.go b/server/main/api.go index 447efed35..11ff69410 100644 --- a/server/main/api.go +++ b/server/main/api.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -17,10 +18,16 @@ import ( // ---------------------------------------------------------------------------------------------------- // REST APIs -type API struct{} +type API struct { + appBuilder func() *App +} -func NewAPI() *API { - return &API{} +func NewAPI(appBuilder func() *App) *API { + return &API{appBuilder: appBuilder} +} + +func (a *API) app() *App { + return a.appBuilder() } func (a *API) RegisterRoutes(r *mux.Router) { @@ -41,13 +48,11 @@ func (a *API) handleGetBlocks(w http.ResponseWriter, r *http.Request) { parentID := query.Get("parent_id") blockType := query.Get("type") - var blocks []Block - if len(blockType) > 0 && len(parentID) > 0 { - blocks = store.getBlocksWithParentAndType(parentID, blockType) - } else if len(blockType) > 0 { - blocks = store.getBlocksWithType(blockType) - } else { - blocks = store.getBlocksWithParent(parentID) + blocks, err := a.app().GetBlocks(parentID, blockType) + if err != nil { + log.Printf(`ERROR GetBlocks: %v`, r) + errorResponse(w, http.StatusInternalServerError, `{}`) + return } log.Printf("GetBlocks parentID: %s, type: %s, %d result(s)", parentID, blockType, len(blocks)) @@ -84,9 +89,6 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) { return } - var blockIDsToNotify = []string{} - uniqueBlockIDs := make(map[string]bool) - for _, block := range blocks { // Error checking if len(block.Type) < 1 { @@ -102,17 +104,14 @@ func (a *API) handlePostBlocks(w http.ResponseWriter, r *http.Request) { return } - if !uniqueBlockIDs[block.ID] { - blockIDsToNotify = append(blockIDsToNotify, block.ID) - } - if len(block.ParentID) > 0 && !uniqueBlockIDs[block.ParentID] { - blockIDsToNotify = append(blockIDsToNotify, block.ParentID) - } - - store.insertBlock(block) } - wsServer.broadcastBlockChangeToWebsocketClients(blockIDsToNotify) + err = a.app().InsertBlocks(blocks) + if err != nil { + log.Printf(`ERROR: %v`, r) + errorResponse(w, http.StatusInternalServerError, `{}`) + return + } log.Printf("POST Blocks %d block(s)", len(blocks)) jsonStringResponse(w, http.StatusOK, "{}") @@ -122,18 +121,13 @@ func (a *API) handleDeleteBlock(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) blockID := vars["blockID"] - var blockIDsToNotify = []string{blockID} - - parentID := store.getParentID(blockID) - - if len(parentID) > 0 { - blockIDsToNotify = append(blockIDsToNotify, parentID) + err := a.app().DeleteBlock(blockID) + if err != nil { + log.Printf(`ERROR: %v`, r) + errorResponse(w, http.StatusInternalServerError, `{}`) + return } - store.deleteBlock(blockID) - - wsServer.broadcastBlockChangeToWebsocketClients(blockIDsToNotify) - log.Printf("DELETE Block %s", blockID) jsonStringResponse(w, http.StatusOK, "{}") } @@ -142,7 +136,12 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) blockID := vars["blockID"] - blocks := store.getSubTree(blockID) + blocks, err := a.app().GetSubTree(blockID) + if err != nil { + log.Printf(`ERROR: %v`, r) + errorResponse(w, http.StatusInternalServerError, `{}`) + return + } log.Printf("GetSubTree blockID: %s, %d result(s)", blockID, len(blocks)) json, err := json.Marshal(blocks) @@ -156,7 +155,12 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) { } func (a *API) handleExport(w http.ResponseWriter, r *http.Request) { - blocks := store.getAllBlocks() + blocks, err := a.app().GetAllBlocks() + if err != nil { + log.Printf(`ERROR: %v`, r) + errorResponse(w, http.StatusInternalServerError, `{}`) + return + } log.Printf("EXPORT Blocks, %d result(s)", len(blocks)) json, err := json.Marshal(blocks) @@ -193,7 +197,12 @@ func (a *API) handleImport(w http.ResponseWriter, r *http.Request) { } for _, block := range blocks { - store.insertBlock(block) + err := a.app().InsertBlock(block) + if err != nil { + log.Printf(`ERROR: %v`, r) + errorResponse(w, http.StatusInternalServerError, `{}`) + return + } } log.Printf("IMPORT Blocks %d block(s)", len(blocks)) @@ -281,3 +290,9 @@ func errorResponse(w http.ResponseWriter, code int, message string) { w.WriteHeader(code) fmt.Fprint(w, message) } + +func addUserID(rw http.ResponseWriter, req *http.Request, next http.Handler) { + ctx := context.WithValue(req.Context(), "userid", req.Header.Get("userid")) + req = req.WithContext(ctx) + next.ServeHTTP(rw, req) +} diff --git a/server/main/app.go b/server/main/app.go new file mode 100644 index 000000000..813ee083e --- /dev/null +++ b/server/main/app.go @@ -0,0 +1,74 @@ +package main + +type App struct { + store *SQLStore + wsServer *WSServer +} + +func (a *App) GetBlocks(parentID string, blockType string) ([]Block, error) { + if len(blockType) > 0 && len(parentID) > 0 { + return a.store.getBlocksWithParentAndType(parentID, blockType) + } + if len(blockType) > 0 { + return a.store.getBlocksWithType(blockType) + } + return a.store.getBlocksWithParent(parentID) +} + +func (a *App) GetParentID(blockID string) (string, error) { + return a.store.getParentID(blockID) +} + +func (a *App) InsertBlock(block Block) error { + return a.store.insertBlock(block) +} + +func (a *App) InsertBlocks(blocks []Block) error { + var blockIDsToNotify = []string{} + uniqueBlockIDs := make(map[string]bool) + + for _, block := range blocks { + if !uniqueBlockIDs[block.ID] { + blockIDsToNotify = append(blockIDsToNotify, block.ID) + } + if len(block.ParentID) > 0 && !uniqueBlockIDs[block.ParentID] { + blockIDsToNotify = append(blockIDsToNotify, block.ParentID) + } + + err := a.store.insertBlock(block) + if err != nil { + return err + } + } + + wsServer.broadcastBlockChangeToWebsocketClients(blockIDsToNotify) + return nil +} + +func (a *App) GetSubTree(blockID string) ([]Block, error) { + return a.store.getSubTree(blockID) +} + +func (a *App) GetAllBlocks() ([]Block, error) { + return a.store.getAllBlocks() +} + +func (a *App) DeleteBlock(blockID string) error { + var blockIDsToNotify = []string{blockID} + parentID, err := a.GetParentID(blockID) + if err != nil { + return err + } + + if len(parentID) > 0 { + blockIDsToNotify = append(blockIDsToNotify, parentID) + } + + err = a.store.deleteBlock(blockID) + if err != nil { + return err + } + + a.wsServer.broadcastBlockChangeToWebsocketClients(blockIDsToNotify) + return nil +} diff --git a/server/main/main.go b/server/main/main.go index 00ec04530..15aba2604 100644 --- a/server/main/main.go +++ b/server/main/main.go @@ -7,7 +7,6 @@ import ( "log" "os" - "os/signal" ) var config *Configuration @@ -70,32 +69,12 @@ func main() { config.Port = *pPort } - wsServer = NewWSServer() - webServer = NewWebServer(config.Port, config.UseSSL) - api = NewAPI() - webServer.AddRoutes(api) - webServer.AddRoutes(wsServer) - - store, err = NewSQLStore(config.DBType, config.DBConfigString) + server, err := NewServer(config) if err != nil { - log.Fatal("Unable to start the database", err) - panic(err) + log.Fatal("ListenAndServeTLS: ", err) } - // Ctrl+C handling - handler := make(chan os.Signal, 1) - signal.Notify(handler, os.Interrupt) - go func() { - for sig := range handler { - // sig is a ^C, handle it - if sig == os.Interrupt { - os.Exit(1) - break - } - } - }() - - if err := webServer.Start(); err != nil { + if err := server.Start(); err != nil { log.Fatal("ListenAndServeTLS: ", err) } } diff --git a/server/main/octoDatabase.go b/server/main/octoDatabase.go index 5437fe608..eb2ab9160 100644 --- a/server/main/octoDatabase.go +++ b/server/main/octoDatabase.go @@ -102,7 +102,7 @@ func (s *SQLStore) createTablesIfNotExists() error { return nil } -func (s *SQLStore) getBlocksWithParentAndType(parentID string, blockType string) []Block { +func (s *SQLStore) getBlocksWithParentAndType(parentID string, blockType string) ([]Block, error) { query := `WITH latest AS ( SELECT * FROM @@ -122,13 +122,13 @@ func (s *SQLStore) getBlocksWithParentAndType(parentID string, blockType string) rows, err := s.db.Query(query, parentID, blockType) if err != nil { log.Printf(`getBlocksWithParentAndType ERROR: %v`, err) - panic(err) + return nil, err } return blocksFromRows(rows) } -func (s *SQLStore) getBlocksWithParent(parentID string) []Block { +func (s *SQLStore) getBlocksWithParent(parentID string) ([]Block, error) { query := `WITH latest AS ( SELECT * FROM @@ -148,13 +148,13 @@ func (s *SQLStore) getBlocksWithParent(parentID string) []Block { rows, err := s.db.Query(query, parentID) if err != nil { log.Printf(`getBlocksWithParent ERROR: %v`, err) - panic(err) + return nil, err } return blocksFromRows(rows) } -func (s *SQLStore) getBlocksWithType(blockType string) []Block { +func (s *SQLStore) getBlocksWithType(blockType string) ([]Block, error) { query := `WITH latest AS ( SELECT * FROM @@ -174,13 +174,13 @@ func (s *SQLStore) getBlocksWithType(blockType string) []Block { rows, err := s.db.Query(query, blockType) if err != nil { log.Printf(`getBlocksWithParentAndType ERROR: %v`, err) - panic(err) + return nil, err } return blocksFromRows(rows) } -func (s *SQLStore) getSubTree(blockID string) []Block { +func (s *SQLStore) getSubTree(blockID string) ([]Block, error) { query := `WITH latest AS ( SELECT * FROM @@ -202,13 +202,13 @@ func (s *SQLStore) getSubTree(blockID string) []Block { rows, err := s.db.Query(query, blockID) if err != nil { log.Printf(`getSubTree ERROR: %v`, err) - panic(err) + return nil, err } return blocksFromRows(rows) } -func (s *SQLStore) getAllBlocks() []Block { +func (s *SQLStore) getAllBlocks() ([]Block, error) { query := `WITH latest AS ( SELECT * FROM @@ -228,13 +228,13 @@ func (s *SQLStore) getAllBlocks() []Block { rows, err := s.db.Query(query) if err != nil { log.Printf(`getAllBlocks ERROR: %v`, err) - panic(err) + return nil, err } return blocksFromRows(rows) } -func blocksFromRows(rows *sql.Rows) []Block { +func blocksFromRows(rows *sql.Rows) ([]Block, error) { defer rows.Close() var results []Block @@ -255,23 +255,23 @@ func blocksFromRows(rows *sql.Rows) []Block { if err != nil { // handle this error log.Printf(`ERROR blocksFromRows: %v`, err) - panic(err) + return nil, err } err = json.Unmarshal([]byte(fieldsJSON), &block.Fields) if err != nil { // handle this error log.Printf(`ERROR blocksFromRows fields: %v`, err) - panic(err) + return nil, err } results = append(results, block) } - return results + return results, nil } -func (s *SQLStore) getParentID(blockID string) string { +func (s *SQLStore) getParentID(blockID string) (string, error) { statement := `WITH latest AS ( @@ -295,16 +295,16 @@ func (s *SQLStore) getParentID(blockID string) string { var parentID string err := row.Scan(&parentID) if err != nil { - return "" + return "", err } - return parentID + return parentID, nil } -func (s *SQLStore) insertBlock(block Block) { +func (s *SQLStore) insertBlock(block Block) error { fieldsJSON, err := json.Marshal(block.Fields) if err != nil { - panic(err) + return err } statement := `INSERT INTO blocks( @@ -331,15 +331,17 @@ func (s *SQLStore) insertBlock(block Block) { block.UpdateAt, block.DeleteAt) if err != nil { - panic(err) + return err } + return nil } -func (s *SQLStore) deleteBlock(blockID string) { +func (s *SQLStore) deleteBlock(blockID string) error { now := time.Now().Unix() statement := `INSERT INTO blocks(id, update_at, delete_at) VALUES($1, $2, $3)` _, err := s.db.Exec(statement, blockID, now, now) if err != nil { - panic(err) + return err } + return nil } diff --git a/server/main/server.go b/server/main/server.go new file mode 100644 index 000000000..646a9df6d --- /dev/null +++ b/server/main/server.go @@ -0,0 +1,58 @@ +package main + +import ( + "log" + "os" + "os/signal" +) + +type Server struct { + config *Configuration + wsServer *WSServer + webServer *WebServer + store *SQLStore +} + +func NewServer(config *Configuration) (*Server, error) { + store, err := NewSQLStore(config.DBType, config.DBConfigString) + if err != nil { + log.Fatal("Unable to start the database", err) + return nil, err + } + + wsServer = NewWSServer() + + appBuilder := func() *App { return &App{store: store, wsServer: wsServer} } + + webServer = NewWebServer(config.WebPath, config.Port, config.UseSSL) + api = NewAPI(appBuilder) + webServer.AddRoutes(wsServer) + webServer.AddRoutes(api) + + // Ctrl+C handling + handler := make(chan os.Signal, 1) + signal.Notify(handler, os.Interrupt) + go func() { + for sig := range handler { + // sig is a ^C, handle it + if sig == os.Interrupt { + os.Exit(1) + break + } + } + }() + + return &Server{ + config: config, + wsServer: wsServer, + webServer: webServer, + store: store, + }, nil +} + +func (s *Server) Start() error { + if err := webServer.Start(); err != nil { + return err + } + return nil +} diff --git a/server/main/webserver.go b/server/main/webserver.go index 1407b8f3c..e444849e4 100644 --- a/server/main/webserver.go +++ b/server/main/webserver.go @@ -14,30 +14,34 @@ type RoutedService interface { } type WebServer struct { - router *mux.Router - port int - ssl bool + router *mux.Router + rootPath string + port int + ssl bool } -func NewWebServer(port int, ssl bool) *WebServer { +func NewWebServer(rootPath string, port int, ssl bool) *WebServer { r := mux.NewRouter() - // Static files - handleDefault(r, "/") - handleStaticFile(r, "/login", "index.html", "text/html; charset=utf-8") - handleStaticFile(r, "/board", "index.html", "text/html; charset=utf-8") - handleStaticFile(r, "/main.js", "main.js", "text/javascript; charset=utf-8") - handleStaticFile(r, "/boardPage.js", "boardPage.js", "text/javascript; charset=utf-8") - handleStaticFile(r, "/favicon.ico", "static/favicon.svg", "image/svg+xml; charset=utf-8") - handleStaticFile(r, "/easymde.min.css", "static/easymde.min.css", "text/css") - handleStaticFile(r, "/main.css", "static/main.css", "text/css") - handleStaticFile(r, "/colors.css", "static/colors.css", "text/css") - handleStaticFile(r, "/images.css", "static/images.css", "text/css") - return &WebServer{ - router: r, - port: port, - ssl: ssl, + ws := &WebServer{ + router: r, + rootPath: rootPath, + port: port, + ssl: ssl, } + + // Static files + ws.handleDefault(r, "/") + ws.handleStaticFile(r, "/login", "index.html", "text/html; charset=utf-8") + ws.handleStaticFile(r, "/board", "index.html", "text/html; charset=utf-8") + ws.handleStaticFile(r, "/main.js", "main.js", "text/javascript; charset=utf-8") + ws.handleStaticFile(r, "/boardPage.js", "boardPage.js", "text/javascript; charset=utf-8") + ws.handleStaticFile(r, "/favicon.ico", "static/favicon.svg", "image/svg+xml; charset=utf-8") + ws.handleStaticFile(r, "/easymde.min.css", "static/easymde.min.css", "text/css") + ws.handleStaticFile(r, "/main.css", "static/main.css", "text/css") + ws.handleStaticFile(r, "/colors.css", "static/colors.css", "text/css") + ws.handleStaticFile(r, "/images.css", "static/images.css", "text/css") + return ws } func (ws *WebServer) AddRoutes(rs RoutedService) { @@ -68,21 +72,21 @@ func (ws *WebServer) Start() error { // ---------------------------------------------------------------------------------------------------- // HTTP handlers -func serveWebFile(w http.ResponseWriter, r *http.Request, relativeFilePath string) { - folderPath := config.WebPath +func (ws *WebServer) serveWebFile(w http.ResponseWriter, r *http.Request, relativeFilePath string) { + folderPath := ws.rootPath filePath := filepath.Join(folderPath, relativeFilePath) http.ServeFile(w, r, filePath) } -func handleStaticFile(r *mux.Router, requestPath string, filePath string, contentType string) { +func (ws *WebServer) handleStaticFile(r *mux.Router, requestPath string, filePath string, contentType string) { r.HandleFunc(requestPath, func(w http.ResponseWriter, r *http.Request) { log.Printf("handleStaticFile: %s", requestPath) w.Header().Set("Content-Type", contentType) - serveWebFile(w, r, filePath) + ws.serveWebFile(w, r, filePath) }) } -func handleDefault(r *mux.Router, requestPath string) { +func (ws *WebServer) handleDefault(r *mux.Router, requestPath string) { r.HandleFunc(requestPath, func(w http.ResponseWriter, r *http.Request) { log.Printf("handleDefault") http.Redirect(w, r, "/board", http.StatusFound)