1
0
mirror of https://github.com/alm494/sql_proxy.git synced 2026-04-22 19:33:55 +02:00

Initial commit

This commit is contained in:
almaz
2025-02-10 15:45:39 +03:00
commit cc53828937
19 changed files with 608 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
bin/
.vscode/
go.sum
+59
View File
@@ -0,0 +1,59 @@
PROJECT_NAME := sql-proxy
BUILD_VERSION := 1.1
BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S')
BUILD_DIR := bin
GO_FILES := src/main.go
# Build with SQL drivers, comment out if unused:
BUILD_WITH_POSTGRES_TAG := postgres
BUILD_WITH_MSSQL_TAG := sqlserver
BUILD_WITH_MYSQL_TAG := mysql
# Go compiler basic settings
GOOS := linux
GOARCH := amd64
# Application settings to run:
LOG_LEVEL := 6
BIND_PORT := 8080
BIND_ADDR := localhost
MAX_ROWS := 10000
#TLS_CERT := $(BUILD_DIR)/server.crt
#TLS_KEY := $(BUILD_DIR)/server.key
TAGS := -tags=$(BUILD_WITH_POSTGRES_TAG),$(BUILD_WITH_MSSQL_TAG),$(BUILD_WITH_MYSQL_TAG)
# Default
all: prod
clean:
rm -f $(BUILD_DIR)/$(PROJECT_NAME)
rm -f $(BUILD_DIR)/$(PROJECT_NAME)-debug
# Build for production
prod: clean
@echo "Building $(PROJECT_NAME) or production..."
GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \
-ldflags="-s -w \
-X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION} \
-X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME) $(GO_FILES)
@echo "Production build completed."
# Build for debugging
debug: clean
@echo "Building $(PROJECT_NAME) or production..."
GOOS=${GOOS} GOARCH=${GOARCH} go build $(TAGS) \
-ldflags="\
-X ${PROJECT_NAME}/src/version.BuildVersion=${BUILD_VERSION}-debug \
-X ${PROJECT_NAME}/src/version.BuildTime=${BUILD_TIME}" -o $(BUILD_DIR)/$(PROJECT_NAME)-debug $(GO_FILES)
@echo "Debug build completed."
# Run
run: debug
@echo "Running $(PROJECT_NAME) in debug mode..."
BIND_ADDR=$(BIND_ADDR) BIND_PORT=$(BIND_PORT) MAX_ROWS=$(MAX_ROWS) TLS_CERT=$(TLS_CERT) TLS_KEY=$(TLS_KEY) LOG_LEVEL=$(LOG_LEVEL) $(BUILD_DIR)/$(PROJECT_NAME)-debug
# Run test
test:
@echo "Running tests..."
@go test ./... -v
+3
View File
@@ -0,0 +1,3 @@
# SQL-PROXY
Simple REST service to replace ADODB calls from legacy software
+29
View File
@@ -0,0 +1,29 @@
module sql-proxy
go 1.23.5
require (
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/lib/pq v1.10.9
github.com/prometheus/client_golang v1.20.5
github.com/sirupsen/logrus v1.9.3
golang.org/x/crypto v0.32.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/denisenkom/go-mssqldb v0.12.3 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
golang.org/x/sys v0.29.0 // indirect
google.golang.org/protobuf v1.36.4 // indirect
)
+15
View File
@@ -0,0 +1,15 @@
package db
import (
"database/sql"
"time"
)
type DbConn struct {
// Hash, as sql.DB does not store credentials
Hash [32]byte
// SQL server connection pool (provided by the driver)
DB *sql.DB
// Last use
Timestamp time.Time
}
+31
View File
@@ -0,0 +1,31 @@
package db
import (
"bytes"
"crypto/sha256"
"encoding/gob"
)
type DbConnInfo struct {
DbType string `json:"db_type"`
Host string `json:"host"`
Port uint16 `json:"port"`
User string `json:"user"`
Password string `json:"password"`
DbName string `json:"db_name"`
SSL bool `json:"ssl"`
}
func (o DbConnInfo) GetHash() ([32]byte, error) {
var buf bytes.Buffer
var hash [32]byte
enc := gob.NewEncoder(&buf)
err := enc.Encode(o)
if err != nil {
return hash, err
}
hash = sha256.Sum256(buf.Bytes())
return hash, nil
}
+167
View File
@@ -0,0 +1,167 @@
package db
import (
"bytes"
"database/sql"
"fmt"
"net/url"
"sql-proxy/src/utils"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
// Class model to keep open SQL connections in the pool
// with concurrent read/write access
type DbList struct {
items sync.Map
}
// Gets SQL server connection by GUID
func (o *DbList) GetById(guid string, updateTimestamp bool) (*sql.DB, bool) {
val, ok := o.items.Load(guid)
if ok {
res := val.(*DbConn)
if updateTimestamp {
res.Timestamp = time.Now()
o.items.Store(guid, res)
}
return res.DB, true
}
utils.Log.Error(fmt.Sprintf("SQL connection with guid='%s' not found", guid))
return nil, false
}
// Gets the new SQL server connection with parameters given.
// First lookups in pool, if fails opens new one and returns GUID value
func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
hash, err := connInfo.GetHash()
if err != nil {
errMsg := "Hash calculation failed"
utils.Log.WithError(err).Error(errMsg)
return errMsg, false
}
// Step 1. Search existing connection by hash to reuse
guid := ""
o.items.Range(
func(key, value interface{}) bool {
if bytes.Equal(value.(*DbConn).Hash[:], hash[:]) {
guid = key.(string)
utils.Log.Debug(fmt.Sprintf("DB connection with id %s found in the pool", guid))
return false
}
return true
})
// Step 2. Perform checks and return guid if passed
if len(guid) > 0 {
conn, ok := o.items.Load(guid)
if ok {
err = conn.(*DbConn).DB.Ping()
if err == nil {
// Everything is ok, return guid
return guid, true
} else {
// Remove dead connection from the pool
o.items.Delete(guid)
utils.Log.Debug(fmt.Sprintf("DB connection with id %s is dead and removed from the pool", guid))
}
}
}
// Step 3. Nothing found, create the new
return o.getNewConnection(connInfo, hash)
}
// Creates the new SQL connection regarding concurrency
func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, bool) {
// 1. Prepare DSN string
var dsn string
encodedPassword := url.QueryEscape(connInfo.Password)
switch connInfo.DbType {
case "postgres":
sslMode := "enable"
if !connInfo.SSL {
sslMode = "disable"
}
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
connInfo.Host, connInfo.Port, connInfo.User, encodedPassword, connInfo.DbName, sslMode)
case "sqlserver":
dsn = fmt.Sprintf("server=%s;user id=%s;password=%s;database=%s;port=%d",
connInfo.Host, connInfo.User, encodedPassword, connInfo.DbName, connInfo.Port)
case "mysql":
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName)
default:
errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType)
utils.Log.Error(errMsg)
return errMsg, false
}
// 2. Open new SQL server connection
var err error
var newDb *sql.DB
newDb, err = sql.Open(connInfo.DbType, dsn)
// 3. Check for failure
if err != nil {
errMsg := "Error establishing SQL server connection"
utils.Log.WithError(err).Error(errMsg)
return errMsg, false
}
// 4. Check if alive
err = newDb.Ping()
if err != nil {
errMsg := "Just created SQL connection is dead"
utils.Log.WithError(err).Error(errMsg)
return errMsg, false
}
// 5. Insert into pool
newId := uuid.New().String()
newItem := DbConn{
Hash: hash,
DB: newDb,
Timestamp: time.Now(),
}
o.items.Store(newId, &newItem)
utils.Log.WithFields(logrus.Fields{
"Host": connInfo.Host,
"Port": connInfo.Port,
"dbName": connInfo.DbName,
"user": connInfo.User,
"dbType": connInfo.DbType,
"Id": newId,
}).Info(fmt.Sprintf("New SQL connection with id %s was added to the pool", newId))
return newId, true
}
func (o *DbList) RunMaintenance() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
<-ticker.C
utils.Log.Debug("Regular task: checking if pooled SQL connections are alive...")
// to do: detect and remove dead connections
//for o.items.Range()
}
}
func (o *DbList) Delete(id string) {
o.items.Delete(id)
}
+6
View File
@@ -0,0 +1,6 @@
//go:build mysql
// +build mysql
package db
import _ "github.com/go-sql-driver/mysql"
+6
View File
@@ -0,0 +1,6 @@
//go:build postgres
// +build postgres
package db
import _ "github.com/lib/pq"
+6
View File
@@ -0,0 +1,6 @@
//go:build sqlserver
// +build sqlserver
package db
import _ "github.com/denisenkom/go-mssqldb"
+7
View File
@@ -0,0 +1,7 @@
package db
// Global vars
var (
DbHandler DbList
MaxRows uint32 = 10000
)
+37
View File
@@ -0,0 +1,37 @@
package handlers
import (
"encoding/json"
"net/http"
"sql-proxy/src/db"
"sql-proxy/src/utils"
)
func CreateConnection(w http.ResponseWriter, r *http.Request) {
var dbConnInfo db.DbConnInfo
err := json.NewDecoder(r.Body).Decode(&dbConnInfo)
if err != nil {
errorMsg := "Error decoding JSON"
utils.Log.Error(errorMsg)
http.Error(w, errorMsg, http.StatusBadRequest)
return
}
connGuid, ok := db.DbHandler.GetByParams(&dbConnInfo)
if !ok {
errorMsg := "Failed to get SQL connection"
utils.Log.Error(errorMsg)
http.Error(w, errorMsg, http.StatusInternalServerError)
} else {
_, err := w.Write([]byte(connGuid))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}
func CloseConnection(w http.ResponseWriter, r *http.Request) {
// to do
}
+10
View File
@@ -0,0 +1,10 @@
package handlers
type ResponseEnvelope struct {
ApiVersion uint8 `json:"api_version"`
ConnectionId string `json:"connection_id"`
Info string `json:"info"`
RowsCount uint32 `json:"rows_count"`
ExceedsMaxRows bool `json:"exceeds_max_rows"`
Rows []map[string]interface{}
}
+21
View File
@@ -0,0 +1,21 @@
package handlers
import (
"net/http"
)
// Deprecated
func Healthz(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
func Readyz(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Ready"))
}
func Livez(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Live"))
}
+102
View File
@@ -0,0 +1,102 @@
package handlers
import (
"database/sql"
"encoding/json"
"net/http"
"sql-proxy/src/db"
"sql-proxy/src/utils"
"github.com/sirupsen/logrus"
)
func GetQuery(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query().Get("query")
conn := r.URL.Query().Get("conn")
if query == "" || conn == "" {
errorText := "Missing parameter"
utils.Log.Error(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
utils.Log.WithFields(logrus.Fields{
"query": query,
"conn": conn,
}).Debug("SQL query received:")
// Search existings connection in the pool
dbConn, ok := db.DbHandler.GetById(conn, true)
if !ok {
errorText := "Failed to get SQL connection"
utils.Log.Error(errorText, ": ", conn)
http.Error(w, errorText, http.StatusForbidden)
return
}
rows, err := dbConn.Query(query)
if err != nil {
utils.Log.WithError(err).Error("SQL query error")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
utils.Log.WithError(err).Error("Invalid query return value")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
tableData, rowsCount, exceedsMaxRows := convertRows(rows, &columns)
var envelope ResponseEnvelope
envelope.RowsCount = rowsCount
envelope.ExceedsMaxRows = exceedsMaxRows
envelope.Rows = *tableData
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(envelope)
}
func ExecuteQuery(w http.ResponseWriter, r *http.Request) {
// to do
}
// Converts SQL query result to json
func convertRows(rows *sql.Rows, columns *[]string) (*[]map[string]interface{}, uint32, bool) {
var rowsCount uint32 = 0
colsCount := len(*columns)
tableData := make([]map[string]interface{}, 0)
values := make([]interface{}, colsCount)
valuePtrs := make([]interface{}, colsCount)
exceedsMaxRows := false
for rows.Next() {
for i := range *columns {
valuePtrs[i] = &values[i]
}
rows.Scan(valuePtrs...)
entry := make(map[string]interface{})
for i, col := range *columns {
var v interface{}
val := values[i]
b, ok := val.([]byte)
if ok {
v = string(b)
} else {
v = val
}
entry[col] = v
}
if rowsCount > db.MaxRows {
exceedsMaxRows = true
break
}
tableData = append(tableData, entry)
rowsCount++
}
return &tableData, rowsCount, exceedsMaxRows
}
+63
View File
@@ -0,0 +1,63 @@
package main
import (
"fmt"
"net/http"
"os"
"sql-proxy/src/db"
"sql-proxy/src/handlers"
"sql-proxy/src/utils"
"sql-proxy/src/version"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
)
func main() {
var err error
// Application params taken from OS environment
utils.Log.SetLevel(logrus.Level(utils.GetIntEnvOrDefault("LOG_LEVEL", 2)))
bindAddress := os.Getenv("BIND_ADDR")
bindPort := utils.GetIntEnvOrDefault("BIND_PORT", 8080)
db.MaxRows = utils.GetIntEnvOrDefault("MAX_ROWS", 10000)
tlsCert := os.Getenv("TLS_CERT")
tlsKey := os.Getenv("TLS_KEY")
// Scheduled maintenance task
go db.DbHandler.RunMaintenance()
router := mux.NewRouter()
router.HandleFunc("/api/v1/query/select", handlers.GetQuery).Methods("GET")
router.HandleFunc("/api/v1/query/execute", handlers.ExecuteQuery).Methods("POST")
router.HandleFunc("/api/v1/connection/create", handlers.CreateConnection).Methods("POST")
router.HandleFunc("/api/v1/connection/delete", handlers.CloseConnection).Methods("DELETE")
router.HandleFunc("/healthz", handlers.Healthz).Methods("GET")
router.HandleFunc("/readyz", handlers.Readyz).Methods("GET")
router.HandleFunc("/livez", handlers.Livez).Methods("GET")
router.Handle("/metrics", promhttp.Handler())
utils.Log.WithFields(logrus.Fields{
"build_version": version.BuildVersion,
"build_time": version.BuildTime,
}).Info("Starting server sql-proxy:")
utils.Log.WithFields(logrus.Fields{
"bind_port": bindPort,
"bind_address": bindAddress,
"tls_cert": tlsCert,
"tls_key": tlsKey,
}).Info("Server started with the following parameters:")
addr := fmt.Sprintf("%s:%d", bindAddress, bindPort)
if len(tlsCert) > 0 && len(tlsKey) > 0 {
err = http.ListenAndServeTLS(addr, tlsCert, tlsKey, router)
} else {
err = http.ListenAndServe(addr, router)
}
if err != nil {
utils.Log.WithError(err).Fatal("Fatal error occurred, service stopped")
}
}
+23
View File
@@ -0,0 +1,23 @@
package utils
import (
"os"
"strconv"
)
func GetIntEnvOrDefault(env string, defaultValue uint32) uint32 {
strValue := os.Getenv(env)
if len(strValue) == 0 {
return defaultValue
} else {
uintValue, err := strconv.ParseUint(strValue, 10, 32)
if err != nil {
Log.Error(env + " env value cannot be parsed as uint, reset to " + strconv.FormatUint(uint64(defaultValue), 10))
return defaultValue
} else {
return uint32(uintValue)
}
}
}
+14
View File
@@ -0,0 +1,14 @@
package utils
import (
"os"
"github.com/sirupsen/logrus"
)
var Log = &logrus.Logger{
Out: os.Stderr,
Formatter: new(logrus.TextFormatter),
Hooks: make(logrus.LevelHooks),
Level: logrus.ErrorLevel,
}
+6
View File
@@ -0,0 +1,6 @@
package version
var (
BuildTime = "none"
BuildVersion = "none"
)