1
0
mirror of https://github.com/alm494/sql_proxy.git synced 2025-10-08 22:01:51 +02:00

src reorganization

This commit is contained in:
Almaz Sharipov
2025-02-16 16:06:08 +03:00
parent f2a81d46bf
commit 6761cdebd2
16 changed files with 136 additions and 84 deletions

View File

@@ -35,8 +35,8 @@ prod: clean
@echo "Building $(PROJECT_NAME) or production..." @echo "Building $(PROJECT_NAME) or production..."
GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \ GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \
-ldflags="-s -w \ -ldflags="-s -w \
-X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION} \ -X ${PROJECT_NAME}/src/app.BuildVersion=${BUILD_VERSION} \
-X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME) $(GO_FILES) -X ${PROJECT_NAME}/src/app.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME) $(GO_FILES)
@echo "Production build completed." @echo "Production build completed."
# Build for debugging # Build for debugging
@@ -44,8 +44,8 @@ debug: clean
@echo "Building $(PROJECT_NAME) or production..." @echo "Building $(PROJECT_NAME) or production..."
GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \ GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \
-ldflags="\ -ldflags="\
-X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION}-debug \ -X ${PROJECT_NAME}/src/app.BuildVersion=${BUILD_VERSION}-debug \
-X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME)-debug $(GO_FILES) -X ${PROJECT_NAME}/src/app.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME)-debug $(GO_FILES)
@echo "Debug build completed." @echo "Debug build completed."
# Run # Run

3
go.mod
View File

@@ -13,12 +13,15 @@ require (
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.33.0
) )
require github.com/gorilla/securecookie v1.1.2 // indirect
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/gorilla/sessions v1.4.0
github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect

4
go.sum
View File

@@ -24,6 +24,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=

View File

@@ -1,4 +1,4 @@
package utils package app
import ( import (
"os" "os"

View File

@@ -1,4 +1,4 @@
package utils package app
import ( import (
"os" "os"

View File

@@ -1,4 +1,4 @@
package version package app
var ( var (
BuildTime = "none" BuildTime = "none"

6
src/db/db.go Normal file
View File

@@ -0,0 +1,6 @@
package db
var (
Handler DbList
MaxRows uint32 = 10000
)

View File

@@ -1,15 +0,0 @@
package db
import (
"database/sql"
"time"
)
type DbConn struct {
// Hash, as sql.DB does not store credentials
Hash [32]byte
// SQL server connection pool (provided by the driver)
DB *sql.DB
// Last use
Timestamp time.Time
}

View File

@@ -6,16 +6,6 @@ import (
"encoding/gob" "encoding/gob"
) )
type DbConnInfo struct {
DbType string `json:"db_type"`
Host string `json:"host"`
Port uint16 `json:"port"`
User string `json:"user"`
Password string `json:"password"`
DbName string `json:"db_name"`
SSL bool `json:"ssl"`
}
func (o DbConnInfo) GetHash() ([32]byte, error) { func (o DbConnInfo) GetHash() ([32]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
var hash [32]byte var hash [32]byte

View File

@@ -5,20 +5,14 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"net/url" "net/url"
"sql-proxy/src/utils" "sql-proxy/src/app"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Class model to keep open SQL connections in the pool
// with concurrent read/write access
type DbList struct {
items sync.Map
}
// Gets SQL server connection by GUID // Gets SQL server connection by GUID
func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) { func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) {
val, ok := o.items.Load(guid) val, ok := o.items.Load(guid)
@@ -30,7 +24,7 @@ func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) {
} }
return res.DB, true return res.DB, true
} }
utils.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid)) app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid))
return nil, false return nil, false
} }
@@ -40,7 +34,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
hash, err := connInfo.GetHash() hash, err := connInfo.GetHash()
if err != nil { if err != nil {
errMsg := "Hash calculation failed" errMsg := "Hash calculation failed"
utils.Log.WithError(err).Error(errMsg) app.Log.WithError(err).Error(errMsg)
return errMsg, false return errMsg, false
} }
@@ -50,7 +44,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
func(key, value interface{}) bool { func(key, value interface{}) bool {
if bytes.Equal(value.(*DbConn).Hash[:], hash[:]) { if bytes.Equal(value.(*DbConn).Hash[:], hash[:]) {
guid = key.(string) guid = key.(string)
utils.Log.Debug(fmt.Sprintf("DB connection with id %s found in the pool", guid)) app.Log.Debug(fmt.Sprintf("DB connection with id %s found in the pool", guid))
return false // stop iteraton return false // stop iteraton
} }
return true // continue iteration return true // continue iteration
@@ -67,7 +61,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
} else { } else {
// Remove dead connection from the pool // Remove dead connection from the pool
o.items.Delete(guid) o.items.Delete(guid)
utils.Log.Debug(fmt.Sprintf("DB connection with id %s is dead and removed from the pool", guid)) app.Log.Debug(fmt.Sprintf("DB connection with id %s is dead and removed from the pool", guid))
} }
} }
@@ -100,7 +94,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName) connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName)
default: default:
errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType) errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType)
utils.Log.Error(errMsg) app.Log.Error(errMsg)
return errMsg, false return errMsg, false
} }
@@ -113,7 +107,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
// 3. Check for failure // 3. Check for failure
if err != nil { if err != nil {
errMsg := "Error establishing SQL server connection" errMsg := "Error establishing SQL server connection"
utils.Log.WithError(err).Error(errMsg) app.Log.WithError(err).Error(errMsg)
return errMsg, false return errMsg, false
} }
@@ -121,7 +115,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
err = newDb.Ping() err = newDb.Ping()
if err != nil { if err != nil {
errMsg := "Just created SQL connection is dead" errMsg := "Just created SQL connection is dead"
utils.Log.WithError(err).Error(errMsg) app.Log.WithError(err).Error(errMsg)
return errMsg, false return errMsg, false
} }
@@ -135,7 +129,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
o.items.Store(newId, &newItem) o.items.Store(newId, &newItem)
utils.Log.WithFields(logrus.Fields{ app.Log.WithFields(logrus.Fields{
"Host": connInfo.Host, "Host": connInfo.Host,
"Port": connInfo.Port, "Port": connInfo.Port,
"dbName": connInfo.DbName, "dbName": connInfo.DbName,
@@ -153,7 +147,7 @@ func (o *DbList) RunMaintenance() {
for { for {
<-ticker.C <-ticker.C
utils.Log.Debug("Regular task: checking if pooled SQL connections are alive...") app.Log.Debug("Regular task: checking if pooled SQL connections are alive...")
// detect dead connections // detect dead connections
var deadItems []string var deadItems []string
@@ -173,7 +167,7 @@ func (o *DbList) RunMaintenance() {
conn.Close() conn.Close()
o.Delete(item) o.Delete(item)
} }
utils.Log.Debug(fmt.Sprintf("Regular task: %d dead connections removed", len(deadItems))) app.Log.Debug(fmt.Sprintf("Regular task: %d dead connections removed", len(deadItems)))
} }
} }

38
src/db/types.go Normal file
View File

@@ -0,0 +1,38 @@
package db
import (
"database/sql"
"sync"
"time"
)
// Class model to keep open SQL connections in the pool
// with concurrent read/write access
type DbList struct {
items sync.Map
}
// Keeps SQL Db connection information
type DbConn struct {
Hash [32]byte // Hash, as sql.DB does not store credentials
DB *sql.DB // SQL server connection pool (provided by the driver)
Timestamp time.Time // Last use (TO DO!)
Stmt []DbStmt // Prepared SQL statements
}
// Keeps SQL prepared statement information
type DbStmt struct {
Id string
Stmt *sql.Stmt
}
// Keeps SQL connection string information
type DbConnInfo struct {
DbType string `json:"db_type"`
Host string `json:"host"`
Port uint16 `json:"port"`
User string `json:"user"`
Password string `json:"password"`
DbName string `json:"db_name"`
SSL bool `json:"ssl"`
}

View File

@@ -1,7 +0,0 @@
package db
// Global vars
var (
DbHandler DbList
MaxRows uint32 = 10000
)

View File

@@ -4,8 +4,8 @@ import (
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
"sql-proxy/src/app"
"sql-proxy/src/db" "sql-proxy/src/db"
"sql-proxy/src/utils"
) )
func CreateConnection(w http.ResponseWriter, r *http.Request) { func CreateConnection(w http.ResponseWriter, r *http.Request) {
@@ -14,16 +14,16 @@ func CreateConnection(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&dbConnInfo) err := json.NewDecoder(r.Body).Decode(&dbConnInfo)
if err != nil { if err != nil {
errorMsg := "Error decoding JSON" errorMsg := "Error decoding JSON"
utils.Log.Error(errorMsg) app.Log.Error(errorMsg)
http.Error(w, errorMsg, http.StatusBadRequest) http.Error(w, errorMsg, http.StatusBadRequest)
return return
} }
connGuid, ok := db.DbHandler.GetByParams(&dbConnInfo) connGuid, ok := db.Handler.GetByParams(&dbConnInfo)
if !ok { if !ok {
errorMsg := "Failed to get SQL connection" errorMsg := "Failed to get SQL connection"
utils.Log.Error(errorMsg) app.Log.Error(errorMsg)
http.Error(w, errorMsg, http.StatusInternalServerError) http.Error(w, errorMsg, http.StatusInternalServerError)
} else { } else {
_, err := w.Write([]byte(connGuid)) _, err := w.Write([]byte(connGuid))
@@ -40,5 +40,5 @@ func CloseConnection(w http.ResponseWriter, r *http.Request) {
return return
} }
defer r.Body.Close() defer r.Body.Close()
db.DbHandler.Delete(string(bodyBytes)) db.Handler.Delete(string(bodyBytes))
} }

View File

@@ -5,39 +5,39 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"sql-proxy/src/app"
"sql-proxy/src/db" "sql-proxy/src/db"
"sql-proxy/src/utils"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func GetQuery(w http.ResponseWriter, r *http.Request) { func SelectQuery(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query().Get("query") query := r.URL.Query().Get("query")
conn := r.URL.Query().Get("conn") conn := r.URL.Query().Get("conn")
if query == "" || conn == "" { if query == "" || conn == "" {
errorText := "Missing parameter" errorText := "Missing parameter"
utils.Log.Error(errorText) app.Log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
utils.Log.WithFields(logrus.Fields{ app.Log.WithFields(logrus.Fields{
"query": query, "query": query,
"conn": conn, "conn": conn,
}).Debug("SQL query received:") }).Debug("SQL query received:")
// Search existings connection in the pool // Search existings connection in the pool
dbConn, ok := db.DbHandler.GetById(conn, true) dbConn, ok := db.Handler.GetById(conn, true)
if !ok { if !ok {
errorText := "Failed to get SQL connection" errorText := "Failed to get SQL connection"
utils.Log.Error(errorText, ": ", conn) app.Log.Error(errorText, ": ", conn)
http.Error(w, errorText, http.StatusForbidden) http.Error(w, errorText, http.StatusForbidden)
return return
} }
rows, err := dbConn.Query(query) rows, err := dbConn.Query(query)
if err != nil { if err != nil {
utils.Log.WithError(err).Error("SQL query error") app.Log.WithError(err).Error("SQL query error")
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
@@ -45,7 +45,7 @@ func GetQuery(w http.ResponseWriter, r *http.Request) {
columns, err := rows.Columns() columns, err := rows.Columns()
if err != nil { if err != nil {
utils.Log.WithError(err).Error("Invalid query return value") app.Log.WithError(err).Error("Invalid query return value")
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
@@ -68,7 +68,7 @@ func ExecuteQuery(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest) http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return return
} }
conn, ok := db.DbHandler.GetById(payload.Conn, true) conn, ok := db.Handler.GetById(payload.Conn, true)
if !ok { if !ok {
http.Error(w, "Invalid connection id", http.StatusBadRequest) http.Error(w, "Invalid connection id", http.StatusBadRequest)
return return

37
src/handlers/statement.go Normal file
View File

@@ -0,0 +1,37 @@
package handlers
import (
"encoding/json"
"net/http"
)
func PrepareStatement(w http.ResponseWriter, r *http.Request) {
var payload ExecuteQueryEnvelope
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return
}
/*
conn, ok := app.DbHandler.GetById(payload.Conn, true)
if !ok {
http.Error(w, "Invalid connection id", http.StatusBadRequest)
return
}
stmt, err := conn.Prepare(payload.SQL)
if err != nil {
http.Error(w, "Failed to prepare statement", http.StatusBadRequest)
}
*/
}
func SelectStatement(w http.ResponseWriter, r *http.Request) {
// to do
}
func ExecuteStatement(w http.ResponseWriter, r *http.Request) {
// to do
}

View File

@@ -5,10 +5,9 @@ import (
"net/http" "net/http"
"os" "os"
"sql-proxy/src/app"
"sql-proxy/src/db" "sql-proxy/src/db"
"sql-proxy/src/handlers" "sql-proxy/src/handlers"
"sql-proxy/src/utils"
"sql-proxy/src/version"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
@@ -19,32 +18,35 @@ func main() {
var err error var err error
// Application params taken from OS environment // Application params taken from OS environment
utils.Log.SetLevel(logrus.Level(utils.GetIntEnvOrDefault("LOG_LEVEL", 2))) app.Log.SetLevel(logrus.Level(app.GetIntEnvOrDefault("LOG_LEVEL", 2)))
bindAddress := os.Getenv("BIND_ADDR") bindAddress := os.Getenv("BIND_ADDR")
bindPort := utils.GetIntEnvOrDefault("BIND_PORT", 8080) bindPort := app.GetIntEnvOrDefault("BIND_PORT", 8080)
db.MaxRows = utils.GetIntEnvOrDefault("MAX_ROWS", 10000) db.MaxRows = app.GetIntEnvOrDefault("MAX_ROWS", 10000)
tlsCert := os.Getenv("TLS_CERT") tlsCert := os.Getenv("TLS_CERT")
tlsKey := os.Getenv("TLS_KEY") tlsKey := os.Getenv("TLS_KEY")
// Scheduled maintenance task // Scheduled maintenance task
go db.DbHandler.RunMaintenance() go db.Handler.RunMaintenance()
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/v1/query/select", handlers.GetQuery).Methods("GET")
router.HandleFunc("/api/v1/query/execute", handlers.ExecuteQuery).Methods("POST")
router.HandleFunc("/api/v1/connection/create", handlers.CreateConnection).Methods("POST") router.HandleFunc("/api/v1/connection/create", handlers.CreateConnection).Methods("POST")
router.HandleFunc("/api/v1/connection/delete", handlers.CloseConnection).Methods("DELETE") router.HandleFunc("/api/v1/connection/delete", handlers.CloseConnection).Methods("DELETE")
router.HandleFunc("/api/v1/query/select", handlers.SelectQuery).Methods("GET")
router.HandleFunc("/api/v1/query/execute", handlers.ExecuteQuery).Methods("POST")
router.HandleFunc("/api/v1/statement/prepare", handlers.PrepareStatement).Methods("POST")
router.HandleFunc("/api/v1/statement/select", handlers.SelectStatement).Methods("GET")
router.HandleFunc("/api/v1/statement/execute", handlers.ExecuteStatement).Methods("POST")
router.HandleFunc("/healthz", handlers.Healthz).Methods("GET") router.HandleFunc("/healthz", handlers.Healthz).Methods("GET")
router.HandleFunc("/readyz", handlers.Readyz).Methods("GET") router.HandleFunc("/readyz", handlers.Readyz).Methods("GET")
router.HandleFunc("/livez", handlers.Livez).Methods("GET") router.HandleFunc("/livez", handlers.Livez).Methods("GET")
router.Handle("/metrics", promhttp.Handler()) router.Handle("/metrics", promhttp.Handler())
utils.Log.WithFields(logrus.Fields{ app.Log.WithFields(logrus.Fields{
"build_version": version.BuildVersion, "build_version": app.BuildVersion,
"build_time": version.BuildTime, "build_time": app.BuildTime,
}).Info("Starting server sql-proxy:") }).Info("Starting server sql-proxy:")
utils.Log.WithFields(logrus.Fields{ app.Log.WithFields(logrus.Fields{
"bind_port": bindPort, "bind_port": bindPort,
"bind_address": bindAddress, "bind_address": bindAddress,
"tls_cert": tlsCert, "tls_cert": tlsCert,
@@ -58,6 +60,6 @@ func main() {
err = http.ListenAndServe(addr, router) err = http.ListenAndServe(addr, router)
} }
if err != nil { if err != nil {
utils.Log.WithError(err).Fatal("Fatal error occurred, service stopped") app.Log.WithError(err).Fatal("Fatal error occurred, service stopped")
} }
} }