1
0
mirror of https://github.com/alm494/sql_proxy.git synced 2025-10-08 22:01:51 +02:00

src reorganization

This commit is contained in:
Almaz Sharipov
2025-02-16 16:06:08 +03:00
parent f2a81d46bf
commit 6761cdebd2
16 changed files with 136 additions and 84 deletions

View File

@@ -35,8 +35,8 @@ prod: clean
@echo "Building $(PROJECT_NAME) or production..."
GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \
-ldflags="-s -w \
-X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION} \
-X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME) $(GO_FILES)
-X ${PROJECT_NAME}/src/app.BuildVersion=${BUILD_VERSION} \
-X ${PROJECT_NAME}/src/app.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME) $(GO_FILES)
@echo "Production build completed."
# Build for debugging
@@ -44,8 +44,8 @@ debug: clean
@echo "Building $(PROJECT_NAME) or production..."
GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \
-ldflags="\
-X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION}-debug \
-X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME)-debug $(GO_FILES)
-X ${PROJECT_NAME}/src/app.BuildVersion=${BUILD_VERSION}-debug \
-X ${PROJECT_NAME}/src/app.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME)-debug $(GO_FILES)
@echo "Debug build completed."
# Run

3
go.mod
View File

@@ -13,12 +13,15 @@ require (
golang.org/x/crypto v0.33.0
)
require github.com/gorilla/securecookie v1.1.2 // indirect
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/gorilla/sessions v1.4.0
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.1 // indirect

4
go.sum
View File

@@ -24,6 +24,10 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=

View File

@@ -1,4 +1,4 @@
package utils
package app
import (
"os"

View File

@@ -1,4 +1,4 @@
package utils
package app
import (
"os"

View File

@@ -1,4 +1,4 @@
package version
package app
var (
BuildTime = "none"

6
src/db/db.go Normal file
View File

@@ -0,0 +1,6 @@
package db
var (
Handler DbList
MaxRows uint32 = 10000
)

View File

@@ -1,15 +0,0 @@
package db
import (
"database/sql"
"time"
)
type DbConn struct {
// Hash, as sql.DB does not store credentials
Hash [32]byte
// SQL server connection pool (provided by the driver)
DB *sql.DB
// Last use
Timestamp time.Time
}

View File

@@ -6,16 +6,6 @@ import (
"encoding/gob"
)
type DbConnInfo struct {
DbType string `json:"db_type"`
Host string `json:"host"`
Port uint16 `json:"port"`
User string `json:"user"`
Password string `json:"password"`
DbName string `json:"db_name"`
SSL bool `json:"ssl"`
}
func (o DbConnInfo) GetHash() ([32]byte, error) {
var buf bytes.Buffer
var hash [32]byte

View File

@@ -5,20 +5,14 @@ import (
"database/sql"
"fmt"
"net/url"
"sql-proxy/src/utils"
"sync"
"sql-proxy/src/app"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
// Class model to keep open SQL connections in the pool
// with concurrent read/write access
type DbList struct {
items sync.Map
}
// Gets SQL server connection by GUID
func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) {
val, ok := o.items.Load(guid)
@@ -30,7 +24,7 @@ func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) {
}
return res.DB, true
}
utils.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid))
app.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid))
return nil, false
}
@@ -40,7 +34,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
hash, err := connInfo.GetHash()
if err != nil {
errMsg := "Hash calculation failed"
utils.Log.WithError(err).Error(errMsg)
app.Log.WithError(err).Error(errMsg)
return errMsg, false
}
@@ -50,7 +44,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
func(key, value interface{}) bool {
if bytes.Equal(value.(*DbConn).Hash[:], hash[:]) {
guid = key.(string)
utils.Log.Debug(fmt.Sprintf("DB connection with id %s found in the pool", guid))
app.Log.Debug(fmt.Sprintf("DB connection with id %s found in the pool", guid))
return false // stop iteraton
}
return true // continue iteration
@@ -67,7 +61,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
} else {
// Remove dead connection from the pool
o.items.Delete(guid)
utils.Log.Debug(fmt.Sprintf("DB connection with id %s is dead and removed from the pool", guid))
app.Log.Debug(fmt.Sprintf("DB connection with id %s is dead and removed from the pool", guid))
}
}
@@ -100,7 +94,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName)
default:
errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType)
utils.Log.Error(errMsg)
app.Log.Error(errMsg)
return errMsg, false
}
@@ -113,7 +107,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
// 3. Check for failure
if err != nil {
errMsg := "Error establishing SQL server connection"
utils.Log.WithError(err).Error(errMsg)
app.Log.WithError(err).Error(errMsg)
return errMsg, false
}
@@ -121,7 +115,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
err = newDb.Ping()
if err != nil {
errMsg := "Just created SQL connection is dead"
utils.Log.WithError(err).Error(errMsg)
app.Log.WithError(err).Error(errMsg)
return errMsg, false
}
@@ -135,7 +129,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
o.items.Store(newId, &newItem)
utils.Log.WithFields(logrus.Fields{
app.Log.WithFields(logrus.Fields{
"Host": connInfo.Host,
"Port": connInfo.Port,
"dbName": connInfo.DbName,
@@ -153,7 +147,7 @@ func (o *DbList) RunMaintenance() {
for {
<-ticker.C
utils.Log.Debug("Regular task: checking if pooled SQL connections are alive...")
app.Log.Debug("Regular task: checking if pooled SQL connections are alive...")
// detect dead connections
var deadItems []string
@@ -173,7 +167,7 @@ func (o *DbList) RunMaintenance() {
conn.Close()
o.Delete(item)
}
utils.Log.Debug(fmt.Sprintf("Regular task: %d dead connections removed", len(deadItems)))
app.Log.Debug(fmt.Sprintf("Regular task: %d dead connections removed", len(deadItems)))
}
}

38
src/db/types.go Normal file
View File

@@ -0,0 +1,38 @@
package db
import (
"database/sql"
"sync"
"time"
)
// Class model to keep open SQL connections in the pool
// with concurrent read/write access
type DbList struct {
items sync.Map
}
// Keeps SQL Db connection information
type DbConn struct {
Hash [32]byte // Hash, as sql.DB does not store credentials
DB *sql.DB // SQL server connection pool (provided by the driver)
Timestamp time.Time // Last use (TO DO!)
Stmt []DbStmt // Prepared SQL statements
}
// Keeps SQL prepared statement information
type DbStmt struct {
Id string
Stmt *sql.Stmt
}
// Keeps SQL connection string information
type DbConnInfo struct {
DbType string `json:"db_type"`
Host string `json:"host"`
Port uint16 `json:"port"`
User string `json:"user"`
Password string `json:"password"`
DbName string `json:"db_name"`
SSL bool `json:"ssl"`
}

View File

@@ -1,7 +0,0 @@
package db
// Global vars
var (
DbHandler DbList
MaxRows uint32 = 10000
)

View File

@@ -4,8 +4,8 @@ import (
"encoding/json"
"io"
"net/http"
"sql-proxy/src/app"
"sql-proxy/src/db"
"sql-proxy/src/utils"
)
func CreateConnection(w http.ResponseWriter, r *http.Request) {
@@ -14,16 +14,16 @@ func CreateConnection(w http.ResponseWriter, r *http.Request) {
err := json.NewDecoder(r.Body).Decode(&dbConnInfo)
if err != nil {
errorMsg := "Error decoding JSON"
utils.Log.Error(errorMsg)
app.Log.Error(errorMsg)
http.Error(w, errorMsg, http.StatusBadRequest)
return
}
connGuid, ok := db.DbHandler.GetByParams(&dbConnInfo)
connGuid, ok := db.Handler.GetByParams(&dbConnInfo)
if !ok {
errorMsg := "Failed to get SQL connection"
utils.Log.Error(errorMsg)
app.Log.Error(errorMsg)
http.Error(w, errorMsg, http.StatusInternalServerError)
} else {
_, err := w.Write([]byte(connGuid))
@@ -40,5 +40,5 @@ func CloseConnection(w http.ResponseWriter, r *http.Request) {
return
}
defer r.Body.Close()
db.DbHandler.Delete(string(bodyBytes))
db.Handler.Delete(string(bodyBytes))
}

View File

@@ -5,39 +5,39 @@ import (
"encoding/json"
"net/http"
"sql-proxy/src/app"
"sql-proxy/src/db"
"sql-proxy/src/utils"
"github.com/sirupsen/logrus"
)
func GetQuery(w http.ResponseWriter, r *http.Request) {
func SelectQuery(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query().Get("query")
conn := r.URL.Query().Get("conn")
if query == "" || conn == "" {
errorText := "Missing parameter"
utils.Log.Error(errorText)
app.Log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
utils.Log.WithFields(logrus.Fields{
app.Log.WithFields(logrus.Fields{
"query": query,
"conn": conn,
}).Debug("SQL query received:")
// Search existings connection in the pool
dbConn, ok := db.DbHandler.GetById(conn, true)
dbConn, ok := db.Handler.GetById(conn, true)
if !ok {
errorText := "Failed to get SQL connection"
utils.Log.Error(errorText, ": ", conn)
app.Log.Error(errorText, ": ", conn)
http.Error(w, errorText, http.StatusForbidden)
return
}
rows, err := dbConn.Query(query)
if err != nil {
utils.Log.WithError(err).Error("SQL query error")
app.Log.WithError(err).Error("SQL query error")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -45,7 +45,7 @@ func GetQuery(w http.ResponseWriter, r *http.Request) {
columns, err := rows.Columns()
if err != nil {
utils.Log.WithError(err).Error("Invalid query return value")
app.Log.WithError(err).Error("Invalid query return value")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -68,7 +68,7 @@ func ExecuteQuery(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return
}
conn, ok := db.DbHandler.GetById(payload.Conn, true)
conn, ok := db.Handler.GetById(payload.Conn, true)
if !ok {
http.Error(w, "Invalid connection id", http.StatusBadRequest)
return

37
src/handlers/statement.go Normal file
View File

@@ -0,0 +1,37 @@
package handlers
import (
"encoding/json"
"net/http"
)
func PrepareStatement(w http.ResponseWriter, r *http.Request) {
var payload ExecuteQueryEnvelope
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil {
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
return
}
/*
conn, ok := app.DbHandler.GetById(payload.Conn, true)
if !ok {
http.Error(w, "Invalid connection id", http.StatusBadRequest)
return
}
stmt, err := conn.Prepare(payload.SQL)
if err != nil {
http.Error(w, "Failed to prepare statement", http.StatusBadRequest)
}
*/
}
func SelectStatement(w http.ResponseWriter, r *http.Request) {
// to do
}
func ExecuteStatement(w http.ResponseWriter, r *http.Request) {
// to do
}

View File

@@ -5,10 +5,9 @@ import (
"net/http"
"os"
"sql-proxy/src/app"
"sql-proxy/src/db"
"sql-proxy/src/handlers"
"sql-proxy/src/utils"
"sql-proxy/src/version"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -19,32 +18,35 @@ func main() {
var err error
// Application params taken from OS environment
utils.Log.SetLevel(logrus.Level(utils.GetIntEnvOrDefault("LOG_LEVEL", 2)))
app.Log.SetLevel(logrus.Level(app.GetIntEnvOrDefault("LOG_LEVEL", 2)))
bindAddress := os.Getenv("BIND_ADDR")
bindPort := utils.GetIntEnvOrDefault("BIND_PORT", 8080)
db.MaxRows = utils.GetIntEnvOrDefault("MAX_ROWS", 10000)
bindPort := app.GetIntEnvOrDefault("BIND_PORT", 8080)
db.MaxRows = app.GetIntEnvOrDefault("MAX_ROWS", 10000)
tlsCert := os.Getenv("TLS_CERT")
tlsKey := os.Getenv("TLS_KEY")
// Scheduled maintenance task
go db.DbHandler.RunMaintenance()
go db.Handler.RunMaintenance()
router := mux.NewRouter()
router.HandleFunc("/api/v1/query/select", handlers.GetQuery).Methods("GET")
router.HandleFunc("/api/v1/query/execute", handlers.ExecuteQuery).Methods("POST")
router.HandleFunc("/api/v1/connection/create", handlers.CreateConnection).Methods("POST")
router.HandleFunc("/api/v1/connection/delete", handlers.CloseConnection).Methods("DELETE")
router.HandleFunc("/api/v1/query/select", handlers.SelectQuery).Methods("GET")
router.HandleFunc("/api/v1/query/execute", handlers.ExecuteQuery).Methods("POST")
router.HandleFunc("/api/v1/statement/prepare", handlers.PrepareStatement).Methods("POST")
router.HandleFunc("/api/v1/statement/select", handlers.SelectStatement).Methods("GET")
router.HandleFunc("/api/v1/statement/execute", handlers.ExecuteStatement).Methods("POST")
router.HandleFunc("/healthz", handlers.Healthz).Methods("GET")
router.HandleFunc("/readyz", handlers.Readyz).Methods("GET")
router.HandleFunc("/livez", handlers.Livez).Methods("GET")
router.Handle("/metrics", promhttp.Handler())
utils.Log.WithFields(logrus.Fields{
"build_version": version.BuildVersion,
"build_time": version.BuildTime,
app.Log.WithFields(logrus.Fields{
"build_version": app.BuildVersion,
"build_time": app.BuildTime,
}).Info("Starting server sql-proxy:")
utils.Log.WithFields(logrus.Fields{
app.Log.WithFields(logrus.Fields{
"bind_port": bindPort,
"bind_address": bindAddress,
"tls_cert": tlsCert,
@@ -58,6 +60,6 @@ func main() {
err = http.ListenAndServe(addr, router)
}
if err != nil {
utils.Log.WithError(err).Fatal("Fatal error occurred, service stopped")
app.Log.WithError(err).Fatal("Fatal error occurred, service stopped")
}
}