1
0
mirror of https://github.com/ManyakRus/crud_generator.git synced 2025-03-11 14:59:21 +02:00

сделал NEED_CREATE_CACHE_API

This commit is contained in:
Nikitin Aleksandr 2024-03-06 16:03:41 +03:00
parent 55ca8fa57d
commit c9fa7505f3
164 changed files with 2716 additions and 1267 deletions

View File

@ -14,7 +14,7 @@ run:
./bin/$(SERVICENAME)
mod:
clear
go get -u ./...
go get -u ./cmd/crud_generator/... ./internal/... ./pkg/...
go mod tidy -compat=1.18
go mod vendor
go fmt ./...

19
go.mod
View File

@ -3,7 +3,7 @@ module github.com/ManyakRus/crud_generator
go 1.20
require (
github.com/ManyakRus/starter v0.0.0-20231227074038-f1cc2e5171fa
github.com/ManyakRus/starter v1.0.10
github.com/bxcodec/faker/v3 v3.8.1
github.com/davecgh/go-spew v1.1.1
github.com/iancoleman/strcase v0.3.0
@ -13,22 +13,23 @@ require (
github.com/ompluscator/dynamic-struct v1.4.0
github.com/otiai10/copy v1.14.0
github.com/serenize/snaker v0.0.0-20201027110005-a7ad2135616e
golang.org/x/tools v0.16.1
gorm.io/gorm v1.25.5
golang.org/x/tools v0.19.0
gorm.io/gorm v1.25.7
)
require (
github.com/ManyakRus/logrus v0.0.0-20231019115155-9e6fede0d792 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.1 // indirect
github.com/jackc/pgx/v5 v5.5.4 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/joho/godotenv v1.5.1 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/mod v0.14.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sys v0.15.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.16.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
gorm.io/driver/postgres v1.5.4 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gorm.io/driver/postgres v1.5.6 // indirect
)

41
go.sum
View File

@ -3,8 +3,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJc
github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8=
github.com/ManyakRus/logrus v0.0.0-20231019115155-9e6fede0d792 h1:bxwxD0H3kSUAH3uNk/b74gkImcUiP7dyibmMoVwk338=
github.com/ManyakRus/logrus v0.0.0-20231019115155-9e6fede0d792/go.mod h1:OUyxCVbPW/2lC1e6cM7Am941SJiC88BhNnb24x2R3a8=
github.com/ManyakRus/starter v0.0.0-20231227074038-f1cc2e5171fa h1:v+rN/vdH1ODCja4iTka/3QToxu4ey76u76pkFHcSfOI=
github.com/ManyakRus/starter v0.0.0-20231227074038-f1cc2e5171fa/go.mod h1:1fRj4AUMGeQTtnwBa52pvMd9zwqPDms+uaxozhHkM1Q=
github.com/ManyakRus/starter v1.0.10 h1:7KMmVKEi7uogtg6+Z2RQEq08i4ynUgq+dWPSeUfAwGw=
github.com/ManyakRus/starter v1.0.10/go.mod h1:bIpOiyctURuRRuEJs0xCsBPvpoZFbpFEPX37wTvtuOI=
github.com/bxcodec/faker/v3 v3.8.1 h1:qO/Xq19V6uHt2xujwpaetgKhraGCapqY2CRWGD/SqcM=
github.com/bxcodec/faker/v3 v3.8.1/go.mod h1:DdSDccxF5msjFo5aO4vrobRQ8nIApg8kq3QWPEQD6+o=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
@ -62,8 +62,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jimsmart/schema v0.2.1 h1:MsSsqq0i86bUskhJJZ6RnrgscbDeBMalLZym6Hx9l3U=
@ -135,15 +135,15 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -162,7 +162,7 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -170,8 +170,8 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -197,8 +197,8 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -221,8 +221,8 @@ golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ=
golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA=
golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -238,7 +238,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
@ -252,8 +253,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
gorm.io/driver/postgres v1.5.6 h1:ydr9xEd5YAM0vxVDY0X139dyzNz10spDiDlC7+ibLeU=
gorm.io/driver/postgres v1.5.6/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View File

@ -3,26 +3,44 @@
package config_main
import (
"github.com/ManyakRus/starter/logger"
"github.com/ManyakRus/starter/log"
"github.com/ManyakRus/starter/micro"
"github.com/joho/godotenv"
"os"
"strings"
//log "github.com/sirupsen/logrus"
//log "github.com/sirupsen/logrus"
//"gitlab.aescorp.ru/dsp_dev/notifier/notifier_adp_eml/internal/v0/app/types"
//"gitlab.aescorp.ru/dsp_dev/notifier/notifier_adp_eml/internal/v0/app/micro"
)
// log хранит используемый логгер
var log = logger.GetLog()
// LoadEnv - загружает из файла .env переменные в переменные окружения
func LoadEnv() {
dir := micro.ProgramDir()
filename := dir + ".env"
LoadEnv_from_file(filename)
}
// LoadEnvTest - загружает из файла .env переменные в переменные окружения, кроме для STAGE=dev или prod
// для модулей тестирования _test.go
func LoadEnvTest() {
dir := micro.ProgramDir()
filename := dir + ".env"
//не загружаем для STAGE=dev, т.к. переменные окружения кубернетеса
stage := os.Getenv("STAGE")
stage = strings.ToLower(stage)
stage = strings.TrimSpace(stage)
log.Info("STAGE: ", stage)
if stage == "dev" || stage == "prod" {
log.Info("LoadEnv() ignore STAGE: dev, filename: ", filename)
return
}
//
LoadEnv_from_file(filename)
}
// LoadEnv - загружает из файла .env переменные в переменные окружения, возвращает ошибку
func LoadEnv_err() error {
var err error

View File

@ -0,0 +1,16 @@
// модуль для хранения постоянных переменных, констант
package constants
import (
"time"
)
var Loc = time.Local
// CONNECTION_ID - ИД в БД Рапира в таблице connections
var CONNECTION_ID int64 = 3 //7
// BRANCH_ID - ИД в БД Рапира в таблице branches
var BRANCH_ID int64 = 2 //20954
var TIME_ZONE = "Europe/Moscow"

View File

@ -802,3 +802,9 @@ func Int64FromString(s string) (int64, error) {
return Otvet, err
}
// FindLastPos - возвращает позицию последнего вхождения
func FindLastPos(s, TextFind string) int {
Otvet := strings.LastIndex(s, TextFind)
return Otvet
}

View File

@ -6,6 +6,7 @@ import (
"context"
"errors"
"fmt"
"github.com/ManyakRus/starter/constants"
"github.com/ManyakRus/starter/logger"
"github.com/ManyakRus/starter/port_checker"
"strings"
@ -38,12 +39,12 @@ var log = logger.GetLog()
// mutexReconnect - защита от многопоточности Reconnect()
var mutexReconnect = &sync.Mutex{}
// Settings хранит все нужные переменные окружения
var Settings SettingsINI
// NeedReconnect - флаг необходимости переподключения
var NeedReconnect bool
// Settings хранит все нужные переменные окружения
var Settings SettingsINI
// SettingsINI - структура для хранения всех нужных переменных окружения
type SettingsINI struct {
DB_HOST string
@ -97,7 +98,9 @@ func Connect_WithApplicationName_err(ApplicationName string) error {
dsn := GetDSN(ApplicationName)
//
conf := &gorm.Config{}
conf := &gorm.Config{
Logger: gormlogger.Default.LogMode(gormlogger.Silent),
}
//conn := postgres.Open(dsn)
dialect := postgres.New(postgres.Config{
@ -108,7 +111,7 @@ func Connect_WithApplicationName_err(ApplicationName string) error {
//Conn, err = gorm.Open(conn, conf)
Conn.Config.NamingStrategy = schema.NamingStrategy{TablePrefix: Settings.DB_SCHEMA + "."}
Conn.Config.Logger = gormlogger.Default.LogMode(gormlogger.Warn)
//Conn.Config.Logger = gormlogger.Default.LogMode(gormlogger.Error)
if err == nil {
DB, err := Conn.DB()
@ -334,7 +337,8 @@ func GetDSN(ApplicationName string) string {
dsn += "user=" + Settings.DB_USER + " "
dsn += "password=" + Settings.DB_PASSWORD + " "
dsn += "dbname=" + Settings.DB_NAME + " "
dsn += "port=" + Settings.DB_PORT + " sslmode=disable TimeZone=UTC "
dsn += "port=" + Settings.DB_PORT + " "
dsn += "sslmode=disable TimeZone=" + constants.TIME_ZONE + " "
dsn += "application_name=" + ApplicationName
return dsn
@ -382,7 +386,11 @@ loop:
} else if NeedReconnect == true {
log.Warn("postgres_gorm CheckPort(", addr, ") OK. Start Reconnect()")
NeedReconnect = false
Connect()
err = Connect_err()
if err != nil {
NeedReconnect = true
log.Error("Connect_err() error: ", err)
}
}
}
}
@ -390,32 +398,72 @@ loop:
stopapp.GetWaitGroup_Main().Done()
}
//// RawMultipleSQL - выполняет текст запроса, отдельно для каждого запроса
//func RawMultipleSQL(db *gorm.DB, TextSQL string) *gorm.DB {
// var tx *gorm.DB
// var err error
// tx = db
//
// // запустим все запросы отдельно
// sqlSlice := strings.Split(TextSQL, ";")
// len1 := len(sqlSlice)
// for i, v := range sqlSlice {
// if i == len1-1 {
// tx = tx.Raw(v)
// err = tx.Error
// } else {
// tx = tx.Exec(v)
// err = tx.Error
// }
// if err != nil {
// TextError := fmt.Sprint("db.Raw() error: ", err, ", TextSQL: \n", v)
// err = errors.New(TextError)
// break
// }
// }
//
// if tx == nil {
// log.Panic("db.Raw() error: rows =nil")
// }
//
// return tx
//}
// RawMultipleSQL - выполняет текст запроса, отдельно для каждого запроса
func RawMultipleSQL(db *gorm.DB, TextSQL string) *gorm.DB {
var tx *gorm.DB
var err error
tx = db
// запустим все запросы отдельно
sqlSlice := strings.Split(TextSQL, ";")
len1 := len(sqlSlice)
for i, v := range sqlSlice {
if i == len1-1 {
tx = tx.Raw(v)
err = tx.Error
} else {
tx = tx.Exec(v)
err = tx.Error
}
if tx == nil {
log.Error("RawMultipleSQL() error: db =nil")
return tx
}
TextSQL1 := ""
TextSQL2 := TextSQL
//запустим все запросы, кроме последнего
pos1 := strings.LastIndex(TextSQL, ";")
if pos1 > 0 {
TextSQL1 = TextSQL[0:pos1]
TextSQL2 = TextSQL[pos1:]
tx = tx.Exec(TextSQL1)
err = tx.Error
if err != nil {
TextError := fmt.Sprint("db.Raw() error: ", err, ", TextSQL: \n", v)
TextError := fmt.Sprint("db.Exec() error: ", err, ", TextSQL: \n", TextSQL1)
err = errors.New(TextError)
break
return tx
}
}
if tx == nil {
log.Panic("db.Raw() error: rows =nil")
//запустим последний запрос, с возвратом результата
tx = tx.Raw(TextSQL2)
err = tx.Error
if err != nil {
TextError := fmt.Sprint("db.Raw() error: ", err, ", TextSQL: \n", TextSQL2)
err = errors.New(TextError)
return tx
}
return tx

View File

@ -1,3 +1,43 @@
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists
* Improve CopyFrom auto-conversion of text-ish values
* Add ltree type support (Florent Viel)
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
* Add AppendRows function (Edoardo Spadolini)
* Optimize convert UUID [16]byte to string (Kirill Malikov)
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford)

View File

@ -79,20 +79,11 @@ echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
cp testsetup/ca.cnf .testdb
cp testsetup/localhost.cnf .testdb
cp testsetup/pgx_sslcert.cnf .testdb
cd .testdb
# Generate a CA public / private key pair.
openssl genrsa -out ca.key 4096
openssl req -x509 -config ca.cnf -new -nodes -key ca.key -sha256 -days 365 -subj '/O=pgx-test-root' -out ca.pem
# Generate the certificate for localhost (the server).
openssl genrsa -out localhost.key 2048
openssl req -new -config localhost.cnf -key localhost.key -out localhost.csr
openssl x509 -req -in localhost.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out localhost.crt -days 364 -sha256 -extfile localhost.cnf -extensions v3_req
# Generate CA, server, and encrypted client certificates.
go run ../testsetup/generate_certs.go
# Copy certificates to server directory and set permissions.
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
@ -100,11 +91,6 @@ cp localhost.key $POSTGRESQL_DATA_DIR/server.key
chmod 600 $POSTGRESQL_DATA_DIR/server.key
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
# Generate the certificate for client authentication.
openssl genrsa -des3 -out pgx_sslcert.key -passout pass:certpw 2048
openssl req -new -config pgx_sslcert.cnf -key pgx_sslcert.key -passin pass:certpw -out pgx_sslcert.csr
openssl x509 -req -in pgx_sslcert.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out pgx_sslcert.crt -days 363 -sha256 -extfile pgx_sslcert.cnf -extensions v3_req
cd ..
```

View File

@ -120,6 +120,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)

View File

@ -10,8 +10,8 @@ import (
// QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct {
query string
arguments []any
SQL string
Arguments []any
fn batchItemFunc
sd *pgconn.StatementDescription
}
@ -57,7 +57,7 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
// Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once.
type Batch struct {
queuedQueries []*QueuedQuery
QueuedQueries []*QueuedQuery
}
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement.
@ -65,16 +65,16 @@ type Batch struct {
// connection's DefaultQueryExecMode.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{
query: query,
arguments: arguments,
SQL: query,
Arguments: arguments,
}
b.queuedQueries = append(b.queuedQueries, qq)
b.QueuedQueries = append(b.QueuedQueries, qq)
return qq
}
// Len returns number of queries that have been queued so far.
func (b *Batch) Len() int {
return len(b.queuedQueries)
return len(b.QueuedQueries)
}
type BatchResults interface {
@ -227,9 +227,9 @@ func (br *batchResults) Close() error {
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br)
if err != nil {
br.err = err
}
@ -253,10 +253,10 @@ func (br *batchResults) earlyError() error {
}
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}
@ -396,9 +396,9 @@ func (br *pipelineBatchResults) Close() error {
}
// Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br)
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.QueuedQueries[br.qqIdx].fn != nil {
err := br.b.QueuedQueries[br.qqIdx].fn(br)
if err != nil {
br.err = err
}
@ -422,10 +422,10 @@ func (br *pipelineBatchResults) earlyError() error {
}
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) {
bi := br.b.queuedQueries[br.qqIdx]
query = bi.query
args = bi.arguments
if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.QueuedQueries[br.qqIdx]
query = bi.SQL
args = bi.Arguments
ok = true
br.qqIdx++
}

View File

@ -513,6 +513,7 @@ optionLoop:
if err != nil {
return pgconn.CommandTag{}, err
}
c.descriptionCache.Put(sd)
}
return c.execParams(ctx, sd, arguments)
@ -902,10 +903,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
return &batchResults{ctx: ctx, conn: c, err: err}
}
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
var queryRewriter QueryRewriter
sql := bi.query
arguments := bi.arguments
sql := bi.SQL
arguments := bi.Arguments
optionLoop:
for len(arguments) > 0 {
@ -927,8 +928,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
}
}
bi.query = sql
bi.arguments = arguments
bi.SQL = sql
bi.Arguments = arguments
}
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
@ -938,8 +939,8 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
}
// All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.queuedQueries {
if sd, ok := c.preparedStatements[bi.query]; ok {
for _, bi := range b.QueuedQueries {
if sd, ok := c.preparedStatements[bi.SQL]; ok {
bi.sd = sd
}
}
@ -960,11 +961,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder
for i, bi := range b.queuedQueries {
for i, bi := range b.QueuedQueries {
if i > 0 {
sb.WriteByte(';')
}
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
@ -983,21 +984,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{}
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
sd := bi.sd
if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
}
}
@ -1022,18 +1023,18 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
if bi.sd == nil {
sd := c.statementCache.Get(bi.query)
sd := c.statementCache.Get(bi.SQL)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
Name: stmtcache.StatementName(bi.query),
SQL: bi.query,
Name: stmtcache.StatementName(bi.SQL),
SQL: bi.SQL,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@ -1054,17 +1055,17 @@ func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query)
sd := c.descriptionCache.Get(bi.SQL)
if sd != nil {
bi.sd = sd
} else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd = &pgconn.StatementDescription{
SQL: bi.query,
SQL: bi.SQL,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@ -1081,13 +1082,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries {
for _, bi := range b.QueuedQueries {
if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx]
} else {
sd := &pgconn.StatementDescription{
SQL: bi.query,
SQL: bi.SQL,
}
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd)
@ -1153,11 +1154,11 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
}
// Queue the queries.
for _, bi := range b.queuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.query, err)
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}
@ -1202,7 +1203,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
return sanitize.SanitizeSQL(sql, valueArgs...)
}
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
// typeName must be one of the following:
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
// - A composite type name where all field types are already registered.
// - A domain type name where the base type is already registered.
// - An enum type name.
// - A range type name where the element type is already registered.
// - A multirange type name where the element type is already registered.
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
var oid uint32
@ -1345,17 +1354,17 @@ order by attnum`,
}
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.pgConn.TxStatus() != 'I' {
if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
return nil
}
if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated()
c.descriptionCache.RemoveInvalidated()
}
var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated()
invalidatedStatements = c.statementCache.GetInvalidated()
}
if len(invalidatedStatements) == 0 {
@ -1367,7 +1376,6 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
for _, sd := range invalidatedStatements {
pipeline.SendDeallocate(sd.Name)
delete(c.preparedStatements, sd.Name)
}
err := pipeline.Sync()
@ -1380,5 +1388,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
}
c.statementCache.RemoveInvalidated()
for _, sd := range invalidatedStatements {
delete(c.preparedStatements, sd.Name)
}
return nil
}

View File

@ -187,7 +187,7 @@ implemented on top of pgconn. The Conn.PgConn() method can be used to access thi
PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be
By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
*/
package pgx

View File

@ -63,6 +63,10 @@ func (q *Query) Sanitize(args ...any) (string, error) {
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = "(" + str + ")"
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}

View File

@ -81,12 +81,16 @@ func (c *LRUCache) InvalidateAll() {
c.l = list.New()
}
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
// Typically, the caller will then deallocate them.
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.

View File

@ -29,8 +29,13 @@ type Cache interface {
// InvalidateAll invalidates all statement descriptions.
InvalidateAll()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated.
HandleInvalidated() []*pgconn.StatementDescription
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions.
Len() int

View File

@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
c.m = make(map[string]*pgconn.StatementDescription)
}
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription {
invalidStmts := c.invalidStmts
// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil
return invalidStmts
}
// Len returns the number of cached prepared statement descriptions.

View File

@ -6,6 +6,11 @@ import (
"io"
)
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
// was created.
//
@ -68,32 +73,64 @@ type LargeObject struct {
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
func (o *LargeObject) Write(p []byte) (int, error) {
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n)
if err != nil {
return n, err
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
if err != nil {
return nTotal, err
}
if n < 0 {
return nTotal, errors.New("failed to write to large object")
}
nTotal += n
if n < expected {
return nTotal, errors.New("short write to large object")
} else if n > expected {
return nTotal, errors.New("invalid write to large object")
}
}
if n < 0 {
return 0, errors.New("failed to write to large object")
}
return n, nil
return nTotal, nil
}
// Read reads up to len(p) bytes into p returning the number of bytes read.
func (o *LargeObject) Read(p []byte) (int, error) {
var res []byte
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res)
copy(p, res)
if err != nil {
return len(res), err
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var res []byte
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
copy(p[nTotal:], res)
nTotal += len(res)
if err != nil {
return nTotal, err
}
if len(res) < expected {
return nTotal, io.EOF
} else if len(res) > expected {
return nTotal, errors.New("invalid read of large object")
}
}
if len(res) < len(p) {
err = io.EOF
}
return len(res), err
return nTotal, nil
}
// Seek moves the current location pointer to the new location specified by offset.

View File

@ -14,6 +14,9 @@ import (
//
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
@ -80,7 +83,7 @@ func rawState(l *sqlLexer) stateFn {
return doubleQuoteState
case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) {
if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}

View File

@ -60,6 +60,11 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
@ -232,12 +237,12 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
@ -246,7 +251,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@ -261,12 +266,19 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w)
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@ -328,7 +340,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
@ -340,7 +352,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
}
}
@ -384,7 +396,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
@ -709,6 +721,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte
var decryptedKey []byte
var decryptedError error

View File

@ -57,22 +57,23 @@ func (pe *PgError) SQLState() string {
return pe.Code
}
type connectError struct {
config *Config
// ConnectError is the error returned when a connection attempt fails.
type ConnectError struct {
Config *Config // The configuration that was used in the connection attempt.
msg string
err error
}
func (e *connectError) Error() string {
func (e *ConnectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}
func (e *connectError) Unwrap() error {
func (e *ConnectError) Unwrap() error {
return e.err
}
@ -88,33 +89,38 @@ func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
// ParseConfigError is the error returned when a connection string cannot be parsed.
type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string
err error
}
func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
func (e *ParseConfigError) Unwrap() error {
return e.err
}
func normalizeTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: err}
return &errTimeout{err: netErr}
}
}
return err

View File

@ -52,6 +52,12 @@ type LookupFunc func(ctx context.Context, host string) (addrs []string, err erro
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) *pgproto3.Frontend
// PgErrorHandler is a function that handles errors returned from Postgres. This function must return true to keep
// the connection open. Returning false will cause the connection to be closed immediately. You should return
// false on any FATAL-severity errors. This will not receive network errors. The *PgConn is provided so the handler is
// aware of the origin of the error, but it must not invoke any query method.
type PgErrorHandler func(*PgConn, *PgError) bool
// NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at
// any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin
// of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY
@ -146,11 +152,11 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
ctx := octx
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 {
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
foundBestServer := false
@ -172,7 +178,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
foundBestServer = true
break
} else if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
@ -183,7 +189,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
break
}
} else if cerr, ok := err.(*connectError); ok {
} else if cerr, ok := err.(*ConnectError); ok {
if _, ok := cerr.err.(*NotPreferredError); ok {
fallbackConfig = fc
}
@ -193,7 +199,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
if !foundBestServer && fallbackConfig != nil {
pgConn, err = connect(ctx, config, fallbackConfig, true)
if pgerr, ok := err.(*PgError); ok {
err = &connectError{config: config, msg: "server error", err: pgerr}
err = &ConnectError{Config: config, msg: "server error", err: pgerr}
}
}
@ -205,7 +211,7 @@ func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err er
err := config.AfterConnect(ctx, pgConn)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "AfterConnect error", err: err}
return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err}
}
}
@ -277,7 +283,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
netConn, err := config.DialFunc(ctx, network, address)
if err != nil {
return nil, &connectError{config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)}
}
pgConn.conn = netConn
@ -289,7 +295,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
if err != nil {
netConn.Close()
return nil, &connectError{config: config, msg: "tls error", err: err}
return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)}
}
pgConn.conn = nbTLSConn
@ -330,7 +336,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
pgConn.frontend.Send(&startupMsg)
if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)}
}
for {
@ -340,7 +346,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
if err, ok := err.(*PgError); ok {
return nil, err
}
return nil, &connectError{config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)}
}
switch msg := msg.(type) {
@ -353,26 +359,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
err = pgConn.txPasswordMessage(pgConn.config.Password)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:]))
err = pgConn.txPasswordMessage(digestedPassword)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed to write password message", err: err}
return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err}
}
case *pgproto3.AuthenticationSASL:
err = pgConn.scramAuth(msg.AuthMechanisms)
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed SASL auth", err: err}
return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err}
}
case *pgproto3.AuthenticationGSS:
err = pgConn.gssAuth()
if err != nil {
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "failed GSS auth", err: err}
return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err}
}
case *pgproto3.ReadyForQuery:
pgConn.status = connStatusIdle
@ -390,7 +396,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return pgConn, nil
}
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "ValidateConnect failed", err: err}
return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err}
}
}
return pgConn, nil
@ -401,7 +407,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
return nil, ErrorResponseToPgError(msg)
default:
pgConn.conn.Close()
return nil, &connectError{config: config, msg: "received unexpected message", err: err}
return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err}
}
}
}
@ -547,11 +553,12 @@ func (pgConn *PgConn) receiveMessage() (pgproto3.BackendMessage, error) {
case *pgproto3.ParameterStatus:
pgConn.parameterStatuses[msg.Name] = msg.Value
case *pgproto3.ErrorResponse:
if msg.Severity == "FATAL" {
err := ErrorResponseToPgError(msg)
if pgConn.config.OnPgError != nil && !pgConn.config.OnPgError(pgConn, err) {
pgConn.status = connStatusClosed
pgConn.conn.Close() // Ignore error as the connection is already broken and there is already an error to return.
close(pgConn.cleanupDone)
return nil, ErrorResponseToPgError(msg)
return nil, err
}
case *pgproto3.NoticeResponse:
if pgConn.config.OnNotice != nil {
@ -1667,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) {
// Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip.
type Batch struct {
buf []byte
err error
}
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.ExecPrepared("", paramValues, paramFormats, resultFormats)
}
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
batch.buf = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf)
if batch.err != nil {
return
}
batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf)
if batch.err != nil {
return
}
}
// ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a
// transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing
// multiple queries in a single round trip than using pipeline mode.
func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader {
if batch.err != nil {
return &MultiResultReader{
closed: true,
err: batch.err,
}
}
if err := pgConn.lock(); err != nil {
return &MultiResultReader{
closed: true,
@ -1711,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
pgConn.contextWatcher.Watch(ctx)
}
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
if batch.err != nil {
multiResult.closed = true
multiResult.err = batch.err
pgConn.unlock()
return multiResult
}
pgConn.enterPotentialWriteReadDeadlock()
defer pgConn.exitPotentialWriteReadDeadlock()
@ -2046,6 +2089,13 @@ func (p *Pipeline) Flush() error {
// Sync establishes a synchronization point and flushes the queued requests.
func (p *Pipeline) Sync() error {
if p.closed {
if p.err != nil {
return p.err
}
return errors.New("pipeline closed")
}
p.conn.frontend.SendSync(&pgproto3.Sync{})
err := p.Flush()
if err != nil {
@ -2062,13 +2112,26 @@ func (p *Pipeline) Sync() error {
// *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError. If no
// results are available, results and err will both be nil.
func (p *Pipeline) GetResults() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}
if p.expectedReadyForQueryCount == 0 {
return nil, nil
}
return p.getResults()
}
func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
if err != nil {
p.closed = true
p.err = err
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
@ -2092,7 +2155,8 @@ func (p *Pipeline) GetResults() (results any, err error) {
case *pgproto3.ParseComplete:
peekedMsg, err := p.conn.peekMessage()
if err != nil {
return nil, err
p.conn.asyncClose()
return nil, normalizeTimeoutError(p.ctx, err)
}
if _, ok := peekedMsg.(*pgproto3.ParameterDescription); ok {
return p.getResultsPrepare()
@ -2152,6 +2216,7 @@ func (p *Pipeline) Close() error {
if p.closed {
return p.err
}
p.closed = true
if p.pendingSync {
@ -2164,7 +2229,7 @@ func (p *Pipeline) Close() error {
}
for p.expectedReadyForQueryCount > 0 {
_, err := p.GetResults()
_, err := p.getResults()
if err != nil {
p.err = err
var pgErr *PgError

View File

@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
return finishMessage(dst, sp)
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

View File

@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
return finishMessage(dst, sp)
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

View File

@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Unmarshaler.

View File

@ -16,7 +16,8 @@ type Backend struct {
// before it is actually transmitted (i.e. before Flush).
tracer *tracer
wbuf []byte
wbuf []byte
encodeError error
// Frontend message flyweights
bind Bind
@ -38,6 +39,7 @@ type Backend struct {
terminate Terminate
bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte
partialMsg bool
authType uint32
@ -54,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is
// called.
// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}
prevLen := len(b.wbuf)
b.wbuf = msg.Encode(b.wbuf)
newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf
if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
}
@ -66,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
// Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
n, err := b.w.Write(b.wbuf)
const maxLen = 1024
@ -158,6 +176,9 @@ func (b *Backend) Receive() (FrontendMessage, error) {
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.partialMsg = true
}
@ -260,3 +281,12 @@ func (b *Backend) SetAuthType(authType uint32) error {
return nil
}
// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return
// an error. This is useful for protecting against malicious clients that send large messages with the intent of
// causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,7 +5,9 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'B')
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, p...)
}
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Close struct {
@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Close) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CommandComplete struct {
@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'C')
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'W')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyData struct {
@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) []byte {
dst = append(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
func (src *CopyData) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'd')
dst = append(dst, src.Data...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) []byte {
return append(dst, 'c', 0, 0, 0, 4)
func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
return append(dst, 'c', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type CopyFail struct {
@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) []byte {
dst = append(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'f')
dst = append(dst, src.Message...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) []byte {
dst = append(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'G')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) []byte {
dst = append(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'H')
dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,6 +4,8 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, v...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,8 +4,6 @@ import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Describe struct {
@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Describe) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'D')
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
return append(dst, 'I', 0, 0, 0, 4)
func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, 'I', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -2,7 +2,6 @@ package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"strconv"
)
@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) []byte {
return append(dst, src.marshalBinary('E')...)
func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = src.appendFields(dst)
return finishMessage(dst, sp)
}
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
func (src *ErrorResponse) appendFields(dst []byte) []byte {
if src.Severity != "" {
buf.WriteByte('S')
buf.WriteString(src.Severity)
buf.WriteByte(0)
dst = append(dst, 'S')
dst = append(dst, src.Severity...)
dst = append(dst, 0)
}
if src.SeverityUnlocalized != "" {
buf.WriteByte('V')
buf.WriteString(src.SeverityUnlocalized)
buf.WriteByte(0)
dst = append(dst, 'V')
dst = append(dst, src.SeverityUnlocalized...)
dst = append(dst, 0)
}
if src.Code != "" {
buf.WriteByte('C')
buf.WriteString(src.Code)
buf.WriteByte(0)
dst = append(dst, 'C')
dst = append(dst, src.Code...)
dst = append(dst, 0)
}
if src.Message != "" {
buf.WriteByte('M')
buf.WriteString(src.Message)
buf.WriteByte(0)
dst = append(dst, 'M')
dst = append(dst, src.Message...)
dst = append(dst, 0)
}
if src.Detail != "" {
buf.WriteByte('D')
buf.WriteString(src.Detail)
buf.WriteByte(0)
dst = append(dst, 'D')
dst = append(dst, src.Detail...)
dst = append(dst, 0)
}
if src.Hint != "" {
buf.WriteByte('H')
buf.WriteString(src.Hint)
buf.WriteByte(0)
dst = append(dst, 'H')
dst = append(dst, src.Hint...)
dst = append(dst, 0)
}
if src.Position != 0 {
buf.WriteByte('P')
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
dst = append(dst, 'P')
dst = append(dst, strconv.Itoa(int(src.Position))...)
dst = append(dst, 0)
}
if src.InternalPosition != 0 {
buf.WriteByte('p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
dst = append(dst, 'p')
dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
dst = append(dst, 0)
}
if src.InternalQuery != "" {
buf.WriteByte('q')
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
dst = append(dst, 'q')
dst = append(dst, src.InternalQuery...)
dst = append(dst, 0)
}
if src.Where != "" {
buf.WriteByte('W')
buf.WriteString(src.Where)
buf.WriteByte(0)
dst = append(dst, 'W')
dst = append(dst, src.Where...)
dst = append(dst, 0)
}
if src.SchemaName != "" {
buf.WriteByte('s')
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
dst = append(dst, 's')
dst = append(dst, src.SchemaName...)
dst = append(dst, 0)
}
if src.TableName != "" {
buf.WriteByte('t')
buf.WriteString(src.TableName)
buf.WriteByte(0)
dst = append(dst, 't')
dst = append(dst, src.TableName...)
dst = append(dst, 0)
}
if src.ColumnName != "" {
buf.WriteByte('c')
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
dst = append(dst, 'c')
dst = append(dst, src.ColumnName...)
dst = append(dst, 0)
}
if src.DataTypeName != "" {
buf.WriteByte('d')
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
dst = append(dst, 'd')
dst = append(dst, src.DataTypeName...)
dst = append(dst, 0)
}
if src.ConstraintName != "" {
buf.WriteByte('n')
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
dst = append(dst, 'n')
dst = append(dst, src.ConstraintName...)
dst = append(dst, 0)
}
if src.File != "" {
buf.WriteByte('F')
buf.WriteString(src.File)
buf.WriteByte(0)
dst = append(dst, 'F')
dst = append(dst, src.File...)
dst = append(dst, 0)
}
if src.Line != 0 {
buf.WriteByte('L')
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
dst = append(dst, 'L')
dst = append(dst, strconv.Itoa(int(src.Line))...)
dst = append(dst, 0)
}
if src.Routine != "" {
buf.WriteByte('R')
buf.WriteString(src.Routine)
buf.WriteByte(0)
dst = append(dst, 'R')
dst = append(dst, src.Routine...)
dst = append(dst, 0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteString(v)
buf.WriteByte(0)
dst = append(dst, k)
dst = append(dst, v...)
dst = append(dst, 0)
}
buf.WriteByte(0)
dst = append(dst, 0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
return dst
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) []byte {
dst = append(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Execute) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'E')
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Flush) Encode(dst []byte) []byte {
return append(dst, 'H', 0, 0, 0, 4)
func (src *Flush) Encode(dst []byte) ([]byte, error) {
return append(dst, 'H', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -18,7 +18,8 @@ type Frontend struct {
// idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq.
tracer *tracer
wbuf []byte
wbuf []byte
encodeError error
// Backend message flyweights
authenticationOk AuthenticationOk
@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w}
}
// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is
// called.
// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error
// encountered will be returned from Flush.
//
// Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods
// such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an
// extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden
// behind an interface.
func (f *Frontend) Send(msg FrontendMessage) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg)
}
@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) {
// Flush writes any pending messages to the backend (i.e. the server).
func (f *Frontend) Flush() error {
if err := f.encodeError; err != nil {
f.encodeError = nil
f.wbuf = f.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
if len(f.wbuf) == 0 {
return nil
}
@ -116,71 +133,141 @@ func (f *Frontend) Untrace() {
f.tracer = nil
}
// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendBind(msg *Bind) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendParse(msg *Parse) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendClose(msg *Close) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is
// called. Any error encountered will be returned from Flush.
func (f *Frontend) SendDescribe(msg *Describe) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called.
// Any error encountered will be returned from Flush.
func (f *Frontend) SendExecute(msg *Execute) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendSync(msg *Sync) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg)
}
}
// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until
// Flush is called.
// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any
// error encountered will be returned from Flush.
func (f *Frontend) SendQuery(msg *Query) {
if f.encodeError != nil {
return
}
prevLen := len(f.wbuf)
f.wbuf = msg.Encode(f.wbuf)
newBuf, err := msg.Encode(f.wbuf)
if err != nil {
f.encodeError = err
return
}
f.wbuf = newBuf
if f.tracer != nil {
f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg)
}

View File

@ -2,6 +2,8 @@ package pgproto3
import (
"encoding/binary"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) []byte {
dst = append(dst, 'F')
sp := len(dst)
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
func (src *FunctionCall) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'F')
dst = pgio.AppendUint32(dst, src.Function)
if len(src.ArgFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many arg format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
if len(src.Arguments) > math.MaxUint16 {
return nil, errors.New("too many arguments")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte {
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}

View File

@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, 'V')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'V')
if src.Result == nil {
dst = pgio.AppendInt32(dst, -1)
@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, src.Result...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *GSSEncRequest) Encode(dst []byte) []byte {
func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, gssEncReqNumber)
return dst
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -2,8 +2,6 @@ package pgproto3
import (
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type GSSResponse struct {
@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error {
return nil
}
func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
func (g *GSSResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, g.Data...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoData) Encode(dst []byte) []byte {
return append(dst, 'n', 0, 0, 0, 4)
func (src *NoData) Encode(dst []byte) ([]byte, error) {
return append(dst, 'n', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoticeResponse) Encode(dst []byte) []byte {
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'N')
dst = (*ErrorResponse)(src).appendFields(dst)
return finishMessage(dst, sp)
}

View File

@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NotificationResponse) Encode(dst []byte) []byte {
dst = append(dst, 'A')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'A')
dst = pgio.AppendUint32(dst, src.PID)
dst = append(dst, src.Channel...)
dst = append(dst, 0)
dst = append(dst, src.Payload...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterDescription) Encode(dst []byte) []byte {
dst = append(dst, 't')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 't')
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type ParameterStatus struct {
@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterStatus) Encode(dst []byte) []byte {
dst = append(dst, 'S')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'S')
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Value...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -52,24 +54,23 @@ func (dst *Parse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Parse) Encode(dst []byte) []byte {
dst = append(dst, 'P')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *Parse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'P')
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Query...)
dst = append(dst, 0)
if len(src.ParameterOIDs) > math.MaxUint16 {
return nil, errors.New("too many parameter oids")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParseComplete) Encode(dst []byte) []byte {
return append(dst, '1', 0, 0, 0, 4)
func (src *ParseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '1', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type PasswordMessage struct {
@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PasswordMessage) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, src.Password...)
dst = append(dst, 0)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,8 +4,14 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/jackc/pgx/v5/internal/pgio"
)
// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL
// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff.
const maxMessageBodyLen = (0x3fffffff - 1)
// Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message.
type Message interface {
@ -14,7 +20,7 @@ type Message interface {
Decode(data []byte) error
// Encode appends itself to dst and returns the new buffer.
Encode(dst []byte) []byte
Encode(dst []byte) ([]byte, error)
}
// FrontendMessage is a message sent by the frontend (i.e. the client).
@ -70,6 +76,15 @@ func (e *writeError) Unwrap() error {
return e.err
}
type ExceededMaxBodyLenErr struct {
MaxExpectedBodyLen int
ActualBodyLen int
}
func (e *ExceededMaxBodyLenErr) Error() string {
return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.MaxExpectedBodyLen, e.ActualBodyLen)
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {
@ -83,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) {
}
return nil, errors.New("unknown protocol representation")
}
// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to
// dst. It returns the new buffer and the position of the message length placeholder.
func beginMessage(dst []byte, t byte) ([]byte, int) {
dst = append(dst, t)
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
return dst, sp
}
// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to
// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer.
func finishMessage(dst []byte, sp int) ([]byte, error) {
messageBodyLen := len(dst[sp:])
if messageBodyLen > maxMessageBodyLen {
return nil, errors.New("message body too large")
}
pgio.SetInt32(dst[sp:], int32(messageBodyLen))
return dst, nil
}

View File

@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PortalSuspended) Encode(dst []byte) []byte {
return append(dst, 's', 0, 0, 0, 4)
func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) {
return append(dst, 's', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type Query struct {
@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Query) Encode(dst []byte) []byte {
dst = append(dst, 'Q')
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
func (src *Query) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'Q')
dst = append(dst, src.String...)
dst = append(dst, 0)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ReadyForQuery) Encode(dst []byte) []byte {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,6 +4,8 @@ import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio"
)
@ -99,11 +101,12 @@ func (dst *RowDescription) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *RowDescription) Encode(dst []byte) []byte {
dst = append(dst, 'T')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *RowDescription) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'T')
if len(src.Fields) > math.MaxUint16 {
return nil, errors.New("too many fields")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {
dst = append(dst, fd.Name...)
@ -117,9 +120,7 @@ func (src *RowDescription) Encode(dst []byte) []byte {
dst = pgio.AppendInt16(dst, fd.Format)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, []byte(src.AuthMechanism)...)
dst = append(dst, 0)
@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
)
type SASLResponse struct {
@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
func (src *SASLResponse) Encode(dst []byte) ([]byte, error) {
dst, sp := beginMessage(dst, 'p')
dst = append(dst, src.Data...)
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *SSLRequest) Encode(dst []byte) []byte {
func (src *SSLRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, sslRequestNumber)
return dst
return dst, nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *StartupMessage) Encode(dst []byte) []byte {
func (src *StartupMessage) Encode(dst []byte) ([]byte, error) {
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte {
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
return finishMessage(dst, sp)
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Sync) Encode(dst []byte) []byte {
return append(dst, 'S', 0, 0, 0, 4)
func (src *Sync) Encode(dst []byte) ([]byte, error) {
return append(dst, 'S', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error {
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Terminate) Encode(dst []byte) []byte {
return append(dst, 'X', 0, 0, 0, 4)
func (src *Terminate) Encode(dst []byte) ([]byte, error) {
return append(dst, 'X', 0, 0, 0, 4), nil
}
// MarshalJSON implements encoding/json.Marshaler.

View File

@ -176,8 +176,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error {
bitLen := int32(binary.BigEndian.Uint32(src))
rp := 4
buf := make([]byte, len(src[rp:]))
copy(buf, src[rp:])
return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true})
return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true})
}
type scanPlanTextAnyToBitsScanner struct{}

View File

@ -297,12 +297,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr
return nil, nil
}
var n float64
var n float32
err := codecScan(c, m, oid, format, src, &n)
if err != nil {
return nil, err
}
return n, nil
return float64(n), nil
}
func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {

View File

@ -25,6 +25,11 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod
case []byte:
return encodePlanJSONCodecEitherFormatByteSlice{}
// Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated.
// e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`.
case json.RawMessage:
return encodePlanJSONCodecEitherFormatJSONRawMessage{}
// Cannot rely on driver.Valuer being handled later because anything can be marshalled.
//
// https://github.com/jackc/pgx/issues/1430
@ -79,6 +84,18 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n
return buf, nil
}
type encodePlanJSONCodecEitherFormatJSONRawMessage struct{}
func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) {
jsonBytes := value.(json.RawMessage)
if jsonBytes == nil {
return nil, nil
}
buf = append(buf, jsonBytes...)
return buf, nil
}
type encodePlanJSONCodecEitherFormatMarshal struct{}
func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) {

122
vendor/github.com/jackc/pgx/v5/pgtype/ltree.go generated vendored Normal file
View File

@ -0,0 +1,122 @@
package pgtype
import (
"database/sql/driver"
"fmt"
)
type LtreeCodec struct{}
func (l LtreeCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}
// PreferredFormat returns the preferred format.
func (l LtreeCodec) PreferredFormat() int16 {
return TextFormatCode
}
// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be
// found then nil is returned.
func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case TextFormatCode:
return (TextCodec)(l).PlanEncode(m, oid, format, value)
case BinaryFormatCode:
switch value.(type) {
case string:
return encodeLtreeCodecBinaryString{}
case []byte:
return encodeLtreeCodecBinaryByteSlice{}
case TextValuer:
return encodeLtreeCodecBinaryTextValuer{}
}
}
return nil
}
type encodeLtreeCodecBinaryString struct{}
func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) {
ltree := value.(string)
buf = append(buf, 1)
return append(buf, ltree...), nil
}
type encodeLtreeCodecBinaryByteSlice struct{}
func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) {
ltree := value.([]byte)
buf = append(buf, 1)
return append(buf, ltree...), nil
}
type encodeLtreeCodecBinaryTextValuer struct{}
func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
t, err := value.(TextValuer).TextValue()
if err != nil {
return nil, err
}
if !t.Valid {
return nil, nil
}
buf = append(buf, 1)
return append(buf, t.String...), nil
}
// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If
// no plan can be found then nil is returned.
func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case TextFormatCode:
return (TextCodec)(l).PlanScan(m, oid, format, target)
case BinaryFormatCode:
switch target.(type) {
case *string:
return scanPlanBinaryLtreeToString{}
case TextScanner:
return scanPlanBinaryLtreeToTextScanner{}
}
}
return nil
}
type scanPlanBinaryLtreeToString struct{}
func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error {
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}
p := (target).(*string)
*p = string(src[1:])
return nil
}
type scanPlanBinaryLtreeToTextScanner struct{}
func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error {
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}
scanner := (target).(TextScanner)
return scanner.ScanText(Text{String: string(src[1:]), Valid: true})
}
// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface.
func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src)
}
// DecodeValue returns src decoded into its default format.
func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
return (TextCodec)(l).DecodeValue(m, oid, format, src)
}

View File

@ -81,6 +81,8 @@ const (
IntervalOID = 1186
IntervalArrayOID = 1187
NumericArrayOID = 1231
TimetzOID = 1266
TimetzArrayOID = 1270
BitOID = 1560
BitArrayOID = 1561
VarbitOID = 1562
@ -559,7 +561,7 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex
}
}
if nextDstType != nil && dstValue.Type() != nextDstType {
if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) {
return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true
}

View File

@ -1,6 +1,7 @@
package pgtype
import (
"encoding/json"
"net"
"net/netip"
"reflect"
@ -173,6 +174,7 @@ func initDefaultMap() {
registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz")
registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval")
registerDefaultPgTypeVariants[string](defaultMap, "text")
registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json")
registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea")
registerDefaultPgTypeVariants[net.IP](defaultMap, "inet")

View File

@ -52,7 +52,19 @@ func parseUUID(src string) (dst [16]byte, err error) {
// encodeUUID converts a uuid byte array to UUID standard string form.
func encodeUUID(src [16]byte) string {
return fmt.Sprintf("%x-%x-%x-%x-%x", src[0:4], src[4:6], src[6:8], src[8:10], src[10:16])
var buf [36]byte
hex.Encode(buf[0:8], src[:4])
buf[8] = '-'
hex.Encode(buf[9:13], src[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], src[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], src[8:10])
buf[23] = '-'
hex.Encode(buf[24:], src[10:])
return string(buf[:])
}
// Scan implements the database/sql Scanner interface.

View File

@ -417,12 +417,10 @@ type CollectableRow interface {
// RowToFunc is a function that scans or otherwise converts row to a T.
type RowToFunc[T any] func(row CollectableRow) (T, error)
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
// AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T.
func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) {
defer rows.Close()
slice := []T{}
for rows.Next() {
value, err := fn(rows)
if err != nil {
@ -438,6 +436,11 @@ func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return slice, nil
}
// CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T.
func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) {
return AppendRows([]T{}, rows, fn)
}
// CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true.
// CollectOneRow is to CollectRows as QueryRow is to Query.
func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
@ -667,7 +670,12 @@ const structTagKey = "db"
func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) {
i = -1
for i, desc := range fldDescs {
if strings.EqualFold(desc.Name, field) {
// Snake case support.
field = strings.ReplaceAll(field, "_", "")
descName := strings.ReplaceAll(desc.Name, "_", "")
if strings.EqualFold(descName, field) {
return i
}
}

View File

@ -21,10 +21,7 @@
// return err
// }
//
// db, err := stdlib.OpenDBFromPool(pool)
// if err != nil {
// return err
// }
// db := stdlib.OpenDBFromPool(pool)
//
// Or a pgx.ConnConfig can be used to set configuration not accessible via connection string. In this case the
// pgx.ConnConfig must first be registered with the driver. This registration returns a connection string which is used

View File

@ -55,7 +55,11 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er
func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
s, ok := arg.(string)
if !ok {
return nil, errors.New("not a string")
textBuf, err := m.Encode(oid, TextFormatCode, arg, nil)
if err != nil {
return nil, errors.New("not a string and cannot be encoded as text")
}
s = string(textBuf)
}
var v any

View File

@ -4,6 +4,9 @@
// Package errgroup provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
//
// [errgroup.Group] is related to [sync.WaitGroup] but adds handling of tasks
// returning errors.
package errgroup
import (

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos) && go1.9
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
package unix

View File

@ -248,6 +248,7 @@ struct ltchars {
#include <linux/module.h>
#include <linux/mount.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nf_tables.h>
#include <linux/netlink.h>
#include <linux/net_namespace.h>
#include <linux/nfc.h>
@ -283,10 +284,6 @@ struct ltchars {
#include <asm/termbits.h>
#endif
#ifndef MSG_FASTOPEN
#define MSG_FASTOPEN 0x20000000
#endif
#ifndef PTRACE_GETREGS
#define PTRACE_GETREGS 0xc
#endif
@ -295,14 +292,6 @@ struct ltchars {
#define PTRACE_SETREGS 0xd
#endif
#ifndef SOL_NETLINK
#define SOL_NETLINK 270
#endif
#ifndef SOL_SMC
#define SOL_SMC 286
#endif
#ifdef SOL_BLUETOOTH
// SPARC includes this in /usr/include/sparc64-linux-gnu/bits/socket.h
// but it is already in bluetooth_linux.go
@ -319,10 +308,23 @@ struct ltchars {
#undef TIPC_WAIT_FOREVER
#define TIPC_WAIT_FOREVER 0xffffffff
// Copied from linux/l2tp.h
// Including linux/l2tp.h here causes conflicts between linux/in.h
// and netinet/in.h included via net/route.h above.
#define IPPROTO_L2TP 115
// Copied from linux/netfilter/nf_nat.h
// Including linux/netfilter/nf_nat.h here causes conflicts between linux/in.h
// and netinet/in.h.
#define NF_NAT_RANGE_MAP_IPS (1 << 0)
#define NF_NAT_RANGE_PROTO_SPECIFIED (1 << 1)
#define NF_NAT_RANGE_PROTO_RANDOM (1 << 2)
#define NF_NAT_RANGE_PERSISTENT (1 << 3)
#define NF_NAT_RANGE_PROTO_RANDOM_FULLY (1 << 4)
#define NF_NAT_RANGE_PROTO_OFFSET (1 << 5)
#define NF_NAT_RANGE_NETMAP (1 << 6)
#define NF_NAT_RANGE_PROTO_RANDOM_ALL \
(NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
#define NF_NAT_RANGE_MASK \
(NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED | \
NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PERSISTENT | \
NF_NAT_RANGE_PROTO_RANDOM_FULLY | NF_NAT_RANGE_PROTO_OFFSET | \
NF_NAT_RANGE_NETMAP)
// Copied from linux/hid.h.
// Keep in sync with the size of the referenced fields.
@ -582,7 +584,7 @@ ccflags="$@"
$2 ~ /^KEY_(SPEC|REQKEY_DEFL)_/ ||
$2 ~ /^KEYCTL_/ ||
$2 ~ /^PERF_/ ||
$2 ~ /^SECCOMP_MODE_/ ||
$2 ~ /^SECCOMP_/ ||
$2 ~ /^SEEK_/ ||
$2 ~ /^SCHED_/ ||
$2 ~ /^SPLICE_/ ||
@ -603,6 +605,9 @@ ccflags="$@"
$2 ~ /^FSOPT_/ ||
$2 ~ /^WDIO[CFS]_/ ||
$2 ~ /^NFN/ ||
$2 !~ /^NFT_META_IIFTYPE/ &&
$2 ~ /^NFT_/ ||
$2 ~ /^NF_NAT_/ ||
$2 ~ /^XDP_/ ||
$2 ~ /^RWF_/ ||
$2 ~ /^(HDIO|WIN|SMART)_/ ||

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.12
//go:build darwin
package unix

View File

@ -13,6 +13,7 @@
package unix
import (
"errors"
"sync"
"unsafe"
)
@ -169,25 +170,26 @@ func Getfsstat(buf []Statfs_t, flags int) (n int, err error) {
func Uname(uname *Utsname) error {
mib := []_C_int{CTL_KERN, KERN_OSTYPE}
n := unsafe.Sizeof(uname.Sysname)
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil {
// Suppress ENOMEM errors to be compatible with the C library __xuname() implementation.
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err
}
mib = []_C_int{CTL_KERN, KERN_HOSTNAME}
n = unsafe.Sizeof(uname.Nodename)
if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil {
if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err
}
mib = []_C_int{CTL_KERN, KERN_OSRELEASE}
n = unsafe.Sizeof(uname.Release)
if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil {
if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err
}
mib = []_C_int{CTL_KERN, KERN_VERSION}
n = unsafe.Sizeof(uname.Version)
if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil {
if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err
}
@ -205,7 +207,7 @@ func Uname(uname *Utsname) error {
mib = []_C_int{CTL_HW, HW_MACHINE}
n = unsafe.Sizeof(uname.Machine)
if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil {
if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err
}

View File

@ -1849,6 +1849,105 @@ func Dup2(oldfd, newfd int) error {
//sys Fsmount(fd int, flags int, mountAttrs int) (fsfd int, err error)
//sys Fsopen(fsName string, flags int) (fd int, err error)
//sys Fspick(dirfd int, pathName string, flags int) (fd int, err error)
//sys fsconfig(fd int, cmd uint, key *byte, value *byte, aux int) (err error)
func fsconfigCommon(fd int, cmd uint, key string, value *byte, aux int) (err error) {
var keyp *byte
if keyp, err = BytePtrFromString(key); err != nil {
return
}
return fsconfig(fd, cmd, keyp, value, aux)
}
// FsconfigSetFlag is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_FLAG.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
func FsconfigSetFlag(fd int, key string) (err error) {
return fsconfigCommon(fd, FSCONFIG_SET_FLAG, key, nil, 0)
}
// FsconfigSetString is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_STRING.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// value is the parameter value to set.
func FsconfigSetString(fd int, key string, value string) (err error) {
var valuep *byte
if valuep, err = BytePtrFromString(value); err != nil {
return
}
return fsconfigCommon(fd, FSCONFIG_SET_STRING, key, valuep, 0)
}
// FsconfigSetBinary is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_BINARY.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// value is the parameter value to set.
func FsconfigSetBinary(fd int, key string, value []byte) (err error) {
if len(value) == 0 {
return EINVAL
}
return fsconfigCommon(fd, FSCONFIG_SET_BINARY, key, &value[0], len(value))
}
// FsconfigSetPath is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_PATH.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// path is a non-empty path for specified key.
// atfd is a file descriptor at which to start lookup from or AT_FDCWD.
func FsconfigSetPath(fd int, key string, path string, atfd int) (err error) {
var valuep *byte
if valuep, err = BytePtrFromString(path); err != nil {
return
}
return fsconfigCommon(fd, FSCONFIG_SET_PATH, key, valuep, atfd)
}
// FsconfigSetPathEmpty is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_PATH_EMPTY. The same as
// FconfigSetPath but with AT_PATH_EMPTY implied.
func FsconfigSetPathEmpty(fd int, key string, path string, atfd int) (err error) {
var valuep *byte
if valuep, err = BytePtrFromString(path); err != nil {
return
}
return fsconfigCommon(fd, FSCONFIG_SET_PATH_EMPTY, key, valuep, atfd)
}
// FsconfigSetFd is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_FD.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// value is a file descriptor to be assigned to specified key.
func FsconfigSetFd(fd int, key string, value int) (err error) {
return fsconfigCommon(fd, FSCONFIG_SET_FD, key, nil, value)
}
// FsconfigCreate is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_CMD_CREATE.
//
// fd is the filesystem context to act upon.
func FsconfigCreate(fd int) (err error) {
return fsconfig(fd, FSCONFIG_CMD_CREATE, nil, nil, 0)
}
// FsconfigReconfigure is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_CMD_RECONFIGURE.
//
// fd is the filesystem context to act upon.
func FsconfigReconfigure(fd int) (err error) {
return fsconfig(fd, FSCONFIG_CMD_RECONFIGURE, nil, nil, 0)
}
//sys Getdents(fd int, buf []byte) (n int, err error) = SYS_GETDENTS64
//sysnb Getpgid(pid int) (pgid int, err error)

View File

@ -1785,6 +1785,8 @@ const (
LANDLOCK_ACCESS_FS_REMOVE_FILE = 0x20
LANDLOCK_ACCESS_FS_TRUNCATE = 0x4000
LANDLOCK_ACCESS_FS_WRITE_FILE = 0x2
LANDLOCK_ACCESS_NET_BIND_TCP = 0x1
LANDLOCK_ACCESS_NET_CONNECT_TCP = 0x2
LANDLOCK_CREATE_RULESET_VERSION = 0x1
LINUX_REBOOT_CMD_CAD_OFF = 0x0
LINUX_REBOOT_CMD_CAD_ON = 0x89abcdef
@ -2127,6 +2129,60 @@ const (
NFNL_SUBSYS_QUEUE = 0x3
NFNL_SUBSYS_ULOG = 0x4
NFS_SUPER_MAGIC = 0x6969
NFT_CHAIN_FLAGS = 0x7
NFT_CHAIN_MAXNAMELEN = 0x100
NFT_CT_MAX = 0x17
NFT_DATA_RESERVED_MASK = 0xffffff00
NFT_DATA_VALUE_MAXLEN = 0x40
NFT_EXTHDR_OP_MAX = 0x4
NFT_FIB_RESULT_MAX = 0x3
NFT_INNER_MASK = 0xf
NFT_LOGLEVEL_MAX = 0x8
NFT_NAME_MAXLEN = 0x100
NFT_NG_MAX = 0x1
NFT_OBJECT_CONNLIMIT = 0x5
NFT_OBJECT_COUNTER = 0x1
NFT_OBJECT_CT_EXPECT = 0x9
NFT_OBJECT_CT_HELPER = 0x3
NFT_OBJECT_CT_TIMEOUT = 0x7
NFT_OBJECT_LIMIT = 0x4
NFT_OBJECT_MAX = 0xa
NFT_OBJECT_QUOTA = 0x2
NFT_OBJECT_SECMARK = 0x8
NFT_OBJECT_SYNPROXY = 0xa
NFT_OBJECT_TUNNEL = 0x6
NFT_OBJECT_UNSPEC = 0x0
NFT_OBJ_MAXNAMELEN = 0x100
NFT_OSF_MAXGENRELEN = 0x10
NFT_QUEUE_FLAG_BYPASS = 0x1
NFT_QUEUE_FLAG_CPU_FANOUT = 0x2
NFT_QUEUE_FLAG_MASK = 0x3
NFT_REG32_COUNT = 0x10
NFT_REG32_SIZE = 0x4
NFT_REG_MAX = 0x4
NFT_REG_SIZE = 0x10
NFT_REJECT_ICMPX_MAX = 0x3
NFT_RT_MAX = 0x4
NFT_SECMARK_CTX_MAXLEN = 0x100
NFT_SET_MAXNAMELEN = 0x100
NFT_SOCKET_MAX = 0x3
NFT_TABLE_F_MASK = 0x3
NFT_TABLE_MAXNAMELEN = 0x100
NFT_TRACETYPE_MAX = 0x3
NFT_TUNNEL_F_MASK = 0x7
NFT_TUNNEL_MAX = 0x1
NFT_TUNNEL_MODE_MAX = 0x2
NFT_USERDATA_MAXLEN = 0x100
NFT_XFRM_KEY_MAX = 0x6
NF_NAT_RANGE_MAP_IPS = 0x1
NF_NAT_RANGE_MASK = 0x7f
NF_NAT_RANGE_NETMAP = 0x40
NF_NAT_RANGE_PERSISTENT = 0x8
NF_NAT_RANGE_PROTO_OFFSET = 0x20
NF_NAT_RANGE_PROTO_RANDOM = 0x4
NF_NAT_RANGE_PROTO_RANDOM_ALL = 0x14
NF_NAT_RANGE_PROTO_RANDOM_FULLY = 0x10
NF_NAT_RANGE_PROTO_SPECIFIED = 0x2
NILFS_SUPER_MAGIC = 0x3434
NL0 = 0x0
NL1 = 0x100
@ -2411,6 +2467,7 @@ const (
PR_MCE_KILL_GET = 0x22
PR_MCE_KILL_LATE = 0x0
PR_MCE_KILL_SET = 0x1
PR_MDWE_NO_INHERIT = 0x2
PR_MDWE_REFUSE_EXEC_GAIN = 0x1
PR_MPX_DISABLE_MANAGEMENT = 0x2c
PR_MPX_ENABLE_MANAGEMENT = 0x2b
@ -2615,8 +2672,9 @@ const (
RTAX_FEATURES = 0xc
RTAX_FEATURE_ALLFRAG = 0x8
RTAX_FEATURE_ECN = 0x1
RTAX_FEATURE_MASK = 0xf
RTAX_FEATURE_MASK = 0x1f
RTAX_FEATURE_SACK = 0x2
RTAX_FEATURE_TCP_USEC_TS = 0x10
RTAX_FEATURE_TIMESTAMP = 0x4
RTAX_HOPLIMIT = 0xa
RTAX_INITCWND = 0xb
@ -2859,9 +2917,38 @@ const (
SCM_RIGHTS = 0x1
SCM_TIMESTAMP = 0x1d
SC_LOG_FLUSH = 0x100000
SECCOMP_ADDFD_FLAG_SEND = 0x2
SECCOMP_ADDFD_FLAG_SETFD = 0x1
SECCOMP_FILTER_FLAG_LOG = 0x2
SECCOMP_FILTER_FLAG_NEW_LISTENER = 0x8
SECCOMP_FILTER_FLAG_SPEC_ALLOW = 0x4
SECCOMP_FILTER_FLAG_TSYNC = 0x1
SECCOMP_FILTER_FLAG_TSYNC_ESRCH = 0x10
SECCOMP_FILTER_FLAG_WAIT_KILLABLE_RECV = 0x20
SECCOMP_GET_ACTION_AVAIL = 0x2
SECCOMP_GET_NOTIF_SIZES = 0x3
SECCOMP_IOCTL_NOTIF_RECV = 0xc0502100
SECCOMP_IOCTL_NOTIF_SEND = 0xc0182101
SECCOMP_IOC_MAGIC = '!'
SECCOMP_MODE_DISABLED = 0x0
SECCOMP_MODE_FILTER = 0x2
SECCOMP_MODE_STRICT = 0x1
SECCOMP_RET_ACTION = 0x7fff0000
SECCOMP_RET_ACTION_FULL = 0xffff0000
SECCOMP_RET_ALLOW = 0x7fff0000
SECCOMP_RET_DATA = 0xffff
SECCOMP_RET_ERRNO = 0x50000
SECCOMP_RET_KILL = 0x0
SECCOMP_RET_KILL_PROCESS = 0x80000000
SECCOMP_RET_KILL_THREAD = 0x0
SECCOMP_RET_LOG = 0x7ffc0000
SECCOMP_RET_TRACE = 0x7ff00000
SECCOMP_RET_TRAP = 0x30000
SECCOMP_RET_USER_NOTIF = 0x7fc00000
SECCOMP_SET_MODE_FILTER = 0x1
SECCOMP_SET_MODE_STRICT = 0x0
SECCOMP_USER_NOTIF_FD_SYNC_WAKE_UP = 0x1
SECCOMP_USER_NOTIF_FLAG_CONTINUE = 0x1
SECRETMEM_MAGIC = 0x5345434d
SECURITYFS_MAGIC = 0x73636673
SEEK_CUR = 0x1
@ -3021,6 +3108,7 @@ const (
SOL_TIPC = 0x10f
SOL_TLS = 0x11a
SOL_UDP = 0x11
SOL_VSOCK = 0x11f
SOL_X25 = 0x106
SOL_XDP = 0x11b
SOMAXCONN = 0x1000

View File

@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@ -282,6 +282,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@ -288,6 +288,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@ -278,6 +278,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@ -275,6 +275,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

Some files were not shown because too many files have changed in this diff Show More