diff --git a/pkg/database/sql/sql.go b/pkg/database/sql/sql.go index a59997264..f0406c002 100644 --- a/pkg/database/sql/sql.go +++ b/pkg/database/sql/sql.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "strings" "sync/atomic" "time" @@ -13,6 +12,7 @@ import ( "github.com/bilibili/kratos/pkg/net/netutil/breaker" "github.com/bilibili/kratos/pkg/net/trace" + "github.com/go-sql-driver/mysql" "github.com/pkg/errors" ) @@ -660,17 +660,12 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // parseDSNAddr parse dsn name and return addr. func parseDSNAddr(dsn string) (addr string) { - if dsn == "" { - return + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + // just ignore parseDSN error, mysql client will return error for us when connect. + return "" } - part0 := strings.Split(dsn, "@") - if len(part0) > 1 { - part1 := strings.Split(part0[1], "?") - if len(part1) > 0 { - addr = part1[0] - } - } - return + return cfg.Addr } func slowLog(statement string, now time.Time) { diff --git a/pkg/database/sql/sql_test.go b/pkg/database/sql/sql_test.go new file mode 100644 index 000000000..191ab7937 --- /dev/null +++ b/pkg/database/sql/sql_test.go @@ -0,0 +1,18 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseAddrDSN(t *testing.T) { + t.Run("test parse addr dsn", func(t *testing.T) { + addr := parseDSNAddr("test:test@tcp(172.16.0.148:3306)/test?timeout=5s&readTimeout=5s&writeTimeout=5s&parseTime=true&loc=Local&charset=utf8") + assert.Equal(t, "172.16.0.148:3306", addr) + }) + t.Run("test password has @", func(t *testing.T) { + addr := parseDSNAddr("root:root@dev@tcp(1.2.3.4:3306)/abc?timeout=1s&readTimeout=1s&writeTimeout=1s&parseTime=true&loc=Local&charset=utf8mb4,utf8") + assert.Equal(t, "1.2.3.4:3306", addr) + }) +}