diff --git a/CHANGELOG b/CHANGELOG index a2b4acc..16a41d1 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,11 @@ +1.4.4: + + - Security: Error responses now return safe HTTP status text instead of detailed database error messages to prevent information leakage + - Security: Added MAX_CONNECTIONS environment variable to limit concurrent database connections and prevent resource exhaustion (DoS protection) + - Fix: Fixed race condition in GetById() when upgrading from read lock to write lock + - Fix: Fixed memory leak in RunMaintenance() where prepared statement deletions were not persisted to the connection pool + - Feature: Error responses now use consistent JSON format with api_version, error, and status fields + 1.4.3: - Fix: minor bugfixes diff --git a/Makefile b/Makefile index 952f27c..192bf13 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ PROJECT_NAME := sql-proxy -BUILD_VERSION := 1.4.3 +BUILD_VERSION := 1.4.4 BUILD_TIME := $(shell date -u '+%Y-%m-%d_%H:%M:%S') BUILD_DIR := build GO_FILES := src/main.go @@ -18,6 +18,8 @@ GOAMD64 := v2 BIND_PORT := 8080 BIND_ADDR := localhost MAX_ROWS := 10000 +MAX_CONNECTIONS := 100 + DEBUG_LOG := true #TLS_CERT := $(BUILD_DIR)/server.crt #TLS_KEY := $(BUILD_DIR)/server.key @@ -50,7 +52,7 @@ debug: clean # Run run: debug @echo "Running $(PROJECT_NAME) in debug mode..." - BIND_ADDR=$(BIND_ADDR) BIND_PORT=$(BIND_PORT) MAX_ROWS=$(MAX_ROWS) TLS_CERT=$(TLS_CERT) TLS_KEY=$(TLS_KEY) DEBUG_LOG=$(DEBUG_LOG) $(BUILD_DIR)/$(PROJECT_NAME)-debug + BIND_ADDR=$(BIND_ADDR) BIND_PORT=$(BIND_PORT) MAX_ROWS=$(MAX_ROWS) MAX_CONNECTIONS=$(MAX_CONNECTIONS) TLS_CERT=$(TLS_CERT) TLS_KEY=$(TLS_KEY) DEBUG_LOG=$(DEBUG_LOG) $(BUILD_DIR)/$(PROJECT_NAME)-debug # Run test test: diff --git a/README.md b/README.md index 70c953e..6bd15f0 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Note that this service is not limited to 1C and can be utilized in other context * Secure Credential Management : Does not store SQL credentials, ensuring sensitive information remains protected; * Secure Communication : Supports HTTPS for secure data transmission; * Efficient Connection Pooling : Utilizes a shared, reusable SQL connection pool with automated maintenance tasks to remove stale or dead connections; +* Connection Limiting : Configurable maximum number of concurrent database connections to prevent resource exhaustion; * Command Support : Currently supports all SQL commands with no limitation. The SELECT command returns query results as a flexible JSON-formatted recordset; * Result Limitation : Allows configuration to limit the number of rows returned by SELECT statements; * Prepared Statements : supported; @@ -48,7 +49,7 @@ Current API version is 1.2. See Swagger OpenAPI 3.0 specification in /docs/api ## How to compile -Current version is 1.4.3. Execute in the command line: +Current version is 1.4.4. Execute in the command line: ``` make prod @@ -59,7 +60,7 @@ make prod Just run the binary. Settings may be passed with environment variables, see Makefile for details and default values: ``` -BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 sql-proxy +BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 MAX_CONNECTIONS=100 sql-proxy ``` or install it as a systemd service with install.sh script. Parameters may be changed later in sql-proxy.service file. \ No newline at end of file diff --git a/README.ru.md b/README.ru.md index 94b13f4..7cf0c17 100644 --- a/README.ru.md +++ b/README.ru.md @@ -36,6 +36,7 @@ + Безопасное управление учетными данными: не хранит данные учетных записей, гарантируя защиту конфиденциальной информации; + Защищённое соединение: при необходимости, поддерживает HTTPS для безопасной передачи данных; + Пул соединений: использует общий переиспользуемый пул SQL-соединений с регламентными задачами обслуживания для удаления устаревших или зависших соединений; ++ Ограничение соединений: настраиваемый лимит одновременных подключений к базе данных для предотвращения исчерпания ресурсов; + Поддержка языка SQL: поддерживает любые SQL-команды без ограничений. Команда SELECT возвращает результаты запроса в виде гибкого JSON-формата набора записей; + Ограничение результатов: позволяет настраивать ограничения на количество строк, возвращаемых командами SELECT; + Поддержка подготовленных выражений: реализована; @@ -53,7 +54,7 @@ ## Как скомпилировать -Номер текущей версии: 1.4.3. Выполнить в командной строке: +Номер текущей версии: 1.4.4. Выполнить в командной строке: ``` make prod @@ -64,7 +65,7 @@ make prod Просто запустите бинарник. Все параметры передаются через переменные окружения, см. Makefile для детальной информации и значений настроек по умолчанию: ``` -BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 sql-proxy +BIND_ADDR=localhost BIND_PORT=8081 MAX_ROWS=10000 MAX_CONNECTIONS=100 sql-proxy ``` или установите как службу systemd с помощью скрипта install.sh. Параметры можно изменить прямо в этом скрипте перед установкой, или отредактировать потом файл sql-proxy.service. \ No newline at end of file diff --git a/src/db/dblist.go b/src/db/dblist.go index b914291..55af80a 100644 --- a/src/db/dblist.go +++ b/src/db/dblist.go @@ -9,11 +9,11 @@ import ( "time" - "slices" - "github.com/google/uuid" ) +var ErrConnectionLimit = fmt.Errorf("connection pool limit reached") + // Init map func (o *DbList) Init() { @@ -22,23 +22,33 @@ func (o *DbList) Init() { } +// Returns current pool size +func (o *DbList) Size() int { + o.mu.RLock() + defer o.mu.RUnlock() + return len(o.items) +} + // Gets SQL server connection by GUID func (o *DbList) GetById(id string, updateTimestamp bool) (*sql.DB, bool) { - o.mu.RLock() + if updateTimestamp { + o.mu.Lock() + defer o.mu.Unlock() - if dbConn, ok := o.items[id]; ok { - if updateTimestamp { - o.mu.RUnlock() - o.mu.Lock() + if dbConn, ok := o.items[id]; ok { dbConn.Timestamp = time.Now() o.items[id] = dbConn - o.mu.Unlock() + return dbConn.DB, true } - return dbConn.DB, true - } + } else { + o.mu.RLock() + defer o.mu.RUnlock() - o.mu.RUnlock() + if dbConn, ok := o.items[id]; ok { + return dbConn.DB, true + } + } app.Logger.Errorf("SQL connection with guid='%s' not found", id) return nil, false @@ -47,12 +57,11 @@ func (o *DbList) GetById(id string, updateTimestamp bool) (*sql.DB, bool) { // Gets the new SQL server connection with parameters given. // First lookups in pool, if fails opens new one and returns GUID value -func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) { +func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, error) { hash, err := connInfo.GetHash() if err != nil { - errMsg := "Hash calculation failed" - app.Logger.Error(errMsg) - return errMsg, false + app.Logger.Error("Hash calculation failed") + return "", fmt.Errorf("hash calculation failed") } guid := "" @@ -70,7 +79,7 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) { if err = dbConn.DB.Ping(); err == nil { o.mu.RUnlock() // Everything is ok, return guid - return guid, true + return guid, nil } else { // Bad connection, need to clean o.mu.RUnlock() @@ -90,11 +99,17 @@ func (o *DbList) GetByParams(connInfo *DbConnInfo) (string, bool) { } // Creates the new SQL connection regarding concurrency -func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, bool) { +func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, error) { o.mu.Lock() defer o.mu.Unlock() + // Check connection pool limit + if len(o.items) >= MaxConnections { + app.Logger.Errorf("Connection pool limit reached: %d", MaxConnections) + return "", ErrConnectionLimit + } + // Prepare DSN string var dsn string @@ -115,9 +130,9 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", connInfo.User, encodedPassword, connInfo.Host, connInfo.Port, connInfo.DbName) default: - errMsg := fmt.Sprintf("No suitable driver implemented for server type '%s'", connInfo.DbType) + errMsg := fmt.Sprintf("unsupported database type '%s'", connInfo.DbType) app.Logger.Error(errMsg) - return errMsg, false + return "", fmt.Errorf("%s", errMsg) } // Open new SQL server connection @@ -128,16 +143,14 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, // Check for failure if err != nil { - errMsg := "Error establishing SQL server connection" - app.Logger.Error(errMsg) - return errMsg, false + app.Logger.Errorf("Error establishing SQL server connection: %v", err) + return "", fmt.Errorf("error establishing SQL server connection") } // Check if alive if err = newDb.Ping(); err != nil { - errMsg := "Just created SQL connection is dead" - app.Logger.Error(errMsg) - return errMsg, false + app.Logger.Errorf("Just created SQL connection is dead: %v", err) + return "", fmt.Errorf("error establishing SQL server connection") } // Insert into pool @@ -161,7 +174,7 @@ func (o *DbList) getNewConnection(connInfo *DbConnInfo, hash [32]byte) (string, newId, ) - return newId, true + return newId, nil } // Deletes SQL server connection @@ -230,15 +243,17 @@ func (o *DbList) ClosePreparedStatement(connId, stmtId string) bool { if !ok { return false } - for i := range dbConn.Stmt { - if dbConn.Stmt[i].Id == stmtId { - dbConn.Stmt[i].Stmt.Close() - dbConn.Stmt = slices.Delete(dbConn.Stmt, i, i+1) - break + + for i, stmt := range dbConn.Stmt { + if stmt.Id == stmtId { + stmt.Stmt.Close() + dbConn.Stmt = append(dbConn.Stmt[:i], dbConn.Stmt[i+1:]...) + o.items[connId] = dbConn + return true } } - return true + return false } @@ -258,40 +273,37 @@ func (o *DbList) RunMaintenance() { o.mu.Lock() for key, dbConn := range o.items { - - var lostStmts []string countConn++ + isDead := false if err := dbConn.DB.Ping(); err != nil { // dead connection deadItems = append(deadItems, key) countDeadConn++ + isDead = true } else if time.Since(dbConn.Timestamp).Abs().Minutes() > 20 { // connection not used for last 20 minutes deadItems = append(deadItems, key) countDeadConn++ + isDead = true } - // check prepared statements + // Close and remove expired prepared statements + activeStmts := make([]DbStmt, 0, len(dbConn.Stmt)) for _, stmt := range dbConn.Stmt { - // prepared statements not used last 20 minutes if time.Since(stmt.Timestamp).Abs().Minutes() > 20 { - lostStmts = append(lostStmts, stmt.Id) + stmt.Stmt.Close() countStmt++ + } else { + activeStmts = append(activeStmts, stmt) } } + dbConn.Stmt = activeStmts - // delete lost prepared statements - for _, lost := range lostStmts { - for i := range dbConn.Stmt { - if dbConn.Stmt[i].Id == lost { - dbConn.Stmt[i].Stmt.Close() - dbConn.Stmt = slices.Delete(dbConn.Stmt, i, i+1) - break - } - } + // Update connection in pool (without timestamp change) + if !isDead { + o.items[key] = dbConn } - } // remove dead connections diff --git a/src/db/vars.go b/src/db/vars.go index 6b89ae5..45c0577 100644 --- a/src/db/vars.go +++ b/src/db/vars.go @@ -1,6 +1,7 @@ package db var ( - Handler DbList - MaxRows uint32 = 10000 + Handler DbList + MaxRows uint32 = 10000 + MaxConnections int = 100 ) diff --git a/src/handlers/common.go b/src/handlers/common.go index f13b226..e3acfae 100644 --- a/src/handlers/common.go +++ b/src/handlers/common.go @@ -17,13 +17,25 @@ type ResponseEnvelope struct { Rows []map[string]any `json:"rows"` } +type ErrorEnvelope struct { + ApiVersion string `json:"api_version"` + Error string `json:"error"` + Status int `json:"status"` +} + func checkApiVersion(w http.ResponseWriter, r *http.Request) bool { apiVersion := r.Header.Get("API-Version") if apiVersion != app.ApiVersion { message := "Unsupported API version" app.Logger.Error(message) - http.Error(w, message, http.StatusNotImplemented) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotImplemented) + json.NewEncoder(w).Encode(ErrorEnvelope{ + ApiVersion: app.ApiVersion, + Error: message, + Status: http.StatusNotImplemented, + }) return false } else { return true @@ -34,7 +46,14 @@ func checkApiVersion(w http.ResponseWriter, r *http.Request) bool { func errorResponce(w http.ResponseWriter, message string, httpStatus int) { app.Logger.Error(message) - http.Error(w, message, httpStatus) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(httpStatus) + errResp := ErrorEnvelope{ + ApiVersion: app.ApiVersion, + Error: http.StatusText(httpStatus), + Status: httpStatus, + } + json.NewEncoder(w).Encode(errResp) } diff --git a/src/handlers/connection.go b/src/handlers/connection.go index 3706692..4578737 100644 --- a/src/handlers/connection.go +++ b/src/handlers/connection.go @@ -2,6 +2,7 @@ package handlers import ( "encoding/json" + "errors" "net/http" "sql-proxy/src/db" ) @@ -19,9 +20,17 @@ func CreateConnection(w http.ResponseWriter, r *http.Request) { return } - if connGuid, ok := db.Handler.GetByParams(&dbConnInfo); !ok { - errorResponce(w, "Failed to get SQL connection", http.StatusInternalServerError) - } else if _, err := w.Write([]byte(connGuid)); err != nil { + connGuid, err := db.Handler.GetByParams(&dbConnInfo) + if err != nil { + if errors.Is(err, db.ErrConnectionLimit) { + errorResponce(w, err.Error(), http.StatusTooManyRequests) + } else { + errorResponce(w, err.Error(), http.StatusInternalServerError) + } + return + } + + if _, err = w.Write([]byte(connGuid)); err != nil { errorResponce(w, err.Error(), http.StatusInternalServerError) } diff --git a/src/main.go b/src/main.go index 46a6804..471a2c8 100644 --- a/src/main.go +++ b/src/main.go @@ -42,6 +42,8 @@ func (p *program) run() { } bindPort := app.GetEnvInt("BIND_PORT", 8080) db.MaxRows = uint32(app.GetEnvInt("MAX_ROWS", 10000)) + db.MaxConnections = app.GetEnvInt("MAX_CONNECTIONS", 100) + tlsCert := app.GetEnvString("TLS_CERT", "") tlsKey := app.GetEnvString("TLS_KEY", "") @@ -67,10 +69,11 @@ func (p *program) run() { router.HandleFunc("/livez", handlers.Livez).Methods("GET") router.Handle("/metrics", promhttp.Handler()) - app.Logger.Info("(c) 2025 Almaz Sharipov, MIT license, https://github.com/alm494/sql_proxy ") + app.Logger.Info("(c) 2025-2026 Almaz Sharipov, MIT license, https://github.com/alm494/sql_proxy ") app.Logger.Infof("build_version=%s, build_time=%s", app.BuildVersion, app.BuildTime) app.Logger.Infof("Server started with the following parameters: "+ - "bind_port=%d, bind_address=%s, tls_cert=%s, tls_key=%s", bindPort, bindAddress, tlsCert, tlsKey) + "bind_port=%d, bind_address=%s, max_connections=%d, max_rows=%d, tls_cert=%s, tls_key=%s", + bindPort, bindAddress, db.MaxConnections, db.MaxRows, tlsCert, tlsKey) addr := fmt.Sprintf("%s:%d", bindAddress, bindPort)