From 61fb38d41889cf9d59b5da93690a921b4fa3bbbf Mon Sep 17 00:00:00 2001 From: Chen-I Lim Date: Tue, 12 Jan 2021 18:49:08 -0800 Subject: [PATCH] Allow GetSubTree without auth. WIP --- server/api/api.go | 31 +++++++++++++++++++++--- server/api/auth.go | 11 ++++++++- server/app/blocks.go | 4 +++ server/services/store/sqlstore/blocks.go | 17 +++++++++++++ server/services/store/store.go | 1 + webapp/src/octoClient.ts | 3 +++ 6 files changed, 62 insertions(+), 5 deletions(-) diff --git a/server/api/api.go b/server/api/api.go index 678d2e654..401d295fe 100644 --- a/server/api/api.go +++ b/server/api/api.go @@ -37,7 +37,7 @@ func (a *API) RegisterRoutes(r *mux.Router) { r.HandleFunc("/api/v1/blocks", a.sessionRequired(a.handleGetBlocks)).Methods("GET") r.HandleFunc("/api/v1/blocks", a.sessionRequired(a.handlePostBlocks)).Methods("POST") r.HandleFunc("/api/v1/blocks/{blockID}", a.sessionRequired(a.handleDeleteBlock)).Methods("DELETE") - r.HandleFunc("/api/v1/blocks/{blockID}/subtree", a.sessionRequired(a.handleGetSubTree)).Methods("GET") + r.HandleFunc("/api/v1/blocks/{blockID}/subtree", a.attachSession(a.handleGetSubTree, false)).Methods("GET") r.HandleFunc("/api/v1/users/me", a.sessionRequired(a.handleGetMe)).Methods("GET") r.HandleFunc("/api/v1/users/{userID}", a.sessionRequired(a.handleGetUser)).Methods("GET") @@ -242,6 +242,32 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) blockID := vars["blockID"] + // If not authenticated (no session), check that block is publicly shared + ctx := r.Context() + session, _ := ctx.Value("session").(*model.Session) + if session == nil { + rootID, err := a.app().GetRootID(blockID) + if err != nil { + log.Printf(`ERROR GetRootID %v: %v, REQUEST: %v`, blockID, err, r) + errorResponse(w, http.StatusInternalServerError, nil) + return + } + + sharing, err := a.app().GetSharing(rootID) + if err != nil { + log.Printf(`ERROR GetSharing %v: %v, REQUEST: %v`, rootID, err, r) + errorResponse(w, http.StatusInternalServerError, nil) + return + } + + // TODO: Check token + if sharing == nil || !(sharing.ID == rootID && sharing.Enabled) { + log.Printf(`handleGetSubTree public unauthorized, rootID: %v`, rootID) + errorResponse(w, http.StatusUnauthorized, nil) + return + } + } + query := r.URL.Query() levels, err := strconv.ParseInt(query.Get("l"), 10, 32) if err != nil { @@ -252,7 +278,6 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) { log.Printf(`ERROR Invalid levels: %d`, levels) errorData := map[string]string{"description": "invalid levels"} errorResponse(w, http.StatusInternalServerError, errorData) - return } @@ -260,7 +285,6 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) { if err != nil { log.Printf(`ERROR: %v, REQUEST: %v`, err, r) errorResponse(w, http.StatusInternalServerError, nil) - return } @@ -269,7 +293,6 @@ func (a *API) handleGetSubTree(w http.ResponseWriter, r *http.Request) { if err != nil { log.Printf(`ERROR json.Marshal: %v, REQUEST: %v`, err, r) errorResponse(w, http.StatusInternalServerError, nil) - return } diff --git a/server/api/auth.go b/server/api/auth.go index d214c5ab4..c4ccd2ffb 100644 --- a/server/api/auth.go +++ b/server/api/auth.go @@ -110,6 +110,10 @@ func (a *API) handleRegister(w http.ResponseWriter, r *http.Request) { } func (a *API) sessionRequired(handler func(w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) { + return a.attachSession(handler, true) +} + +func (a *API) attachSession(handler func(w http.ResponseWriter, r *http.Request), required bool) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { log.Printf(`Single User: %v`, a.singleUser) if a.singleUser { @@ -129,7 +133,12 @@ func (a *API) sessionRequired(handler func(w http.ResponseWriter, r *http.Reques token, _ := auth.ParseAuthTokenFromRequest(r) session, err := a.app().GetSession(token) if err != nil { - errorResponse(w, http.StatusUnauthorized, map[string]string{"error": err.Error()}) + if required { + errorResponse(w, http.StatusUnauthorized, map[string]string{"error": err.Error()}) + return + } + + handler(w, r) return } ctx := context.WithValue(r.Context(), "session", session) diff --git a/server/app/blocks.go b/server/app/blocks.go index 064f1c569..89d8387c8 100644 --- a/server/app/blocks.go +++ b/server/app/blocks.go @@ -16,6 +16,10 @@ func (a *App) GetBlocks(parentID string, blockType string) ([]model.Block, error return a.store.GetBlocksWithParent(parentID) } +func (a *App) GetRootID(blockID string) (string, error) { + return a.store.GetRootID(blockID) +} + func (a *App) GetParentID(blockID string) (string, error) { return a.store.GetParentID(blockID) } diff --git a/server/services/store/sqlstore/blocks.go b/server/services/store/sqlstore/blocks.go index 9e40a5f99..2aacd9bb2 100644 --- a/server/services/store/sqlstore/blocks.go +++ b/server/services/store/sqlstore/blocks.go @@ -262,6 +262,23 @@ func blocksFromRows(rows *sql.Rows) ([]model.Block, error) { return results, nil } +func (s *SQLStore) GetRootID(blockID string) (string, error) { + query := s.getQueryBuilder().Select("root_id"). + FromSelect(s.latestsBlocksSubquery(), "latest"). + Where(sq.Eq{"id": blockID}) + + row := query.QueryRow() + + var rootID string + + err := row.Scan(&rootID) + if err != nil { + return "", err + } + + return rootID, nil +} + func (s *SQLStore) GetParentID(blockID string) (string, error) { query := s.getQueryBuilder().Select("parent_id"). FromSelect(s.latestsBlocksSubquery(), "latest"). diff --git a/server/services/store/store.go b/server/services/store/store.go index 50cd675c0..5d2ebe93e 100644 --- a/server/services/store/store.go +++ b/server/services/store/store.go @@ -11,6 +11,7 @@ type Store interface { GetSubTree2(blockID string) ([]model.Block, error) GetSubTree3(blockID string) ([]model.Block, error) GetAllBlocks() ([]model.Block, error) + GetRootID(blockID string) (string, error) GetParentID(blockID string) (string, error) InsertBlock(block model.Block) error DeleteBlock(blockID string, modifiedBy string) error diff --git a/webapp/src/octoClient.ts b/webapp/src/octoClient.ts index 9c7a14f35..9c4e8eeef 100644 --- a/webapp/src/octoClient.ts +++ b/webapp/src/octoClient.ts @@ -113,6 +113,9 @@ class OctoClient { private async getBlocksWithPath(path: string): Promise { const response = await fetch(this.serverUrl + path, {headers: this.headers()}) + if (response.status !== 200) { + return [] + } const blocks = (await response.json() || []) as IMutableBlock[] this.fixBlocks(blocks) return blocks