From 1fff221da9e1ec5d6862aa3a3782eee2c77c78c8 Mon Sep 17 00:00:00 2001 From: Miguel de la Cruz Date: Mon, 27 Mar 2023 13:56:39 +0200 Subject: [PATCH] Adds escaping when normalizing table names for MySQL on DB helpers (#4653) Co-authored-by: Mattermost Build --- server/services/store/sqlstore/migrate.go | 32 +++++++++++++---------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/server/services/store/sqlstore/migrate.go b/server/services/store/sqlstore/migrate.go index d3fd7e3c5..da48b3a5c 100644 --- a/server/services/store/sqlstore/migrate.go +++ b/server/services/store/sqlstore/migrate.go @@ -319,7 +319,7 @@ func (s *SQLStore) GetTemplateHelperFuncs() template.FuncMap { func (s *SQLStore) genAddColumnIfNeeded(tableName, columnName, datatype, constraint string) (string, error) { tableName = addPrefixIfNeeded(tableName, s.tablePrefix) - normTableName := normalizeTablename(s.schemaName, tableName) + normTableName := s.normalizeTablename(tableName) switch s.dbType { case model.SqliteDBType: @@ -358,7 +358,7 @@ func (s *SQLStore) genAddColumnIfNeeded(tableName, columnName, datatype, constra func (s *SQLStore) genDropColumnIfNeeded(tableName, columnName string) (string, error) { tableName = addPrefixIfNeeded(tableName, s.tablePrefix) - normTableName := normalizeTablename(s.schemaName, tableName) + normTableName := s.normalizeTablename(tableName) switch s.dbType { case model.SqliteDBType: @@ -395,7 +395,7 @@ func (s *SQLStore) genDropColumnIfNeeded(tableName, columnName string) (string, func (s *SQLStore) genCreateIndexIfNeeded(tableName, columns string) (string, error) { indexName := getIndexName(tableName, columns) tableName = addPrefixIfNeeded(tableName, s.tablePrefix) - normTableName := normalizeTablename(s.schemaName, tableName) + normTableName := s.normalizeTablename(tableName) switch s.dbType { case model.SqliteDBType: @@ -435,7 +435,7 @@ func (s *SQLStore) genRenameTableIfNeeded(oldTableName, newTableName string) (st oldTableName = addPrefixIfNeeded(oldTableName, s.tablePrefix) newTableName = addPrefixIfNeeded(newTableName, s.tablePrefix) - normOldTableName := normalizeTablename(s.schemaName, oldTableName) + normOldTableName := s.normalizeTablename(oldTableName) vars := map[string]string{ "schema": s.schemaName, @@ -466,14 +466,14 @@ func (s *SQLStore) genRenameTableIfNeeded(oldTableName, newTableName string) (st case model.PostgresDBType: return replaceVars(` do $$ - begin + begin if (SELECT COUNT(table_name) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = '[[new_table_name]]' AND table_schema = '[[schema]]' - ) = 0 then + ) = 0 then ALTER TABLE [[norm_old_table_name]] RENAME TO [[new_table_name]]; end if; - end$$; + end$$; `, vars), nil default: return "", ErrUnsupportedDatabaseType @@ -482,7 +482,7 @@ func (s *SQLStore) genRenameTableIfNeeded(oldTableName, newTableName string) (st func (s *SQLStore) genRenameColumnIfNeeded(tableName, oldColumnName, newColumnName, dataType string) (string, error) { tableName = addPrefixIfNeeded(tableName, s.tablePrefix) - normTableName := normalizeTablename(s.schemaName, tableName) + normTableName := s.normalizeTablename(tableName) vars := map[string]string{ "schema": s.schemaName, @@ -516,15 +516,15 @@ func (s *SQLStore) genRenameColumnIfNeeded(tableName, oldColumnName, newColumnNa case model.PostgresDBType: return replaceVars(` do $$ - begin + begin if (SELECT COUNT(table_name) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = '[[table_name]]' AND table_schema = '[[schema]]' AND column_name = '[[new_column_name]]' - ) = 0 then + ) = 0 then ALTER TABLE [[norm_table_name]] RENAME COLUMN [[old_column_name]] TO [[new_column_name]]; end if; - end$$; + end$$; `, vars), nil default: return "", ErrUnsupportedDatabaseType @@ -620,7 +620,7 @@ func (s *SQLStore) doesColumnExist(tableName, columnName string) (bool, error) { func (s *SQLStore) genAddConstraintIfNeeded(tableName, constraintName, constraintType, constraintDefinition string) (string, error) { tableName = addPrefixIfNeeded(tableName, s.tablePrefix) - normTableName := normalizeTablename(s.schemaName, tableName) + normTableName := s.normalizeTablename(tableName) var query string @@ -686,8 +686,12 @@ func addPrefixIfNeeded(s, prefix string) string { return s } -func normalizeTablename(schemaName, tableName string) string { - if schemaName != "" && !strings.HasPrefix(tableName, schemaName+".") { +func (s *SQLStore) normalizeTablename(tableName string) string { + if s.schemaName != "" && !strings.HasPrefix(tableName, s.schemaName+".") { + schemaName := s.schemaName + if s.dbType == model.MysqlDBType { + schemaName = "`" + schemaName + "`" + } tableName = schemaName + "." + tableName } return tableName