1
0
mirror of https://github.com/alm494/sql_proxy.git synced 2026-04-22 19:33:55 +02:00

bugfixes, new features

This commit is contained in:
Almaz Sharipov
2026-04-12 13:42:42 +03:00
parent f7429781e1
commit c074c2bbcf
9 changed files with 118 additions and 62 deletions
+8
View File
@@ -1,3 +1,11 @@
1.4.4:
- Security: Error responses now return safe HTTP status text instead of detailed database error messages to prevent information leakage
- Security: Added MAX_CONNECTIONS environment variable to limit concurrent database connections and prevent resource exhaustion (DoS protection)
- Fix: Fixed race condition in GetById() when upgrading from read lock to write lock
- Fix: Fixed memory leak in RunMaintenance() where prepared statement deletions were not persisted to the connection pool
- Feature: Error responses now use consistent JSON format with api_version, error, and status fields
1.4.3:
- Fix: minor bugfixes
+4 -2
View File
@@ -1,5 +1,5 @@
PROJECT_NAME := sql-proxy
BUILD_VERSION := 1.4.3
BUILD_VERSION := 1.4.4
BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S')
BUILD_DIR := build
GO_FILES := src/main.go
@@ -18,6 +18,8 @@ GOAMD64 := v2
BIND_PORT := 8080
BIND_ADDR := localhost
MAX_ROWS := 10000
MAX_CONNECTIONS := 100
DEBUG_LOG := true
#TLS_CERT := $(BUILD_DIR)/server.crt
#TLS_KEY := $(BUILD_DIR)/server.key
@@ -50,7 +52,7 @@ debug: clean
# Run
run: debug
@echo "Running $(PROJECT_NAME) in debug mode..."
BIND_ADDR=$(BIND_ADDR) BIND_PORT=$(BIND_PORT) MAX_ROWS=$(MAX_ROWS) TLS_CERT=$(TLS_CERT) TLS_KEY=$(TLS_KEY) DEBUG_LOG=$(DEBUG_LOG) $(BUILD_DIR)/$(PROJECT_NAME)-debug
BIND_ADDR=$(BIND_ADDR) BIND_PORT=$(BIND_PORT) MAX_ROWS=$(MAX_ROWS) MAX_CONNECTIONS=$(MAX_CONNECTIONS) TLS_CERT=$(TLS_CERT) TLS_KEY=$(TLS_KEY) DEBUG_LOG=$(DEBUG_LOG) $(BUILD_DIR)/$(PROJECT_NAME)-debug
# Run test
test:
+3 -2
View File
@@ -31,6 +31,7 @@ Note that this service is not limited to 1C and can be utilized in other context
* Secure Credential Management : Does not store SQL credentials, ensuring sensitive information remains protected;
* Secure Communication : Supports HTTPS for secure data transmission;
* Efficient Connection Pooling : Utilizes a shared, reusable SQL connection pool with automated maintenance tasks to remove stale or dead connections;
* Connection Limiting : Configurable maximum number of concurrent database connections to prevent resource exhaustion;
* Command Support : Currently supports all SQL commands with no limitation. The SELECT command returns query results as a flexible JSON-formatted recordset;
* Result Limitation : Allows configuration to limit the number of rows returned by SELECT statements;
* Prepared Statements : supported;
@@ -48,7 +49,7 @@ Current API version is 1.2. See Swagger OpenAPI 3.0 specification in /docs/api
## How to compile
Current version is 1.4.3. Execute in the command line:
Current version is 1.4.4. Execute in the command line:
```
make prod
@@ -59,7 +60,7 @@ make prod
Just run the binary. Settings may be passed with environment variables, see Makefile for details and default values:
```
BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 sql-proxy
BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 MAX_CONNECTIONS=100 sql-proxy
```
or install it as a systemd service with install.sh script. Parameters may be changed later in sql-proxy.service file.
+3 -2
View File
@@ -36,6 +36,7 @@
+ Безопасное управление учетными данными: не хранит данные учетных записей, гарантируя защиту конфиденциальной информации;
+ Защищённое соединение: при необходимости, поддерживает HTTPS для безопасной передачи данных;
+ Пул соединений: использует общий переиспользуемый пул SQL-соединений с регламентными задачами обслуживания для удаления устаревших или зависших соединений;
+ Ограничение соединений: настраиваемый лимит одновременных подключений к базе данных для предотвращения исчерпания ресурсов;
+ Поддержка языка SQL: поддерживает любые SQL-команды без ограничений. Команда SELECT возвращает результаты запроса в виде гибкого JSON-формата набора записей;
+ Ограничение результатов: позволяет настраивать ограничения на количество строк, возвращаемых командами SELECT;
+ Поддержка подготовленных выражений: реализована;
@@ -53,7 +54,7 @@
## Как скомпилировать
Номер текущей версии: 1.4.3. Выполнить в командной строке:
Номер текущей версии: 1.4.4. Выполнить в командной строке:
```
make prod
@@ -64,7 +65,7 @@ make prod
Просто запустите бинарник. Все параметры передаются через переменные окружения, см. Makefile для детальной информации и значений настроек по умолчанию:
```
BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 sql-proxy
BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 MAX_CONNECTIONS=100 sql-proxy
```
или установите как службу systemd с помощью скрипта install.sh. Параметры можно изменить прямо в этом скрипте перед установкой, или отредактировать потом файл sql-proxy.service.
+59 -47
View File
@@ -9,11 +9,11 @@ import (
"time"
"slices"
"github.com/google/uuid"
)
var ErrConnectionLimit = fmt.Errorf("connection pool limit reached")
// Init map
func (o *DbList) Init() {
@@ -22,23 +22,33 @@ func (o *DbList) Init() {
}
// Returns current pool size
func (o *DbList) Size() int {
o.mu.RLock()
defer o.mu.RUnlock()
return len(o.items)
}
// Gets SQL server connection by GUID
func (o *DbList) GetById(id string, updateTimestamp bool) (*sql.DB, bool) {
o.mu.RLock()
if updateTimestamp {
o.mu.Lock()
defer o.mu.Unlock()
if dbConn, ok := o.items[id]; ok {
if updateTimestamp {
o.mu.RUnlock()
o.mu.Lock()
if dbConn, ok := o.items[id]; ok {
dbConn.Timestamp = time.Now()
o.items[id] = dbConn
o.mu.Unlock()
return dbConn.DB, true
}
return dbConn.DB, true
}
} else {
o.mu.RLock()
defer o.mu.RUnlock()
o.mu.RUnlock()
if dbConn, ok := o.items[id]; ok {
return dbConn.DB, true
}
}
app.Logger.Errorf("SQL connection with guid='%s' not found", id)
return nil, false
@@ -47,12 +57,11 @@ func (o *DbList) GetById(id string, updateTimestamp bool) (*sql.DB, bool) {
// Gets the new SQL server connection with parameters given.
// First lookups in pool, if fails opens new one and returns GUID value
func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, error) {
hash, err := connInfo.GetHash()
if err != nil {
errMsg := "Hash calculation failed"
app.Logger.Error(errMsg)
return errMsg, false
app.Logger.Error("Hash calculation failed")
return "", fmt.Errorf("hash calculation failed")
}
guid := ""
@@ -70,7 +79,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
if err = dbConn.DB.Ping(); err == nil {
o.mu.RUnlock()
// Everything is ok, return guid
return guid, true
return guid, nil
} else {
// Bad connection, need to clean
o.mu.RUnlock()
@@ -90,11 +99,17 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) {
}
// Creates the new SQL connection regarding concurrency
func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, bool) {
func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, error) {
o.mu.Lock()
defer o.mu.Unlock()
// Check connection pool limit
if len(o.items) >= MaxConnections {
app.Logger.Errorf("Connection pool limit reached: %d", MaxConnections)
return "", ErrConnectionLimit
}
// Prepare DSN string
var dsn string
@@ -115,9 +130,9 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s",
connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName)
default:
errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType)
errMsg := fmt.Sprintf("unsupported database type '%s'", connInfo.DbType)
app.Logger.Error(errMsg)
return errMsg, false
return "", fmt.Errorf("%s", errMsg)
}
// Open new SQL server connection
@@ -128,16 +143,14 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
// Check for failure
if err != nil {
errMsg := "Error establishing SQL server connection"
app.Logger.Error(errMsg)
return errMsg, false
app.Logger.Errorf("Error establishing SQL server connection: %v", err)
return "", fmt.Errorf("error establishing SQL server connection")
}
// Check if alive
if err = newDb.Ping(); err != nil {
errMsg := "Just created SQL connection is dead"
app.Logger.Error(errMsg)
return errMsg, false
app.Logger.Errorf("Just created SQL connection is dead: %v", err)
return "", fmt.Errorf("error establishing SQL server connection")
}
// Insert into pool
@@ -161,7 +174,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string,
newId,
)
return newId, true
return newId, nil
}
// Deletes SQL server connection
@@ -230,15 +243,17 @@ func (o *DbList) ClosePreparedStatement(connId, stmtId string) bool {
if !ok {
return false
}
for i := range dbConn.Stmt {
if dbConn.Stmt[i].Id == stmtId {
dbConn.Stmt[i].Stmt.Close()
dbConn.Stmt = slices.Delete(dbConn.Stmt, i, i+1)
break
for i, stmt := range dbConn.Stmt {
if stmt.Id == stmtId {
stmt.Stmt.Close()
dbConn.Stmt = append(dbConn.Stmt[:i], dbConn.Stmt[i+1:]...)
o.items[connId] = dbConn
return true
}
}
return true
return false
}
@@ -258,40 +273,37 @@ func (o *DbList) RunMaintenance() {
o.mu.Lock()
for key, dbConn := range o.items {
var lostStmts []string
countConn++
isDead := false
if err := dbConn.DB.Ping(); err != nil {
// dead connection
deadItems = append(deadItems, key)
countDeadConn++
isDead = true
} else if time.Since(dbConn.Timestamp).Abs().Minutes() > 20 {
// connection not used for last 20 minutes
deadItems = append(deadItems, key)
countDeadConn++
isDead = true
}
// check prepared statements
// Close and remove expired prepared statements
activeStmts := make([]DbStmt, 0, len(dbConn.Stmt))
for _, stmt := range dbConn.Stmt {
// prepared statements not used last 20 minutes
if time.Since(stmt.Timestamp).Abs().Minutes() > 20 {
lostStmts = append(lostStmts, stmt.Id)
stmt.Stmt.Close()
countStmt++
} else {
activeStmts = append(activeStmts, stmt)
}
}
dbConn.Stmt = activeStmts
// delete lost prepared statements
for _, lost := range lostStmts {
for i := range dbConn.Stmt {
if dbConn.Stmt[i].Id == lost {
dbConn.Stmt[i].Stmt.Close()
dbConn.Stmt = slices.Delete(dbConn.Stmt, i, i+1)
break
}
}
// Update connection in pool (without timestamp change)
if !isDead {
o.items[key] = dbConn
}
}
// remove dead connections
+3 -2
View File
@@ -1,6 +1,7 @@
package db
var (
Handler DbList
MaxRows uint32 = 10000
Handler DbList
MaxRows uint32 = 10000
MaxConnections int = 100
)
+21 -2
View File
@@ -17,13 +17,25 @@ type ResponseEnvelope struct {
Rows []map[string]any `json:"rows"`
}
type ErrorEnvelope struct {
ApiVersion string `json:"api_version"`
Error string `json:"error"`
Status int `json:"status"`
}
func checkApiVersion(w http.ResponseWriter, r *http.Request) bool {
apiVersion := r.Header.Get("API-Version")
if apiVersion != app.ApiVersion {
message := "Unsupported API version"
app.Logger.Error(message)
http.Error(w, message, http.StatusNotImplemented)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotImplemented)
json.NewEncoder(w).Encode(ErrorEnvelope{
ApiVersion: app.ApiVersion,
Error: message,
Status: http.StatusNotImplemented,
})
return false
} else {
return true
@@ -34,7 +46,14 @@ func checkApiVersion(w http.ResponseWriter, r *http.Request) bool {
func errorResponce(w http.ResponseWriter, message string, httpStatus int) {
app.Logger.Error(message)
http.Error(w, message, httpStatus)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(httpStatus)
errResp := ErrorEnvelope{
ApiVersion: app.ApiVersion,
Error: http.StatusText(httpStatus),
Status: httpStatus,
}
json.NewEncoder(w).Encode(errResp)
}
+12 -3
View File
@@ -2,6 +2,7 @@ package handlers
import (
"encoding/json"
"errors"
"net/http"
"sql-proxy/src/db"
)
@@ -19,9 +20,17 @@ func CreateConnection(w http.ResponseWriter, r *http.Request) {
return
}
if connGuid, ok := db.Handler.GetByParams(&dbConnInfo); !ok {
errorResponce(w, "Failed to get SQL connection", http.StatusInternalServerError)
} else if _, err := w.Write([]byte(connGuid)); err != nil {
connGuid, err := db.Handler.GetByParams(&dbConnInfo)
if err != nil {
if errors.Is(err, db.ErrConnectionLimit) {
errorResponce(w, err.Error(), http.StatusTooManyRequests)
} else {
errorResponce(w, err.Error(), http.StatusInternalServerError)
}
return
}
if _, err = w.Write([]byte(connGuid)); err != nil {
errorResponce(w, err.Error(), http.StatusInternalServerError)
}
+5 -2
View File
@@ -42,6 +42,8 @@ func (p *program) run() {
}
bindPort := app.GetEnvInt("BIND_PORT", 8080)
db.MaxRows = uint32(app.GetEnvInt("MAX_ROWS", 10000))
db.MaxConnections = app.GetEnvInt("MAX_CONNECTIONS", 100)
tlsCert := app.GetEnvString("TLS_CERT", "")
tlsKey := app.GetEnvString("TLS_KEY", "")
@@ -67,10 +69,11 @@ func (p *program) run() {
router.HandleFunc("/livez", handlers.Livez).Methods("GET")
router.Handle("/metrics", promhttp.Handler())
app.Logger.Info("(c) 2025 Almaz Sharipov, MIT license, https://github.com/alm494/sql_proxy ")
app.Logger.Info("(c) 2025-2026 Almaz Sharipov, MIT license, https://github.com/alm494/sql_proxy ")
app.Logger.Infof("build_version=%s, build_time=%s", app.BuildVersion, app.BuildTime)
app.Logger.Infof("Server started with the following parameters: "+
"bind_port=%d, bind_address=%s, tls_cert=%s, tls_key=%s", bindPort, bindAddress, tlsCert, tlsKey)
"bind_port=%d, bind_address=%s, max_connections=%d, max_rows=%d, tls_cert=%s, tls_key=%s",
bindPort, bindAddress, db.MaxConnections, db.MaxRows, tlsCert, tlsKey)
addr := fmt.Sprintf("%s:%d", bindAddress, bindPort)