1
0
mirror of https://github.com/ManyakRus/crud_generator.git synced 2025-01-24 21:32:42 +02:00
2024-03-06 16:03:41 +03:00

284 lines
6.5 KiB
Go

package postgres
import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Dialector struct {
*Config
}
type Config struct {
DriverName string
DSN string
WithoutQuotingCheck bool
PreferSimpleProtocol bool
WithoutReturning bool
Conn gorm.ConnPool
}
var (
timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
defaultIdentifierLength = 63 //maximum identifier length for postgres
)
func Open(dsn string) gorm.Dialector {
return &Dialector{&Config{DSN: dsn}}
}
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func (dialector Dialector) Name() string {
return "postgres"
}
func (dialector Dialector) Apply(config *gorm.Config) error {
var namingStartegy *schema.NamingStrategy
switch v := config.NamingStrategy.(type) {
case *schema.NamingStrategy:
namingStartegy = v
case schema.NamingStrategy:
namingStartegy = &v
case nil:
namingStartegy = &schema.NamingStrategy{}
}
if namingStartegy.IdentifierMaxLength <= 0 {
namingStartegy.IdentifierMaxLength = defaultIdentifierLength
}
config.NamingStrategy = namingStartegy
return nil
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
callbackConfig := &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
}
// register callbacks
if !dialector.WithoutReturning {
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else if dialector.DriverName != "" {
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
} else {
var config *pgx.ConnConfig
config, err = pgx.ParseConfig(dialector.Config.DSN)
if err != nil {
return
}
if dialector.Config.PreferSimpleProtocol {
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
}
result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN)
if len(result) > 2 {
config.RuntimeParams["timezone"] = result[2]
}
db.ConnPool = stdlib.OpenDB(*config)
}
return
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('$')
index := 0
varLen := len(stmt.Vars)
if varLen > 0 {
switch stmt.Vars[0].(type) {
case pgx.QueryExecMode:
index++
}
}
writer.WriteString(strconv.Itoa(varLen - index))
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
if dialector.WithoutQuotingCheck {
writer.WriteString(str)
return
}
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '"':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString(`""`)
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteByte('"')
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteByte('"')
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString(`""`)
}
writer.WriteByte(v)
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString(`""`)
}
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
size := field.Size
if field.DataType == schema.Uint {
size++
}
if field.AutoIncrement {
switch {
case size <= 16:
return "smallserial"
case size <= 32:
return "serial"
default:
return "bigserial"
}
} else {
switch {
case size <= 16:
return "smallint"
case size <= 32:
return "integer"
default:
return "bigint"
}
}
case schema.Float:
if field.Precision > 0 {
if field.Scale > 0 {
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
}
return fmt.Sprintf("numeric(%d)", field.Precision)
}
return "decimal"
case schema.String:
if field.Size > 0 {
return fmt.Sprintf("varchar(%d)", field.Size)
}
return "text"
case schema.Time:
if field.Precision > 0 {
return fmt.Sprintf("timestamptz(%d)", field.Precision)
}
return "timestamptz"
case schema.Bytes:
return "bytea"
default:
return dialector.getSchemaCustomType(field)
}
}
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
sqlType := string(field.DataType)
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
size := field.Size
if field.GORMDataType == schema.Uint {
size++
}
switch {
case size <= 16:
sqlType = "smallserial"
case size <= 32:
sqlType = "serial"
default:
sqlType = "bigserial"
}
}
return sqlType
}
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func getSerialDatabaseType(s string) (dbType string, ok bool) {
switch s {
case "smallserial":
return "smallint", true
case "serial":
return "integer", true
case "bigserial":
return "bigint", true
default:
return "", false
}
}