mirror of
https://github.com/ManyakRus/crud_generator.git
synced 2024-12-23 12:44:13 +02:00
250 lines
5.7 KiB
Go
250 lines
5.7 KiB
Go
package postgres
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"github.com/jackc/pgx/v5"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/jackc/pgx/v5/stdlib"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/callbacks"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
"gorm.io/gorm/migrator"
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
type Dialector struct {
|
|
*Config
|
|
}
|
|
|
|
type Config struct {
|
|
DriverName string
|
|
DSN string
|
|
PreferSimpleProtocol bool
|
|
WithoutReturning bool
|
|
Conn gorm.ConnPool
|
|
}
|
|
|
|
func Open(dsn string) gorm.Dialector {
|
|
return &Dialector{&Config{DSN: dsn}}
|
|
}
|
|
|
|
func New(config Config) gorm.Dialector {
|
|
return &Dialector{Config: &config}
|
|
}
|
|
|
|
func (dialector Dialector) Name() string {
|
|
return "postgres"
|
|
}
|
|
|
|
var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
|
|
|
|
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
|
callbackConfig := &callbacks.Config{
|
|
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
|
|
UpdateClauses: []string{"UPDATE", "SET", "WHERE"},
|
|
DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
|
|
}
|
|
// register callbacks
|
|
if !dialector.WithoutReturning {
|
|
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
|
|
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
|
|
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
|
|
}
|
|
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
|
|
|
|
if dialector.Conn != nil {
|
|
db.ConnPool = dialector.Conn
|
|
} else if dialector.DriverName != "" {
|
|
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
|
|
} else {
|
|
var config *pgx.ConnConfig
|
|
|
|
config, err = pgx.ParseConfig(dialector.Config.DSN)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if dialector.Config.PreferSimpleProtocol {
|
|
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
|
|
}
|
|
result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN)
|
|
if len(result) > 2 {
|
|
config.RuntimeParams["timezone"] = result[2]
|
|
}
|
|
db.ConnPool = stdlib.OpenDB(*config)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
|
return Migrator{migrator.Migrator{Config: migrator.Config{
|
|
DB: db,
|
|
Dialector: dialector,
|
|
CreateIndexAfterCreateTable: true,
|
|
}}}
|
|
}
|
|
|
|
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
|
return clause.Expr{SQL: "DEFAULT"}
|
|
}
|
|
|
|
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
|
writer.WriteByte('$')
|
|
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
|
|
}
|
|
|
|
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
|
var (
|
|
underQuoted, selfQuoted bool
|
|
continuousBacktick int8
|
|
shiftDelimiter int8
|
|
)
|
|
|
|
for _, v := range []byte(str) {
|
|
switch v {
|
|
case '"':
|
|
continuousBacktick++
|
|
if continuousBacktick == 2 {
|
|
writer.WriteString(`""`)
|
|
continuousBacktick = 0
|
|
}
|
|
case '.':
|
|
if continuousBacktick > 0 || !selfQuoted {
|
|
shiftDelimiter = 0
|
|
underQuoted = false
|
|
continuousBacktick = 0
|
|
writer.WriteByte('"')
|
|
}
|
|
writer.WriteByte(v)
|
|
continue
|
|
default:
|
|
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
|
|
writer.WriteByte('"')
|
|
underQuoted = true
|
|
if selfQuoted = continuousBacktick > 0; selfQuoted {
|
|
continuousBacktick -= 1
|
|
}
|
|
}
|
|
|
|
for ; continuousBacktick > 0; continuousBacktick -= 1 {
|
|
writer.WriteString(`""`)
|
|
}
|
|
|
|
writer.WriteByte(v)
|
|
}
|
|
shiftDelimiter++
|
|
}
|
|
|
|
if continuousBacktick > 0 && !selfQuoted {
|
|
writer.WriteString(`""`)
|
|
}
|
|
writer.WriteByte('"')
|
|
}
|
|
|
|
var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
|
|
|
|
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
|
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
|
|
}
|
|
|
|
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
|
switch field.DataType {
|
|
case schema.Bool:
|
|
return "boolean"
|
|
case schema.Int, schema.Uint:
|
|
size := field.Size
|
|
if field.DataType == schema.Uint {
|
|
size++
|
|
}
|
|
if field.AutoIncrement {
|
|
switch {
|
|
case size <= 16:
|
|
return "smallserial"
|
|
case size <= 32:
|
|
return "serial"
|
|
default:
|
|
return "bigserial"
|
|
}
|
|
} else {
|
|
switch {
|
|
case size <= 16:
|
|
return "smallint"
|
|
case size <= 32:
|
|
return "integer"
|
|
default:
|
|
return "bigint"
|
|
}
|
|
}
|
|
case schema.Float:
|
|
if field.Precision > 0 {
|
|
if field.Scale > 0 {
|
|
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
|
|
}
|
|
return fmt.Sprintf("numeric(%d)", field.Precision)
|
|
}
|
|
return "decimal"
|
|
case schema.String:
|
|
if field.Size > 0 {
|
|
return fmt.Sprintf("varchar(%d)", field.Size)
|
|
}
|
|
return "text"
|
|
case schema.Time:
|
|
if field.Precision > 0 {
|
|
return fmt.Sprintf("timestamptz(%d)", field.Precision)
|
|
}
|
|
return "timestamptz"
|
|
case schema.Bytes:
|
|
return "bytea"
|
|
default:
|
|
return dialector.getSchemaCustomType(field)
|
|
}
|
|
}
|
|
|
|
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
|
|
sqlType := string(field.DataType)
|
|
|
|
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
|
|
size := field.Size
|
|
if field.GORMDataType == schema.Uint {
|
|
size++
|
|
}
|
|
switch {
|
|
case size <= 16:
|
|
sqlType = "smallserial"
|
|
case size <= 32:
|
|
sqlType = "serial"
|
|
default:
|
|
sqlType = "bigserial"
|
|
}
|
|
}
|
|
|
|
return sqlType
|
|
}
|
|
|
|
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
|
|
tx.Exec("SAVEPOINT " + name)
|
|
return nil
|
|
}
|
|
|
|
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
|
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
|
return nil
|
|
}
|
|
|
|
func getSerialDatabaseType(s string) (dbType string, ok bool) {
|
|
switch s {
|
|
case "smallserial":
|
|
return "smallint", true
|
|
case "serial":
|
|
return "integer", true
|
|
case "bigserial":
|
|
return "bigint", true
|
|
default:
|
|
return "", false
|
|
}
|
|
}
|