commit cc53828937edeb61573b5cd19c8a0b0ecb4fb5b4 Author: almaz Date: Mon Feb 10 15:45:39 2025 +0300 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..95d7b3d --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +bin/ +.vscode/ +go.sum diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..15dcc98 --- /dev/null +++ b/Makefile @@ -0,0 +1,59 @@ +PROJECT_NAME := sql-proxy +BUILD_VERSION := 1.1 +BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S') +BUILD_DIR := bin +GO_FILES := src/main.go + +# Build with SQL drivers, comment out if unused: +BUILD_WITH_POSTGRES_TAG := postgres +BUILD_WITH_MSSQL_TAG := sqlserver +BUILD_WITH_MYSQL_TAG := mysql + +# Go compiler basic settings +GOOS := linux +GOARCH := amd64 + +# Application settings to run: +LOG_LEVEL := 6 +BIND_PORT := 8080 +BIND_ADDR := localhost +MAX_ROWS := 10000 +#TLS_CERT := $(BUILD_DIR)/server.crt +#TLS_KEY := $(BUILD_DIR)/server.key + +TAGS := -tags=$(BUILD_WITH_POSTGRES_TAG),$(BUILD_WITH_MSSQL_TAG),$(BUILD_WITH_MYSQL_TAG) + +# Default +all: prod + +clean: + rm -f $(BUILD_DIR)/$(PROJECT_NAME) + rm -f $(BUILD_DIR)/$(PROJECT_NAME)-debug + +# Build for production +prod: clean + @echo "Building $(PROJECT_NAME) or production..." + GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \ + -ldflags="-s -w \ + -X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION} \ + -X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME) $(GO_FILES) + @echo "Production build completed." + +# Build for debugging +debug: clean + @echo "Building $(PROJECT_NAME) or production..." + GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \ + -ldflags="\ + -X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION}-debug \ + -X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME)-debug $(GO_FILES) + @echo "Debug build completed." + +# Run +run: debug + @echo "Running $(PROJECT_NAME) in debug mode..." + BIND_ADDR=$(BIND_ADDR) BIND_PORT=$(BIND_PORT) MAX_ROWS=$(MAX_ROWS) TLS_CERT=$(TLS_CERT) TLS_KEY=$(TLS_KEY) LOG_LEVEL=$(LOG_LEVEL) $(BUILD_DIR)/$(PROJECT_NAME)-debug + +# Run test +test: + @echo "Running tests..." + @go test ./... -v diff --git a/README.md b/README.md new file mode 100644 index 0000000..ef6a020 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# SQL-PROXY + +Simple REST service to replace ADODB calls from legacy software \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..dc855b8 --- /dev/null +++ b/go.mod @@ -0,0 +1,29 @@ +module sql-proxy + +go 1.23.5 + +require ( + github.com/google/uuid v1.6.0 + github.com/gorilla/mux v1.8.1 + github.com/lib/pq v1.10.9 + github.com/prometheus/client_golang v1.20.5 + github.com/sirupsen/logrus v1.9.3 + golang.org/x/crypto v0.32.0 +) + +require ( + filippo.io/edwards25519 v1.1.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/denisenkom/go-mssqldb v0.12.3 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect + github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/klauspost/compress v1.17.11 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + golang.org/x/sys v0.29.0 // indirect + google.golang.org/protobuf v1.36.4 // indirect +) diff --git a/src/db/dbconn.go b/src/db/dbconn.go new file mode 100644 index 0000000..b9ed92d --- /dev/null +++ b/src/db/dbconn.go @@ -0,0 +1,15 @@ +package db + +import ( + "database/sql" + "time" +) + +type DbConn struct { + // Hash, as sql.DB does not store credentials + Hash [32]byte + // SQL server connection pool (provided by the driver) + DB *sql.DB + // Last use + Timestamp time.Time +} diff --git a/src/db/dbconninfo.go b/src/db/dbconninfo.go new file mode 100644 index 0000000..907ef91 --- /dev/null +++ b/src/db/dbconninfo.go @@ -0,0 +1,31 @@ +package db + +import ( + "bytes" + "crypto/sha256" + "encoding/gob" +) + +type DbConnInfo struct { + DbType string `json:"db_type"` + Host string `json:"host"` + Port uint16 `json:"port"` + User string `json:"user"` + Password string `json:"password"` + DbName string `json:"db_name"` + SSL bool `json:"ssl"` +} + +func (o DbConnInfo) GetHash() ([32]byte, error) { + var buf bytes.Buffer + var hash [32]byte + + enc := gob.NewEncoder(&buf) + err := enc.Encode(o) + if err != nil { + return hash, err + } + + hash = sha256.Sum256(buf.Bytes()) + return hash, nil +} diff --git a/src/db/dblist.go b/src/db/dblist.go new file mode 100644 index 0000000..12dc34d --- /dev/null +++ b/src/db/dblist.go @@ -0,0 +1,167 @@ +package db + +import ( + "bytes" + "database/sql" + "fmt" + "net/url" + "sql-proxy/src/utils" + "sync" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" +) + +// Class model to keep open SQL connections in the pool +// with concurrent read/write access +type DbList struct { + items sync.Map +} + +// Gets SQL server connection by GUID +func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) { + val, ok := o.items.Load(guid) + if ok { + res := val.(*DbConn) + if updateTimestamp { + res.Timestamp = time.Now() + o.items.Store(guid, res) + } + return res.DB, true + } + utils.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid)) + return nil, false +} + +// Gets the new SQL server connection with parameters given. +// First lookups in pool, if fails opens new one and returns GUID value +func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) { + hash, err := connInfo.GetHash() + if err != nil { + errMsg := "Hash calculation failed" + utils.Log.WithError(err).Error(errMsg) + return errMsg, false + } + + // Step 1. Search existing connection by hash to reuse + guid := "" + o.items.Range( + func(key, value interface{}) bool { + if bytes.Equal(value.(*DbConn).Hash[:], hash[:]) { + guid = key.(string) + utils.Log.Debug(fmt.Sprintf("DB connection with id %s found in the pool", guid)) + return false + } + return true + }) + + // Step 2. Perform checks and return guid if passed + if len(guid) > 0 { + conn, ok := o.items.Load(guid) + if ok { + err = conn.(*DbConn).DB.Ping() + if err == nil { + // Everything is ok, return guid + return guid, true + } else { + // Remove dead connection from the pool + o.items.Delete(guid) + utils.Log.Debug(fmt.Sprintf("DB connection with id %s is dead and removed from the pool", guid)) + } + + } + } + + // Step 3. Nothing found, create the new + return o.getNewConnection(connInfo, hash) +} + +// Creates the new SQL connection regarding concurrency +func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, bool) { + // 1. Prepare DSN string + var dsn string + + encodedPassword := url.QueryEscape(connInfo.Password) + + switch connInfo.DbType { + case "postgres": + sslMode := "enable" + if !connInfo.SSL { + sslMode = "disable" + } + dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + connInfo.Host, connInfo.Port, connInfo.User, encodedPassword, connInfo.DbName, sslMode) + case "sqlserver": + dsn = fmt.Sprintf("server=%s;user id=%s;password=%s;database=%s;port=%d", + connInfo.Host, connInfo.User, encodedPassword, connInfo.DbName, connInfo.Port) + case "mysql": + dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", + connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName) + default: + errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType) + utils.Log.Error(errMsg) + return errMsg, false + } + + // 2. Open new SQL server connection + var err error + var newDb *sql.DB + + newDb, err = sql.Open(connInfo.DbType, dsn) + + // 3. Check for failure + if err != nil { + errMsg := "Error establishing SQL server connection" + utils.Log.WithError(err).Error(errMsg) + return errMsg, false + } + + // 4. Check if alive + err = newDb.Ping() + if err != nil { + errMsg := "Just created SQL connection is dead" + utils.Log.WithError(err).Error(errMsg) + return errMsg, false + } + + // 5. Insert into pool + newId := uuid.New().String() + newItem := DbConn{ + Hash: hash, + DB: newDb, + Timestamp: time.Now(), + } + + o.items.Store(newId, &newItem) + + utils.Log.WithFields(logrus.Fields{ + "Host": connInfo.Host, + "Port": connInfo.Port, + "dbName": connInfo.DbName, + "user": connInfo.User, + "dbType": connInfo.DbType, + "Id": newId, + }).Info(fmt.Sprintf("New SQL connection with id %s was added to the pool", newId)) + + return newId, true +} + +func (o *DbList) RunMaintenance() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + <-ticker.C + utils.Log.Debug("Regular task: checking if pooled SQL connections are alive...") + + // to do: detect and remove dead connections + //for o.items.Range() + } +} + +func (o *DbList) Delete(id string) { + + o.items.Delete(id) + +} diff --git a/src/db/mysql.go b/src/db/mysql.go new file mode 100644 index 0000000..65343be --- /dev/null +++ b/src/db/mysql.go @@ -0,0 +1,6 @@ +//go:build mysql +// +build mysql + +package db + +import _ "github.com/go-sql-driver/mysql" diff --git a/src/db/postges.go b/src/db/postges.go new file mode 100644 index 0000000..38fb375 --- /dev/null +++ b/src/db/postges.go @@ -0,0 +1,6 @@ +//go:build postgres +// +build postgres + +package db + +import _ "github.com/lib/pq" diff --git a/src/db/sqlserver.go b/src/db/sqlserver.go new file mode 100644 index 0000000..15c8a83 --- /dev/null +++ b/src/db/sqlserver.go @@ -0,0 +1,6 @@ +//go:build sqlserver +// +build sqlserver + +package db + +import _ "github.com/denisenkom/go-mssqldb" diff --git a/src/db/vars.go b/src/db/vars.go new file mode 100644 index 0000000..e9dac01 --- /dev/null +++ b/src/db/vars.go @@ -0,0 +1,7 @@ +package db + +// Global vars +var ( + DbHandler DbList + MaxRows uint32 = 10000 +) diff --git a/src/handlers/connection.go b/src/handlers/connection.go new file mode 100644 index 0000000..abacdf8 --- /dev/null +++ b/src/handlers/connection.go @@ -0,0 +1,37 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "sql-proxy/src/db" + "sql-proxy/src/utils" +) + +func CreateConnection(w http.ResponseWriter, r *http.Request) { + var dbConnInfo db.DbConnInfo + + err := json.NewDecoder(r.Body).Decode(&dbConnInfo) + if err != nil { + errorMsg := "Error decoding JSON" + utils.Log.Error(errorMsg) + http.Error(w, errorMsg, http.StatusBadRequest) + return + } + + connGuid, ok := db.DbHandler.GetByParams(&dbConnInfo) + + if !ok { + errorMsg := "Failed to get SQL connection" + utils.Log.Error(errorMsg) + http.Error(w, errorMsg, http.StatusInternalServerError) + } else { + _, err := w.Write([]byte(connGuid)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } +} + +func CloseConnection(w http.ResponseWriter, r *http.Request) { + // to do +} diff --git a/src/handlers/envelope.go b/src/handlers/envelope.go new file mode 100644 index 0000000..e52ccb0 --- /dev/null +++ b/src/handlers/envelope.go @@ -0,0 +1,10 @@ +package handlers + +type ResponseEnvelope struct { + ApiVersion uint8 `json:"api_version"` + ConnectionId string `json:"connection_id"` + Info string `json:"info"` + RowsCount uint32 `json:"rows_count"` + ExceedsMaxRows bool `json:"exceeds_max_rows"` + Rows []map[string]interface{} +} diff --git a/src/handlers/k8s.go b/src/handlers/k8s.go new file mode 100644 index 0000000..d350230 --- /dev/null +++ b/src/handlers/k8s.go @@ -0,0 +1,21 @@ +package handlers + +import ( + "net/http" +) + +// Deprecated +func Healthz(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) +} + +func Readyz(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Ready")) +} + +func Livez(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("Live")) +} diff --git a/src/handlers/query.go b/src/handlers/query.go new file mode 100644 index 0000000..7747675 --- /dev/null +++ b/src/handlers/query.go @@ -0,0 +1,102 @@ +package handlers + +import ( + "database/sql" + "encoding/json" + "net/http" + + "sql-proxy/src/db" + "sql-proxy/src/utils" + + "github.com/sirupsen/logrus" +) + +func GetQuery(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query().Get("query") + conn := r.URL.Query().Get("conn") + if query == "" || conn == "" { + errorText := "Missing parameter" + utils.Log.Error(errorText) + http.Error(w, errorText, http.StatusBadRequest) + return + } + + utils.Log.WithFields(logrus.Fields{ + "query": query, + "conn": conn, + }).Debug("SQL query received:") + + // Search existings connection in the pool + dbConn, ok := db.DbHandler.GetById(conn, true) + if !ok { + errorText := "Failed to get SQL connection" + utils.Log.Error(errorText, ": ", conn) + http.Error(w, errorText, http.StatusForbidden) + return + } + + rows, err := dbConn.Query(query) + if err != nil { + utils.Log.WithError(err).Error("SQL query error") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + utils.Log.WithError(err).Error("Invalid query return value") + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + tableData, rowsCount, exceedsMaxRows := convertRows(rows, &columns) + + var envelope ResponseEnvelope + envelope.RowsCount = rowsCount + envelope.ExceedsMaxRows = exceedsMaxRows + envelope.Rows = *tableData + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(envelope) +} + +func ExecuteQuery(w http.ResponseWriter, r *http.Request) { + // to do +} + +// Converts SQL query result to json +func convertRows(rows *sql.Rows, columns *[]string) (*[]map[string]interface{}, uint32, bool) { + var rowsCount uint32 = 0 + colsCount := len(*columns) + tableData := make([]map[string]interface{}, 0) + values := make([]interface{}, colsCount) + valuePtrs := make([]interface{}, colsCount) + exceedsMaxRows := false + + for rows.Next() { + for i := range *columns { + valuePtrs[i] = &values[i] + } + rows.Scan(valuePtrs...) + entry := make(map[string]interface{}) + for i, col := range *columns { + var v interface{} + val := values[i] + b, ok := val.([]byte) + if ok { + v = string(b) + } else { + v = val + } + entry[col] = v + } + if rowsCount > db.MaxRows { + exceedsMaxRows = true + break + } + tableData = append(tableData, entry) + rowsCount++ + } + return &tableData, rowsCount, exceedsMaxRows +} diff --git a/src/main.go b/src/main.go new file mode 100644 index 0000000..ae23054 --- /dev/null +++ b/src/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "fmt" + "net/http" + "os" + + "sql-proxy/src/db" + "sql-proxy/src/handlers" + "sql-proxy/src/utils" + "sql-proxy/src/version" + + "github.com/gorilla/mux" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/sirupsen/logrus" +) + +func main() { + var err error + + // Application params taken from OS environment + utils.Log.SetLevel(logrus.Level(utils.GetIntEnvOrDefault("LOG_LEVEL", 2))) + bindAddress := os.Getenv("BIND_ADDR") + bindPort := utils.GetIntEnvOrDefault("BIND_PORT", 8080) + db.MaxRows = utils.GetIntEnvOrDefault("MAX_ROWS", 10000) + tlsCert := os.Getenv("TLS_CERT") + tlsKey := os.Getenv("TLS_KEY") + + // Scheduled maintenance task + go db.DbHandler.RunMaintenance() + + router := mux.NewRouter() + router.HandleFunc("/api/v1/query/select", handlers.GetQuery).Methods("GET") + router.HandleFunc("/api/v1/query/execute", handlers.ExecuteQuery).Methods("POST") + router.HandleFunc("/api/v1/connection/create", handlers.CreateConnection).Methods("POST") + router.HandleFunc("/api/v1/connection/delete", handlers.CloseConnection).Methods("DELETE") + router.HandleFunc("/healthz", handlers.Healthz).Methods("GET") + router.HandleFunc("/readyz", handlers.Readyz).Methods("GET") + router.HandleFunc("/livez", handlers.Livez).Methods("GET") + router.Handle("/metrics", promhttp.Handler()) + + utils.Log.WithFields(logrus.Fields{ + "build_version": version.BuildVersion, + "build_time": version.BuildTime, + }).Info("Starting server sql-proxy:") + + utils.Log.WithFields(logrus.Fields{ + "bind_port": bindPort, + "bind_address": bindAddress, + "tls_cert": tlsCert, + "tls_key": tlsKey, + }).Info("Server started with the following parameters:") + + addr := fmt.Sprintf("%s:%d", bindAddress, bindPort) + if len(tlsCert) > 0 && len(tlsKey) > 0 { + err = http.ListenAndServeTLS(addr, tlsCert, tlsKey, router) + } else { + err = http.ListenAndServe(addr, router) + } + if err != nil { + utils.Log.WithError(err).Fatal("Fatal error occurred, service stopped") + } +} diff --git a/src/utils/env.go b/src/utils/env.go new file mode 100644 index 0000000..a177fa7 --- /dev/null +++ b/src/utils/env.go @@ -0,0 +1,23 @@ +package utils + +import ( + "os" + "strconv" +) + +func GetIntEnvOrDefault(env string, defaultValue uint32) uint32 { + strValue := os.Getenv(env) + + if len(strValue) == 0 { + return defaultValue + } else { + uintValue, err := strconv.ParseUint(strValue, 10, 32) + if err != nil { + Log.Error(env + " env value cannot be parsed as uint, reset to " + strconv.FormatUint(uint64(defaultValue), 10)) + return defaultValue + } else { + return uint32(uintValue) + } + } + +} diff --git a/src/utils/log.go b/src/utils/log.go new file mode 100644 index 0000000..8e62d17 --- /dev/null +++ b/src/utils/log.go @@ -0,0 +1,14 @@ +package utils + +import ( + "os" + + "github.com/sirupsen/logrus" +) + +var Log = &logrus.Logger{ + Out: os.Stderr, + Formatter: new(logrus.TextFormatter), + Hooks: make(logrus.LevelHooks), + Level: logrus.ErrorLevel, +} diff --git a/src/version/version.go b/src/version/version.go new file mode 100644 index 0000000..57e433b --- /dev/null +++ b/src/version/version.go @@ -0,0 +1,6 @@ +package version + +var ( + BuildTime = "none" + BuildVersion = "none" +)