diff --git a/src/db/dblist.go b/src/db/dblist.go index 08239aa..f20eea6 100644 --- a/src/db/dblist.go +++ b/src/db/dblist.go @@ -13,18 +13,20 @@ import ( "github.com/sirupsen/logrus" ) +// *** SQL connections *** + // Gets SQL server connection by GUID -func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) { - val, ok := o.items.Load(guid) +func (o *DbList) GetById(id string, updateTimestamp bool) (*sql.DB, bool) { + val, ok := o.items.Load(id) if ok { res := val.(*DbConn) if updateTimestamp { res.Timestamp = time.Now() - o.items.Store(guid, res) + o.items.Store(id, res) } return res.DB, true } - app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid)) + app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", id)) return nil, false } @@ -141,6 +143,71 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, return newId, true } +// Deletes SQL server connection +func (o *DbList) Delete(id string) { + o.items.Delete(id) + app.Log.Debug(fmt.Sprintf("DB connection with id %s was deleted by query", id)) +} + +// *** SQL prepared statements *** + +// Saves SQL prepared statement +func (o *DbList) PutPreparedStatement(id string, stmt *sql.Stmt) (string, bool) { + val, ok := o.items.Load(id) + if !ok { + app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", id)) + return "", false + } + + newId := uuid.New().String() + dbStmt := DbStmt{ + Id: newId, + Stmt: stmt, + } + res := val.(*DbConn) + res.Timestamp = time.Now() + res.Stmt = append(res.Stmt, dbStmt) + o.items.Store(id, res) + return newId, true +} + +// Gets SQL prepared statement +func (o *DbList) GetPreparedStatement(conn_id, stmt_id string) (*sql.Stmt, bool) { + val, ok := o.items.Load(conn_id) + if !ok { + app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", conn_id)) + return nil, false + } + res := val.(*DbConn) + for i := 0; i < len(res.Stmt); i++ { + if res.Stmt[i].Id == stmt_id { + return res.Stmt[i].Stmt, true + } + } + return nil, false +} + +// Closes and deletes SQL prepared statement +func (o *DbList) ClosePreparedStatement(conn_id, stmt_id string) bool { + val, ok := o.items.Load(conn_id) + if !ok { + app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", conn_id)) + return false + } + res := val.(*DbConn) + for i := 0; i < len(res.Stmt); i++ { + if res.Stmt[i].Id == stmt_id { + res.Stmt[i].Stmt.Close() + res.Stmt = append(res.Stmt[:i], res.Stmt[i+1:]...) + break + } + } + o.items.Store(conn_id, res) + return true +} + +// *** Maintenance *** + func (o *DbList) RunMaintenance() { ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -151,8 +218,10 @@ func (o *DbList) RunMaintenance() { // detect dead connections var deadItems []string + var count int o.items.Range( func(key, value interface{}) bool { + count++ err := value.(*DbConn).DB.Ping() if err != nil { deadItems = append(deadItems, key.(string)) @@ -160,6 +229,8 @@ func (o *DbList) RunMaintenance() { return true // continue iteration }) + app.Log.Debug(fmt.Sprintf("Regular task: pool size = %d", count)) + // remove dead connections if len(deadItems) > 0 { for _, item := range deadItems { @@ -172,9 +243,3 @@ func (o *DbList) RunMaintenance() { } } - -func (o *DbList) Delete(id string) { - - o.items.Delete(id) - -} diff --git a/src/db/driver/mysql.go b/src/db/mysql.go similarity index 100% rename from src/db/driver/mysql.go rename to src/db/mysql.go diff --git a/src/db/driver/postgres.go b/src/db/postgres.go similarity index 100% rename from src/db/driver/postgres.go rename to src/db/postgres.go diff --git a/src/db/driver/sqlserver.go b/src/db/sqlserver.go similarity index 100% rename from src/db/driver/sqlserver.go rename to src/db/sqlserver.go diff --git a/src/handlers/envelope.go b/src/handlers/envelope.go index 3e7926e..e52ccb0 100644 --- a/src/handlers/envelope.go +++ b/src/handlers/envelope.go @@ -8,8 +8,3 @@ type ResponseEnvelope struct { ExceedsMaxRows bool `json:"exceeds_max_rows"` Rows []map[string]interface{} } - -type ExecuteQueryEnvelope struct { - SQL string `json:"sql"` - Conn string `json:"connection_id"` -} diff --git a/src/handlers/query.go b/src/handlers/query.go index a861cbd..d858198 100644 --- a/src/handlers/query.go +++ b/src/handlers/query.go @@ -12,8 +12,8 @@ import ( ) func SelectQuery(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query().Get("query") - conn := r.URL.Query().Get("conn") + query := r.URL.Query().Get("sql") + conn := r.URL.Query().Get("connection_id") if query == "" || conn == "" { errorText := "Missing parameter" app.Log.Error(errorText) @@ -22,8 +22,8 @@ func SelectQuery(w http.ResponseWriter, r *http.Request) { } app.Log.WithFields(logrus.Fields{ - "query": query, - "conn": conn, + "sql": query, + "connection_id": conn, }).Debug("SQL query received:") // Search existings connection in the pool @@ -62,18 +62,27 @@ func SelectQuery(w http.ResponseWriter, r *http.Request) { } func ExecuteQuery(w http.ResponseWriter, r *http.Request) { - var payload ExecuteQueryEnvelope - err := json.NewDecoder(r.Body).Decode(&payload) + + var requestBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&requestBody) if err != nil { - http.Error(w, "Invalid JSON payload", http.StatusBadRequest) + http.Error(w, "Error decoding JSON", http.StatusBadRequest) return } - conn, ok := db.Handler.GetById(payload.Conn, true) + defer r.Body.Close() + + conn, ok := db.Handler.GetById(requestBody["connection_id"].(string), true) if !ok { http.Error(w, "Invalid connection id", http.StatusBadRequest) return } - _, err = conn.Exec(payload.SQL) + + app.Log.WithFields(logrus.Fields{ + "sql": requestBody["sql"].(string), + "connection_id": requestBody["connection_id"].(string), + }).Debug("SQL execute query received:") + + _, err = conn.Exec(requestBody["sql"].(string)) if err != nil { http.Error(w, "Invalid SQL query", http.StatusBadRequest) } diff --git a/src/handlers/statement.go b/src/handlers/statement.go index d52f0b6..7dc81d0 100644 --- a/src/handlers/statement.go +++ b/src/handlers/statement.go @@ -3,35 +3,56 @@ package handlers import ( "encoding/json" "net/http" + "sql-proxy/src/db" ) func PrepareStatement(w http.ResponseWriter, r *http.Request) { - var payload ExecuteQueryEnvelope - err := json.NewDecoder(r.Body).Decode(&payload) + var requestBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&requestBody) if err != nil { - http.Error(w, "Invalid JSON payload", http.StatusBadRequest) + http.Error(w, "Error decoding JSON", http.StatusBadRequest) return } - /* - conn, ok := app.DbHandler.GetById(payload.Conn, true) + defer r.Body.Close() - if !ok { - http.Error(w, "Invalid connection id", http.StatusBadRequest) - return - } + conn, ok := db.Handler.GetById(requestBody["connection_id"].(string), true) + if !ok { + http.Error(w, "Invalid connection id", http.StatusBadRequest) + return + } - stmt, err := conn.Prepare(payload.SQL) + stmt, err := conn.Prepare(requestBody["sql"].(string)) + if err != nil { + http.Error(w, "Failed to prepare statement", http.StatusBadRequest) + } - if err != nil { - http.Error(w, "Failed to prepare statement", http.StatusBadRequest) - } - */ + stmt_id, ok := db.Handler.PutPreparedStatement(requestBody["connection_id"].(string), stmt) + if !ok { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + + _, err = w.Write([]byte(stmt_id)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } func SelectStatement(w http.ResponseWriter, r *http.Request) { + var requestBody map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&requestBody) + if err != nil { + http.Error(w, "Error decoding JSON", http.StatusBadRequest) + return + } + defer r.Body.Close() + // to do } func ExecuteStatement(w http.ResponseWriter, r *http.Request) { // to do } + +func CloseStatement(w http.ResponseWriter, r *http.Request) { + // to do +} diff --git a/src/main.go b/src/main.go index ff4c90a..0805aa4 100644 --- a/src/main.go +++ b/src/main.go @@ -36,6 +36,7 @@ func main() { router.HandleFunc("/api/v1/statement/prepare", handlers.PrepareStatement).Methods("POST") router.HandleFunc("/api/v1/statement/select", handlers.SelectStatement).Methods("GET") router.HandleFunc("/api/v1/statement/execute", handlers.ExecuteStatement).Methods("POST") + router.HandleFunc("/api/v1/statement/close", handlers.CloseStatement).Methods("DELETE") router.HandleFunc("/healthz", handlers.Healthz).Methods("GET") router.HandleFunc("/readyz", handlers.Readyz).Methods("GET") router.HandleFunc("/livez", handlers.Livez).Methods("GET")