mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-13 00:07:30 +02:00
Completed truss code gen for generating model requests and crud.
This commit is contained in:
parent
efaeeb7103
commit
bdbe3c587a
@ -120,6 +120,29 @@ If you have a lot of migrations, it can be a pain to run all them, as an example
|
||||
Another bonus with the globally defined schema allows testing to spin up database containers on demand include all the migrations. The testing package enables unit tests to programmatically execute schema migrations before running any unit tests.
|
||||
|
||||
|
||||
### Accessing Postgres
|
||||
|
||||
Login to the local postgres container
|
||||
```bash
|
||||
docker exec -it example-project_postgres_1 /bin/bash
|
||||
bash-4.4# psql -u postgres shared
|
||||
```
|
||||
|
||||
Show tables
|
||||
```commandline
|
||||
shared=# \dt
|
||||
List of relations
|
||||
Schema | Name | Type | Owner
|
||||
--------+----------------+-------+----------
|
||||
public | accounts | table | postgres
|
||||
public | migrations | table | postgres
|
||||
public | projects | table | postgres
|
||||
public | users | table | postgres
|
||||
public | users_accounts | table | postgres
|
||||
(5 rows)
|
||||
```
|
||||
|
||||
|
||||
## Development Notes
|
||||
|
||||
|
||||
|
@ -3,12 +3,12 @@ package main
|
||||
import (
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
|
||||
"github.com/lib/pq"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
|
||||
"github.com/lib/pq"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
_ "github.com/lib/pq"
|
||||
@ -19,11 +19,16 @@ import (
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
// service is the name of the program used for logging, tracing and the
|
||||
// the prefix used for loading env variables
|
||||
// ie: export SCHEMA_ENV=dev
|
||||
var service = "SCHEMA"
|
||||
|
||||
func main() {
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "Schema : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
@ -40,12 +45,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// The prefix used for loading env variables.
|
||||
// ie: export SCHEMA_ENV=dev
|
||||
envKeyPrefix := "SCHEMA"
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
|
||||
if err := envconfig.Process(service, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
@ -104,7 +105,7 @@ func main() {
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
|
@ -34,12 +34,17 @@ import (
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
// service is the name of the program used for logging, tracing and the
|
||||
// the prefix used for loading env variables
|
||||
// ie: export WEB_API_ENV=dev
|
||||
var service = "WEB_API"
|
||||
|
||||
func main() {
|
||||
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "WEB_API : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
@ -110,12 +115,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// The prefix used for loading env variables.
|
||||
// ie: export WEB_API_ENV=dev
|
||||
envKeyPrefix := "WEB_API"
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
|
||||
if err := envconfig.Process(service, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
@ -243,7 +244,7 @@ func main() {
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
|
@ -40,12 +40,17 @@ import (
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
// service is the name of the program used for logging, tracing and the
|
||||
// the prefix used for loading env variables
|
||||
// ie: export WEB_APP_ENV=dev
|
||||
var service = "WEB_APP"
|
||||
|
||||
func main() {
|
||||
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, "WEB_APP : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
@ -125,12 +130,8 @@ func main() {
|
||||
CMD string `envconfig:"CMD"`
|
||||
}
|
||||
|
||||
// The prefix used for loading env variables.
|
||||
// ie: export WEB_APP_ENV=dev
|
||||
envKeyPrefix := "WEB_APP"
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
|
||||
if err := envconfig.Process(service, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
@ -258,7 +259,7 @@ func main() {
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
@ -522,7 +523,7 @@ func main() {
|
||||
shutdown := make(chan os.Signal, 1)
|
||||
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
api := http.Server{
|
||||
app := http.Server{
|
||||
Addr: cfg.HTTP.Host,
|
||||
Handler: handlers.APP(shutdown, log, cfg.App.StaticDir, cfg.App.TemplateDir, masterDb, nil, renderer),
|
||||
ReadTimeout: cfg.HTTP.ReadTimeout,
|
||||
@ -537,7 +538,7 @@ func main() {
|
||||
// Start the service listening for requests.
|
||||
go func() {
|
||||
log.Printf("main : APP Listening %s", cfg.HTTP.Host)
|
||||
serverErrors <- api.ListenAndServe()
|
||||
serverErrors <- app.ListenAndServe()
|
||||
}()
|
||||
|
||||
// =========================================================================
|
||||
@ -556,10 +557,10 @@ func main() {
|
||||
defer cancel()
|
||||
|
||||
// Asking listener to shutdown and load shed.
|
||||
err := api.Shutdown(ctx)
|
||||
err := app.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.App.ShutdownTimeout, err)
|
||||
err = api.Close()
|
||||
err = app.Close()
|
||||
}
|
||||
|
||||
// Log the status of this shutdown.
|
||||
|
@ -5,6 +5,9 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible
|
||||
github.com/dimfeld/httptreemux v5.0.1+incompatible
|
||||
github.com/dustin/go-humanize v1.0.0
|
||||
github.com/fatih/camelcase v1.0.0
|
||||
github.com/fatih/structtag v1.0.0
|
||||
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db
|
||||
github.com/go-playground/locales v0.12.1
|
||||
github.com/go-playground/universal-translator v0.16.0
|
||||
@ -12,6 +15,7 @@ require (
|
||||
github.com/golang/protobuf v1.3.1 // indirect
|
||||
github.com/google/go-cmp v0.2.0
|
||||
github.com/huandu/go-sqlbuilder v1.4.0
|
||||
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365
|
||||
github.com/jmoiron/sqlx v1.2.0
|
||||
github.com/kelseyhightower/envconfig v1.3.0
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
@ -19,13 +23,16 @@ require (
|
||||
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
github.com/onsi/ginkgo v1.8.0 // indirect
|
||||
github.com/onsi/gomega v1.5.0 // indirect
|
||||
github.com/onsi/gomega v1.5.0
|
||||
github.com/opentracing/opentracing-go v1.1.0 // indirect
|
||||
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3
|
||||
github.com/philhofer/fwd v1.0.0 // indirect
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sergi/go-diff v1.0.0
|
||||
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0
|
||||
github.com/tinylib/msgp v1.1.0 // indirect
|
||||
github.com/urfave/cli v1.20.0
|
||||
github.com/uudashr/go-module v0.0.0-20180827225833-c0ca9c3a4966 // indirect
|
||||
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f
|
||||
golang.org/x/net v0.0.0-20190522155817-f3200d17e092 // indirect
|
||||
golang.org/x/sys v0.0.0-20190526052359-791d8a0f4d09 // indirect
|
||||
|
@ -8,6 +8,12 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
|
||||
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
|
||||
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8=
|
||||
github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc=
|
||||
github.com/fatih/structtag v1.0.0 h1:pTHj65+u3RKWYPSGaU290FpI/dXxTaHdVwVwbcPKmEc=
|
||||
github.com/fatih/structtag v1.0.0/go.mod h1:IKitwq45uXL/yqi5mYghiD3w9H6eTOvI9vnk8tXMphA=
|
||||
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE=
|
||||
@ -30,6 +36,8 @@ github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/huandu/go-sqlbuilder v1.4.0 h1:2LIlTDOz63lOETLOIiKBPEu4PUbikmS5LUc3EekwYqM=
|
||||
github.com/huandu/go-sqlbuilder v1.4.0/go.mod h1:mYfGcZTUS6yJsahUQ3imkYSkGGT3A+owd54+79kkW+U=
|
||||
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365 h1:ECW73yc9MY7935nNYXUkK7Dz17YuSUI9yqRqYS8aBww=
|
||||
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
|
||||
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
|
||||
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
|
||||
@ -69,6 +77,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0 h1:X9XMOYjxEfAYSy3xK1DzO5dMkkWhs9E9UCcS1IERx2k=
|
||||
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0/go.mod h1:Ad7IjTpvzZO8Fl0vh9AzQ+j/jYZfyp2diGwI8m5q+ns=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@ -76,6 +86,10 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
|
||||
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
|
||||
github.com/urfave/cli v1.20.0 h1:fDqGv3UG/4jbVl/QkFwEdddtEDjh/5Ov6X+0B/3bPaw=
|
||||
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
|
||||
github.com/uudashr/go-module v0.0.0-20180827225833-c0ca9c3a4966 h1:7dS/ZO0dIwrtj/FGTt9I6urVpx7LEHzucegv4ORYK3M=
|
||||
github.com/uudashr/go-module v0.0.0-20180827225833-c0ca9c3a4966/go.mod h1:P6Nk1sQWL6jcdBIxnLVlqCsOl0arao7gg7sPoM6gx4A=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo=
|
||||
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -1,43 +1,96 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
// Project is an item we sell.
|
||||
// Project represents a workflow.
|
||||
type Project struct {
|
||||
ID bson.ObjectId `bson:"_id" json:"id"` // Unique identifier.
|
||||
Name string `bson:"name" json:"name"` // Display name of the project.
|
||||
Cost int `bson:"cost" json:"cost"` // Price for one item in cents.
|
||||
Quantity int `bson:"quantity" json:"quantity"` // Original number of items available.
|
||||
DateCreated time.Time `bson:"date_created" json:"date_created"` // When the project was added.
|
||||
DateModified time.Time `bson:"date_modified" json:"date_modified"` // When the project record was lost modified.
|
||||
ID string `json:"id" validate:"required,uuid"`
|
||||
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
|
||||
CreatedAt time.Time `json:"created_at" truss:"api-read"`
|
||||
UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
|
||||
}
|
||||
|
||||
// NewProject is what we require from clients when adding a Project.
|
||||
type NewProject struct {
|
||||
Name string `json:"name" validate:"required"`
|
||||
Cost int `json:"cost" validate:"required,gte=0"`
|
||||
Quantity int `json:"quantity" validate:"required,gte=1"`
|
||||
// CreateProjectRequest contains information needed to create a new Project.
|
||||
type ProjectCreateRequest struct {
|
||||
AccountID string `json:"account_id" validate:"required,uuid"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
|
||||
}
|
||||
|
||||
// UpdateProject defines what information may be provided to modify an
|
||||
// existing Project. All fields are optional so clients can send just the
|
||||
// fields they want changed. It uses pointer fields so we can differentiate
|
||||
// between a field that was not provided and a field that was provided as
|
||||
// explicitly blank. Normally we do not want to use pointers to basic types but
|
||||
// we make exceptions around marshalling/unmarshalling.
|
||||
type UpdateProject struct {
|
||||
Name *string `json:"name"`
|
||||
Cost *int `json:"cost" validate:"omitempty,gte=0"`
|
||||
Quantity *int `json:"quantity" validate:"omitempty,gte=1"`
|
||||
// UpdateProjectRequest defines what information may be provided to modify an existing
|
||||
// Project. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank. Normally
|
||||
// we do not want to use pointers to basic types but we make exceptions around
|
||||
// marshalling/unmarshalling.
|
||||
type ProjectUpdateRequest struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
Name *string `json:"name" validate:"omitempty"`
|
||||
Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
|
||||
}
|
||||
|
||||
// Sale represents a transaction where we sold some quantity of a
|
||||
// Project.
|
||||
type Sale struct{}
|
||||
// ProjectFindRequest defines the possible options to search for projects. By default
|
||||
// archived projects will be excluded from response.
|
||||
type ProjectFindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
IncludedArchived bool
|
||||
}
|
||||
|
||||
// NewSale defines what we require when creating a Sale record.
|
||||
type NewSale struct{}
|
||||
// ProjectStatus represents the status of an project.
|
||||
type ProjectStatus string
|
||||
|
||||
// ProjectStatus values define the status field of a user project.
|
||||
const (
|
||||
// ProjectStatus_Active defines the state when a user can access an project.
|
||||
ProjectStatus_Active ProjectStatus = "active"
|
||||
// ProjectStatus_Disabled defines the state when a user has been disabled from
|
||||
// accessing an project.
|
||||
ProjectStatus_Disabled ProjectStatus = "disabled"
|
||||
)
|
||||
|
||||
// ProjectStatus_Values provides list of valid ProjectStatus values.
|
||||
var ProjectStatus_Values = []ProjectStatus{
|
||||
ProjectStatus_Active,
|
||||
ProjectStatus_Disabled,
|
||||
}
|
||||
|
||||
// Scan supports reading the ProjectStatus value from the database.
|
||||
func (s *ProjectStatus) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = ProjectStatus(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the ProjectStatus value to be stored in the database.
|
||||
func (s ProjectStatus) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof=active disabled")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the ProjectStatus value to a string.
|
||||
func (s ProjectStatus) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
@ -1,162 +0,0 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
const projectsCollection = "projects"
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the mgo not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
)
|
||||
|
||||
// List retrieves a list of existing projects from the database.
|
||||
func List(ctx context.Context, dbConn *sqlx.DB) ([]Project, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.List")
|
||||
defer span.Finish()
|
||||
|
||||
p := []Project{}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(nil).All(&p)
|
||||
}
|
||||
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, "db.projects.find()")
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Retrieve gets the specified project from the database.
|
||||
func Retrieve(ctx context.Context, dbConn *sqlx.DB, id string) (*Project, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Retrieve")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidID
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
var p *Project
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Find(q).One(&p)
|
||||
}
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.find(%s)", q))
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Create inserts a new project into the database.
|
||||
func Create(ctx context.Context, dbConn *sqlx.DB, cp *NewProject, now time.Time) (*Project, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create")
|
||||
defer span.Finish()
|
||||
|
||||
// Mongo truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
p := Project{
|
||||
ID: bson.NewObjectId(),
|
||||
Name: cp.Name,
|
||||
Cost: cp.Cost,
|
||||
Quantity: cp.Quantity,
|
||||
DateCreated: now,
|
||||
DateModified: now,
|
||||
}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Insert(&p)
|
||||
}
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.insert(%v)", &p))
|
||||
}
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Update replaces a project document in the database.
|
||||
func Update(ctx context.Context, dbConn *sqlx.DB, id string, upd UpdateProject, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
fields := make(bson.M)
|
||||
|
||||
if upd.Name != nil {
|
||||
fields["name"] = *upd.Name
|
||||
}
|
||||
if upd.Cost != nil {
|
||||
fields["cost"] = *upd.Cost
|
||||
}
|
||||
if upd.Quantity != nil {
|
||||
fields["quantity"] = *upd.Quantity
|
||||
}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields["date_modified"] = now
|
||||
|
||||
m := bson.M{"$set": fields}
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Update(q, m)
|
||||
}
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", q, m))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a project from the database.
|
||||
func Delete(ctx context.Context, dbConn *sqlx.DB, id string) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
|
||||
defer span.Finish()
|
||||
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return ErrInvalidID
|
||||
}
|
||||
|
||||
q := bson.M{"_id": bson.ObjectIdHex(id)}
|
||||
|
||||
f := func(collection *mgo.Collection) error {
|
||||
return collection.Remove(q)
|
||||
}
|
||||
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
|
||||
if err == mgo.ErrNotFound {
|
||||
return ErrNotFound
|
||||
}
|
||||
return errors.Wrap(err, fmt.Sprintf("db.projects.remove(%v)", q))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,129 +0,0 @@
|
||||
package project_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/project"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var test *tests.Test
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
return m.Run()
|
||||
}
|
||||
|
||||
// TestProject validates the full set of CRUD operations on Project values.
|
||||
func TestProject(t *testing.T) {
|
||||
defer tests.Recover(t)
|
||||
|
||||
t.Log("Given the need to work with Project records.")
|
||||
{
|
||||
t.Log("\tWhen handling a single Project.")
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
dbConn := test.MasterDB.Copy()
|
||||
defer dbConn.Close()
|
||||
|
||||
np := project.NewProject{
|
||||
Name: "Comic Books",
|
||||
Cost: 25,
|
||||
Quantity: 60,
|
||||
}
|
||||
|
||||
p, err := project.Create(ctx, dbConn, &np, time.Now().UTC())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to create a project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to create a project.", tests.Success)
|
||||
|
||||
savedP, err := project.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve project by ID: %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve project by ID.", tests.Success)
|
||||
|
||||
if diff := cmp.Diff(p, savedP); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get back the same project. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the same project.", tests.Success)
|
||||
|
||||
upd := project.UpdateProject{
|
||||
Name: tests.StringPointer("Comics"),
|
||||
Cost: tests.IntPointer(50),
|
||||
Quantity: tests.IntPointer(40),
|
||||
}
|
||||
|
||||
if err := project.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to update project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to update project.", tests.Success)
|
||||
|
||||
savedP, err = project.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve updated project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve updated project.", tests.Success)
|
||||
|
||||
// Build a project matching what we expect to see. We just use the
|
||||
// modified time from the database.
|
||||
want := &project.Project{
|
||||
ID: p.ID,
|
||||
Name: *upd.Name,
|
||||
Cost: *upd.Cost,
|
||||
Quantity: *upd.Quantity,
|
||||
DateCreated: p.DateCreated,
|
||||
DateModified: savedP.DateModified,
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, savedP); diff != "" {
|
||||
t.Fatalf("\t%s\tShould get back the same project. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
t.Logf("\t%s\tShould get back the same project.", tests.Success)
|
||||
|
||||
upd = project.UpdateProject{
|
||||
Name: tests.StringPointer("Graphic Novels"),
|
||||
}
|
||||
|
||||
if err := project.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to update just some fields of project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to update just some fields of project.", tests.Success)
|
||||
|
||||
savedP, err = project.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to retrieve updated project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to retrieve updated project.", tests.Success)
|
||||
|
||||
if savedP.Name != *upd.Name {
|
||||
t.Fatalf("\t%s\tShould be able to see updated Name field : got %q want %q.", tests.Failed, savedP.Name, *upd.Name)
|
||||
} else {
|
||||
t.Logf("\t%s\tShould be able to see updated Name field.", tests.Success)
|
||||
}
|
||||
|
||||
if err := project.Delete(ctx, dbConn, p.ID.Hex()); err != nil {
|
||||
t.Fatalf("\t%s\tShould be able to delete project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould be able to delete project.", tests.Success)
|
||||
|
||||
savedP, err = project.Retrieve(ctx, dbConn, p.ID.Hex())
|
||||
if errors.Cause(err) != project.ErrNotFound {
|
||||
t.Fatalf("\t%s\tShould NOT be able to retrieve deleted project : %s.", tests.Failed, err)
|
||||
}
|
||||
t.Logf("\t%s\tShould NOT be able to retrieve deleted project.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
@ -65,8 +65,8 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
|
||||
zipcode varchar(20) NOT NULL DEFAULT '',
|
||||
status account_status_t NOT NULL DEFAULT 'active',
|
||||
timezone varchar(128) NOT NULL DEFAULT 'America/Anchorage',
|
||||
signup_user_id char(36) DEFAULT NULL,
|
||||
billing_user_id char(36) DEFAULT NULL,
|
||||
signup_user_id char(36) DEFAULT NULL REFERENCES users(id),
|
||||
billing_user_id char(36) DEFAULT NULL REFERENCES users(id),
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
@ -107,8 +107,8 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
|
||||
|
||||
q3 := `CREATE TABLE IF NOT EXISTS users_accounts (
|
||||
id char(36) NOT NULL,
|
||||
account_id char(36) NOT NULL,
|
||||
user_id char(36) NOT NULL,
|
||||
account_id char(36) NOT NULL REFERENCES accounts(id),
|
||||
user_id char(36) NOT NULL REFERENCES users(id),
|
||||
roles user_account_role_t[] NOT NULL,
|
||||
status user_account_status_t NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
@ -142,5 +142,42 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
// create new table projects
|
||||
{
|
||||
ID: "20190622-01",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
q1 := `CREATE TYPE project_status_t as enum('active','disabled')`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
|
||||
q2 := `CREATE TABLE IF NOT EXISTS projects (
|
||||
id char(36) NOT NULL,
|
||||
account_id char(36) NOT NULL REFERENCES accounts(id),
|
||||
name varchar(255) NOT NULL,
|
||||
status project_status_t NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
|
||||
PRIMARY KEY (id)
|
||||
)`
|
||||
if _, err := tx.Exec(q2); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q2)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Rollback: func(tx *sql.Tx) error {
|
||||
q1 := `DROP TYPE project_status_t`
|
||||
if _, err := tx.Exec(q1); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q1)
|
||||
}
|
||||
|
||||
q2 := `DROP TABLE IF EXISTS projects`
|
||||
if _, err := tx.Exec(q2); err != nil {
|
||||
return errors.WithMessagef(err, "Query failed %s", q2)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
1
example-project/tools/truss/.gitignore
vendored
Normal file
1
example-project/tools/truss/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
truss
|
33
example-project/tools/truss/README.md
Normal file
33
example-project/tools/truss/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# SaaS Truss
|
||||
|
||||
Copyright 2019, Geeks Accelerator
|
||||
accelerator@geeksinthewoods.com.com
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
Truss provides code generation to reduce copy/pasting.
|
||||
|
||||
|
||||
## Local Installation
|
||||
|
||||
### Build
|
||||
```bash
|
||||
go build .
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
./truss -h
|
||||
|
||||
Usage of ./truss
|
||||
--cmd string <dbtable2crud>
|
||||
--db_host string <127.0.0.1:5433>
|
||||
--db_user string <postgres>
|
||||
--db_pass string <postgres>
|
||||
--db_database string <shared>
|
||||
--db_driver string <postgres>
|
||||
--db_timezone string <utc>
|
||||
--db_disabletls bool <false>
|
||||
```
|
||||
|
149
example-project/tools/truss/cmd/dbtable2crud/db.go
Normal file
149
example-project/tools/truss/cmd/dbtable2crud/db.go
Normal file
@ -0,0 +1,149 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type psqlColumn struct {
|
||||
Table string
|
||||
Column string
|
||||
ColumnId int64
|
||||
NotNull bool
|
||||
DataTypeFull string
|
||||
DataTypeName string
|
||||
DataTypeLength *int
|
||||
NumericPrecision *int
|
||||
NumericScale *int
|
||||
IsPrimaryKey bool
|
||||
PrimaryKeyName *string
|
||||
IsUniqueKey bool
|
||||
UniqueKeyName *string
|
||||
IsForeignKey bool
|
||||
ForeignKeyName *string
|
||||
ForeignKeyColumnId pq.Int64Array
|
||||
ForeignKeyTable *string
|
||||
ForeignKeyLocalColumnId pq.Int64Array
|
||||
DefaultFull *string
|
||||
DefaultValue *string
|
||||
IsEnum bool
|
||||
EnumTypeId *string
|
||||
EnumValues []string
|
||||
}
|
||||
|
||||
// descTable lists all the columns for a table.
|
||||
func descTable(db *sqlx.DB, dbName, dbTable string) ([]psqlColumn, error) {
|
||||
|
||||
queryStr := fmt.Sprintf(`SELECT
|
||||
c.relname as table,
|
||||
f.attname as column,
|
||||
f.attnum as columnId,
|
||||
f.attnotnull as not_null,
|
||||
pg_catalog.format_type(f.atttypid,f.atttypmod) AS data_type_full,
|
||||
t.typname AS data_type_name,
|
||||
CASE WHEN f.atttypmod >= 0 AND t.typname <> 'numeric'THEN (f.atttypmod - 4) --first 4 bytes are for storing actual length of data
|
||||
END AS data_type_length,
|
||||
CASE WHEN t.typname = 'numeric' THEN (((f.atttypmod - 4) >> 16) & 65535)
|
||||
END AS numeric_precision,
|
||||
CASE WHEN t.typname = 'numeric' THEN ((f.atttypmod - 4)& 65535 )
|
||||
END AS numeric_scale,
|
||||
CASE WHEN p.contype = 'p' THEN true ELSE false
|
||||
END AS is_primary_key,
|
||||
CASE WHEN p.contype = 'p' THEN p.conname
|
||||
END AS primary_key_name,
|
||||
CASE WHEN p.contype = 'u' THEN true ELSE false
|
||||
END AS is_unique_key,
|
||||
CASE WHEN p.contype = 'u' THEN p.conname
|
||||
END AS unique_key_name,
|
||||
CASE WHEN p.contype = 'f' THEN true ELSE false
|
||||
END AS is_foreign_key,
|
||||
CASE WHEN p.contype = 'f' THEN p.conname
|
||||
END AS foreignkey_name,
|
||||
CASE WHEN p.contype = 'f' THEN p.confkey
|
||||
END AS foreign_key_columnid,
|
||||
CASE WHEN p.contype = 'f' THEN g.relname
|
||||
END AS foreign_key_table,
|
||||
CASE WHEN p.contype = 'f' THEN p.conkey
|
||||
END AS foreign_key_local_column_id,
|
||||
CASE WHEN f.atthasdef = 't' THEN d.adsrc
|
||||
END AS default_value,
|
||||
CASE WHEN t.typtype = 'e' THEN true ELSE false
|
||||
END AS is_enum,
|
||||
CASE WHEN t.typtype = 'e' THEN t.oid
|
||||
END AS enum_type_id
|
||||
FROM pg_attribute f
|
||||
JOIN pg_class c ON c.oid = f.attrelid
|
||||
JOIN pg_type t ON t.oid = f.atttypid
|
||||
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||
WHERE c.relkind = 'r'::char
|
||||
AND f.attisdropped = false
|
||||
AND c.relname = '%s'
|
||||
AND f.attnum > 0
|
||||
ORDER BY f.attnum
|
||||
;`, dbTable) // AND n.nspname = '%s'
|
||||
|
||||
rows, err := db.Query(queryStr)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// iterate over each row
|
||||
var resp []psqlColumn
|
||||
for rows.Next() {
|
||||
var c psqlColumn
|
||||
err = rows.Scan(&c.Table, &c.Column, &c.ColumnId, &c.NotNull, &c.DataTypeFull, &c.DataTypeName, &c.DataTypeLength, &c.NumericPrecision, &c.NumericScale, &c.IsPrimaryKey, &c.PrimaryKeyName, &c.IsUniqueKey, &c.UniqueKeyName, &c.IsForeignKey, &c.ForeignKeyName, &c.ForeignKeyColumnId, &c.ForeignKeyTable, &c.ForeignKeyLocalColumnId, &c.DefaultFull, &c.IsEnum, &c.EnumTypeId)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.DefaultFull != nil {
|
||||
defaultValue := *c.DefaultFull
|
||||
|
||||
// "'active'::project_status_t"
|
||||
defaultValue = strings.Split(defaultValue, "::")[0]
|
||||
c.DefaultValue = &defaultValue
|
||||
}
|
||||
|
||||
resp = append(resp, c)
|
||||
}
|
||||
|
||||
for colIdx, dbCol := range resp {
|
||||
if !dbCol.IsEnum {
|
||||
continue
|
||||
}
|
||||
|
||||
queryStr := fmt.Sprintf(`SELECT e.enumlabel
|
||||
FROM pg_enum AS e
|
||||
WHERE e.enumtypid = '%s'
|
||||
ORDER BY e.enumsortorder`, *dbCol.EnumTypeId)
|
||||
|
||||
rows, err := db.Query(queryStr)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var v string
|
||||
err = rows.Scan(&v)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
dbCol.EnumValues = append(dbCol.EnumValues, v)
|
||||
}
|
||||
|
||||
resp[colIdx] = dbCol
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
378
example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go
Normal file
378
example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go
Normal file
@ -0,0 +1,378 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
)
|
||||
|
||||
// Run in the main entry point for the dbtable2crud cmd.
|
||||
func Run(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir, goSrcPath string) error {
|
||||
log.SetPrefix(log.Prefix() + " : dbtable2crud")
|
||||
|
||||
// Ensure the schema is up to date
|
||||
if err := schema.Migrate(db, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// When dbTable is empty, lower case the model name
|
||||
if dbTable == "" {
|
||||
dbTable = strings.Join(camelcase.Split(modelName), " ")
|
||||
dbTable = english.PluralWord(2, dbTable, "")
|
||||
dbTable = strings.Replace(dbTable, " ", "_", -1)
|
||||
dbTable = strings.ToLower(dbTable)
|
||||
}
|
||||
|
||||
// Parse the model file and load the specified model struct.
|
||||
model, err := parseModelFile(db, log, dbName, dbTable, modelFile, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic lint of the model struct.
|
||||
err = validateModel(log, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmplData := map[string]interface{}{
|
||||
"GoSrcPath": goSrcPath,
|
||||
}
|
||||
|
||||
// Update the model file with new or updated code.
|
||||
err = updateModel(log, model, templateDir, tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the model crud file with new or updated code.
|
||||
err = updateModelCrud(db, log, dbName, dbTable, modelFile, modelName, templateDir, model, tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateModel performs a basic lint of the model struct to ensure
|
||||
// code gen output is correct.
|
||||
func validateModel(log *log.Logger, model *modelDef) error {
|
||||
for _, sf := range model.Fields {
|
||||
if sf.DbColumn == nil && sf.ColumnName != "-" {
|
||||
log.Printf("validateStruct : Unable to find struct field for db column %s\n", sf.ColumnName)
|
||||
}
|
||||
|
||||
var expectedType string
|
||||
switch sf.FieldName {
|
||||
case "ID":
|
||||
expectedType = "string"
|
||||
case "CreatedAt":
|
||||
expectedType = "time.Time"
|
||||
case "UpdatedAt":
|
||||
expectedType = "time.Time"
|
||||
case "ArchivedAt":
|
||||
expectedType = "pq.NullTime"
|
||||
}
|
||||
|
||||
if expectedType != "" && expectedType != sf.FieldType {
|
||||
log.Printf("validateStruct : Struct field %s should be of type %s not %s\n", sf.FieldName, expectedType, sf.FieldType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModel updated the parsed code file with the new code.
|
||||
func updateModel(log *log.Logger, model *modelDef, templateDir string, tmplData map[string]interface{}) error {
|
||||
|
||||
// Execute template and parse code to be used to compare against modelFile.
|
||||
tmplObjs, err := loadTemplateObjects(log, model, templateDir, "models.tmpl", tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the current code as a string to produce a diff.
|
||||
curCode := model.String()
|
||||
|
||||
objHeaders := []*goparse.GoObject{}
|
||||
|
||||
for _, obj := range tmplObjs {
|
||||
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
|
||||
objHeaders = append(objHeaders, obj)
|
||||
continue
|
||||
}
|
||||
|
||||
if model.HasType(obj.Name, obj.Type) {
|
||||
cur := model.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
if len(objHeaders) > 0 {
|
||||
// Remove any comments and linebreaks before the existing object so updates can be added.
|
||||
removeObjs := []*goparse.GoObject{}
|
||||
for idx := cur.Index - 1; idx > 0; idx-- {
|
||||
prevObj := model.Objects().List()[idx]
|
||||
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
|
||||
removeObjs = append(removeObjs, prevObj)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(removeObjs) > 0 {
|
||||
err := model.Objects().Remove(removeObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the current index is correct.
|
||||
cur = model.Objects().Get(obj.Name, obj.Type)
|
||||
}
|
||||
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
newObjs = append(newObjs, c)
|
||||
}
|
||||
}
|
||||
|
||||
newObjs = append(newObjs, obj)
|
||||
|
||||
// Do the object replacement.
|
||||
err := model.Objects().Replace(cur, newObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := model.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := model.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
objHeaders = []*goparse.GoObject{}
|
||||
}
|
||||
|
||||
// Set some flags to determine additional imports and need to be added.
|
||||
var hasEnum bool
|
||||
var hasPq bool
|
||||
for _, f := range model.Fields {
|
||||
if f.DbColumn != nil && f.DbColumn.IsEnum {
|
||||
hasEnum = true
|
||||
}
|
||||
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
|
||||
hasPq = true
|
||||
}
|
||||
}
|
||||
|
||||
reqImports := []string{}
|
||||
if hasEnum {
|
||||
reqImports = append(reqImports, "database/sql/driver")
|
||||
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
|
||||
reqImports = append(reqImports, "github.com/pkg/errors")
|
||||
}
|
||||
|
||||
if hasPq {
|
||||
reqImports = append(reqImports, "github.com/lib/pq")
|
||||
}
|
||||
|
||||
for _, in := range reqImports {
|
||||
err := model.AddImport(goparse.GoImport{Name: in})
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Produce a diff after the updates have been applied.
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(curCode, model.String(), true)
|
||||
|
||||
fmt.Println(dmp.DiffPrettyText(diffs))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModelCrud updated the parsed code file with the new code.
|
||||
func updateModelCrud(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir string, baseModel *modelDef, tmplData map[string]interface{}) error {
|
||||
|
||||
modelDir := filepath.Dir(modelFile)
|
||||
crudFile := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+".go")
|
||||
|
||||
var crudDoc *goparse.GoDocument
|
||||
if _, err := os.Stat(crudFile); os.IsNotExist(err) {
|
||||
crudDoc, err = goparse.NewGoDocument(baseModel.Package)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Parse the supplied model file.
|
||||
crudDoc, err = goparse.ParseFile(log, modelFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Load all the updated struct fields from the base model file.
|
||||
structFields := make(map[string]map[string]modelField)
|
||||
for _, obj := range baseModel.GoDocument.Objects().List() {
|
||||
if obj.Type != goparse.GoObjectType_Struct || obj.Name == baseModel.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
objFields, err := parseModelFields(baseModel.GoDocument, obj.Name, baseModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
structFields[obj.Name] = make(map[string]modelField)
|
||||
for _, f := range objFields {
|
||||
structFields[obj.Name][f.FieldName] = f
|
||||
}
|
||||
}
|
||||
|
||||
// Append the struct fields to be used for template execution.
|
||||
if tmplData == nil {
|
||||
tmplData = make(map[string]interface{})
|
||||
}
|
||||
tmplData["StructFields"] = structFields
|
||||
|
||||
// Execute template and parse code to be used to compare against modelFile.
|
||||
tmplObjs, err := loadTemplateObjects(log, baseModel, templateDir, "model_crud.tmpl", tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the current code as a string to produce a diff.
|
||||
curCode := crudDoc.String()
|
||||
|
||||
objHeaders := []*goparse.GoObject{}
|
||||
|
||||
for _, obj := range tmplObjs {
|
||||
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
|
||||
objHeaders = append(objHeaders, obj)
|
||||
continue
|
||||
}
|
||||
|
||||
if crudDoc.HasType(obj.Name, obj.Type) {
|
||||
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
if len(objHeaders) > 0 {
|
||||
// Remove any comments and linebreaks before the existing object so updates can be added.
|
||||
removeObjs := []*goparse.GoObject{}
|
||||
for idx := cur.Index - 1; idx > 0; idx-- {
|
||||
prevObj := crudDoc.Objects().List()[idx]
|
||||
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
|
||||
removeObjs = append(removeObjs, prevObj)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(removeObjs) > 0 {
|
||||
err := crudDoc.Objects().Remove(removeObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the current index is correct.
|
||||
cur = crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
}
|
||||
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
newObjs = append(newObjs, c)
|
||||
}
|
||||
}
|
||||
|
||||
newObjs = append(newObjs, obj)
|
||||
|
||||
// Do the object replacement.
|
||||
err := crudDoc.Objects().Replace(cur, newObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := crudDoc.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := crudDoc.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
objHeaders = []*goparse.GoObject{}
|
||||
}
|
||||
|
||||
/*
|
||||
// Set some flags to determine additional imports and need to be added.
|
||||
var hasEnum bool
|
||||
var hasPq bool
|
||||
for _, f := range crudModel.Fields {
|
||||
if f.DbColumn != nil && f.DbColumn.IsEnum {
|
||||
hasEnum = true
|
||||
}
|
||||
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
|
||||
hasPq = true
|
||||
}
|
||||
}
|
||||
|
||||
reqImports := []string{}
|
||||
if hasEnum {
|
||||
reqImports = append(reqImports, "database/sql/driver")
|
||||
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
|
||||
reqImports = append(reqImports, "github.com/pkg/errors")
|
||||
}
|
||||
|
||||
if hasPq {
|
||||
reqImports = append(reqImports, "github.com/lib/pq")
|
||||
}
|
||||
|
||||
for _, in := range reqImports {
|
||||
err := model.AddImport(goparse.GoImport{Name: in})
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, crudModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Produce a diff after the updates have been applied.
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(curCode, crudDoc.String(), true)
|
||||
|
||||
fmt.Println(dmp.DiffPrettyText(diffs))
|
||||
|
||||
return nil
|
||||
}
|
229
example-project/tools/truss/cmd/dbtable2crud/models.go
Normal file
229
example-project/tools/truss/cmd/dbtable2crud/models.go
Normal file
@ -0,0 +1,229 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
|
||||
"github.com/fatih/structtag"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// modelDef defines info about the struct and associated db table.
|
||||
type modelDef struct {
|
||||
*goparse.GoDocument
|
||||
Name string
|
||||
TableName string
|
||||
PrimaryField string
|
||||
PrimaryColumn string
|
||||
PrimaryType string
|
||||
Fields []modelField
|
||||
FieldNames []string
|
||||
ColumnNames []string
|
||||
}
|
||||
|
||||
// modelField defines a struct field and associated db column.
|
||||
type modelField struct {
|
||||
ColumnName string
|
||||
DbColumn *psqlColumn
|
||||
FieldName string
|
||||
FieldType string
|
||||
FieldIsPtr bool
|
||||
Tags *structtag.Tags
|
||||
ApiHide bool
|
||||
ApiRead bool
|
||||
ApiCreate bool
|
||||
ApiUpdate bool
|
||||
DefaultValue string
|
||||
}
|
||||
|
||||
// parseModelFile parses the entire model file and then loads the specified model struct.
|
||||
func parseModelFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName string) (*modelDef, error) {
|
||||
|
||||
// Parse the supplied model file.
|
||||
doc, err := goparse.ParseFile(log, modelFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Init new modelDef.
|
||||
model := &modelDef{
|
||||
GoDocument: doc,
|
||||
Name: modelName,
|
||||
TableName: dbTable,
|
||||
}
|
||||
|
||||
// Append the field the the model def.
|
||||
model.Fields, err = parseModelFields(doc, modelName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sf := range model.Fields {
|
||||
model.FieldNames = append(model.FieldNames, sf.FieldName)
|
||||
model.ColumnNames = append(model.ColumnNames, sf.ColumnName)
|
||||
}
|
||||
|
||||
// Query the database for a table definition.
|
||||
dbCols, err := descTable(db, dbName, dbTable)
|
||||
if err != nil {
|
||||
return model, err
|
||||
}
|
||||
|
||||
// Loop over all the database table columns and append to the associated
|
||||
// struct field. Don't force all database table columns to be defined in the
|
||||
// in the struct.
|
||||
for _, dbCol := range dbCols {
|
||||
for idx, sf := range model.Fields {
|
||||
if sf.ColumnName != dbCol.Column {
|
||||
continue
|
||||
}
|
||||
|
||||
if dbCol.IsPrimaryKey {
|
||||
model.PrimaryColumn = sf.ColumnName
|
||||
model.PrimaryField = sf.FieldName
|
||||
model.PrimaryType = sf.FieldType
|
||||
}
|
||||
|
||||
if dbCol.DefaultValue != nil {
|
||||
sf.DefaultValue = *dbCol.DefaultValue
|
||||
|
||||
if dbCol.IsEnum {
|
||||
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
|
||||
sf.DefaultValue = sf.FieldType + "_" + FormatCamel(sf.DefaultValue)
|
||||
} else if strings.HasPrefix(sf.DefaultValue, "'") {
|
||||
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
|
||||
sf.DefaultValue = "\"" + sf.DefaultValue + "\""
|
||||
}
|
||||
}
|
||||
|
||||
c := dbCol
|
||||
sf.DbColumn = &c
|
||||
model.Fields[idx] = sf
|
||||
}
|
||||
}
|
||||
|
||||
// Print out the model for debugging.
|
||||
//modelJSON, err := json.MarshalIndent(model, "", " ")
|
||||
//if err != nil {
|
||||
// return model, errors.WithStack(err )
|
||||
//}
|
||||
//log.Printf(string(modelJSON))
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// parseModelFields parses the fields from a struct.
|
||||
func parseModelFields(doc *goparse.GoDocument, modelName string, baseModel *modelDef) ([]modelField, error) {
|
||||
|
||||
// Ensure the model file has a struct with the model name supplied.
|
||||
if !doc.HasType(modelName, goparse.GoObjectType_Struct) {
|
||||
err := errors.Errorf("Struct with the name %s does not exist", modelName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load the struct from parsed go file.
|
||||
docModel := doc.Get(modelName, goparse.GoObjectType_Struct)
|
||||
|
||||
// Loop over all the objects contained between the struct definition start and end.
|
||||
// This should be a list of variables defined for model.
|
||||
resp := []modelField{}
|
||||
for _, l := range docModel.Objects().List() {
|
||||
|
||||
// Skip all lines that are not a var.
|
||||
if l.Type != goparse.GoObjectType_Line {
|
||||
log.Printf("parseModelFile : Model %s has line that is %s, not type line, skipping - %s\n", modelName, l.Type, l.String())
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the var name, type and defined tags from the line.
|
||||
sv, err := goparse.ParseStructProp(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Init new modelField for the struct var.
|
||||
sf := modelField{
|
||||
FieldName: sv.Name,
|
||||
FieldType: sv.Type,
|
||||
FieldIsPtr: strings.HasPrefix(sv.Type, "*"),
|
||||
Tags: sv.Tags,
|
||||
}
|
||||
|
||||
// Extract the column name from the var tags.
|
||||
if sf.Tags != nil {
|
||||
// First try to get the column name from the db tag.
|
||||
dbt, err := sf.Tags.Get("db")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if dbt != nil {
|
||||
sf.ColumnName = dbt.Name
|
||||
}
|
||||
|
||||
// Second try to get the column name from the json tag.
|
||||
if sf.ColumnName == "" {
|
||||
jt, err := sf.Tags.Get("json")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if jt != nil && jt.Name != "-" {
|
||||
sf.ColumnName = jt.Name
|
||||
}
|
||||
}
|
||||
|
||||
var apiActionsSet bool
|
||||
tt, err := sf.Tags.Get("truss")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if tt != nil {
|
||||
if tt.Name == "api-create" || tt.HasOption("api-create") {
|
||||
sf.ApiCreate = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-read" || tt.HasOption("api-read") {
|
||||
sf.ApiRead = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-update" || tt.HasOption("api-update") {
|
||||
sf.ApiUpdate = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-hide" || tt.HasOption("api-hide") {
|
||||
sf.ApiHide = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
}
|
||||
|
||||
if !apiActionsSet {
|
||||
sf.ApiCreate = true
|
||||
sf.ApiRead = true
|
||||
sf.ApiUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set the column name to the field name if empty and does not equal '-'.
|
||||
if sf.ColumnName == "" {
|
||||
sf.ColumnName = sf.FieldName
|
||||
}
|
||||
|
||||
// If a base model as already been parsed with the db columns,
|
||||
// append to the current field.
|
||||
if baseModel != nil {
|
||||
for _, baseSf := range baseModel.Fields {
|
||||
if baseSf.ColumnName == sf.ColumnName {
|
||||
sf.DefaultValue = baseSf.DefaultValue
|
||||
sf.DbColumn = baseSf.DbColumn
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Append the field the the model def.
|
||||
resp = append(resp, sf)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
345
example-project/tools/truss/cmd/dbtable2crud/templates.go
Normal file
345
example-project/tools/truss/cmd/dbtable2crud/templates.go
Normal file
@ -0,0 +1,345 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// loadTemplateObjects executes a template file based on the given model struct and
|
||||
// returns the parsed go objects.
|
||||
func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename string, tmptData map[string]interface{}) ([]*goparse.GoObject, error) {
|
||||
|
||||
// Data used to execute all the of defined code sections in the template file.
|
||||
if tmptData == nil {
|
||||
tmptData = make(map[string]interface{})
|
||||
}
|
||||
tmptData["Model"] = model
|
||||
|
||||
// geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
|
||||
// Read the template file from the local file system.
|
||||
tempFilePath := filepath.Join(templateDir, filename)
|
||||
dat, err := ioutil.ReadFile(tempFilePath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// New template with custom functions.
|
||||
baseTmpl := template.New("base")
|
||||
baseTmpl.Funcs(template.FuncMap{
|
||||
"Concat": func(vals ...string) string {
|
||||
return strings.Join(vals, "")
|
||||
},
|
||||
"JoinStrings": func(vals []string, sep string) string {
|
||||
return strings.Join(vals, sep)
|
||||
},
|
||||
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, pre+v)
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, fmt.Sprintf(fmtStr, v))
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FormatCamel": func(name string) string {
|
||||
return FormatCamel(name)
|
||||
},
|
||||
"FormatCamelTitle": func(name string) string {
|
||||
return FormatCamelTitle(name)
|
||||
},
|
||||
"FormatCamelLower": func(name string) string {
|
||||
if name == "ID" {
|
||||
return "id"
|
||||
}
|
||||
return FormatCamelLower(name)
|
||||
},
|
||||
"FormatCamelLowerTitle": func(name string) string {
|
||||
return FormatCamelLowerTitle(name)
|
||||
},
|
||||
"FormatCamelPluralTitle": func(name string) string {
|
||||
return FormatCamelPluralTitle(name)
|
||||
},
|
||||
"FormatCamelPluralTitleLower": func(name string) string {
|
||||
return FormatCamelPluralTitleLower(name)
|
||||
},
|
||||
"FormatCamelPluralCamel": func(name string) string {
|
||||
return FormatCamelPluralCamel(name)
|
||||
},
|
||||
"FormatCamelPluralLower": func(name string) string {
|
||||
return FormatCamelPluralLower(name)
|
||||
},
|
||||
"FormatCamelPluralUnderscore": func(name string) string {
|
||||
return FormatCamelPluralUnderscore(name)
|
||||
},
|
||||
"FormatCamelPluralLowerUnderscore": func(name string) string {
|
||||
return FormatCamelPluralLowerUnderscore(name)
|
||||
},
|
||||
"FormatCamelUnderscore": func(name string) string {
|
||||
return FormatCamelUnderscore(name)
|
||||
},
|
||||
"FormatCamelLowerUnderscore": func(name string) string {
|
||||
return FormatCamelLowerUnderscore(name)
|
||||
},
|
||||
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
|
||||
if f.Tags == nil {
|
||||
return false
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return false
|
||||
}
|
||||
if ft.Name == optName || ft.HasOption(optName) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
"FieldTag": func(f modelField, tagName string) string {
|
||||
if f.Tags == nil {
|
||||
return ""
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return ""
|
||||
}
|
||||
return ft.String()
|
||||
},
|
||||
"FieldTagReplaceOrPrepend": func(f modelField, tagName, oldVal, newVal string) string {
|
||||
if f.Tags == nil {
|
||||
return ""
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if ft.Name == oldVal || ft.Name == newVal {
|
||||
ft.Name = newVal
|
||||
} else if ft.HasOption(oldVal) {
|
||||
for idx, val := range ft.Options {
|
||||
if val == oldVal {
|
||||
ft.Options[idx] = newVal
|
||||
}
|
||||
}
|
||||
} else if !ft.HasOption(newVal) {
|
||||
if ft.Name == "" {
|
||||
ft.Name = newVal
|
||||
} else {
|
||||
ft.Options = append(ft.Options, newVal)
|
||||
}
|
||||
}
|
||||
|
||||
return ft.String()
|
||||
},
|
||||
"StringListHasValue": func(list []string, val string) bool {
|
||||
for _, v := range list {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
|
||||
// Load the template file using the text/template package.
|
||||
tmpl, err := baseTmpl.Parse(string(dat))
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a list of template names defined in the template file.
|
||||
tmplNames := []string{}
|
||||
for _, defTmpl := range tmpl.Templates() {
|
||||
tmplNames = append(tmplNames, defTmpl.Name())
|
||||
}
|
||||
|
||||
// Stupid hack to return template names the in order they are defined in the file.
|
||||
tmplNames, err = templateFileOrderedNames(tempFilePath, tmplNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Loop over all the defined templates, execute using the defined data, parse the
|
||||
// formatted code and append the parsed go objects to the result list.
|
||||
var resp []*goparse.GoObject
|
||||
for _, tmplName := range tmplNames {
|
||||
// Executed the defined template with the given data.
|
||||
var tpl bytes.Buffer
|
||||
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Format the source code to ensure its valid and code to parsed consistently.
|
||||
codeBytes, err := format.Source(tpl.Bytes())
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
|
||||
|
||||
dl := []string{}
|
||||
for idx, l := range strings.Split(tpl.String(), "\n") {
|
||||
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
|
||||
}
|
||||
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Remove extra white space from the code.
|
||||
codeStr := strings.TrimSpace(string(codeBytes))
|
||||
|
||||
// Split the code into a list of strings.
|
||||
codeLines := strings.Split(codeStr, "\n")
|
||||
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := goparse.ParseLines(codeLines, 0)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Append the parsed objects to the return result list.
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Name == "" && obj.Type != goparse.GoObjectType_Import && obj.Type != goparse.GoObjectType_Var && obj.Type != goparse.GoObjectType_Const && obj.Type != goparse.GoObjectType_Comment && obj.Type != goparse.GoObjectType_LineBreak {
|
||||
// All objects should have a name except for multiline var/const declarations and comments.
|
||||
err = errors.Errorf("Failed to parse name with type %s from lines: %v", obj.Type, obj.Lines())
|
||||
return resp, err
|
||||
} else if string(obj.Type) == "" {
|
||||
err = errors.Errorf("Failed to parse type for %s from lines: %v", obj.Name, obj.Lines())
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resp = append(resp, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// FormatCamel formats Valdez mountain to ValdezMountain
|
||||
func FormatCamel(name string) string {
|
||||
return strcase.ToCamel(name)
|
||||
}
|
||||
|
||||
// FormatCamelLower formats ValdezMountain to valdezmountain
|
||||
func FormatCamelLower(name string) string {
|
||||
return strcase.ToLowerCamel(FormatCamel(name))
|
||||
}
|
||||
|
||||
// FormatCamelTitle formats ValdezMountain to Valdez Mountain
|
||||
func FormatCamelTitle(name string) string {
|
||||
return strings.Join(camelcase.Split(name), " ")
|
||||
}
|
||||
|
||||
// FormatCamelLowerTitle formats ValdezMountain to valdez mountain
|
||||
func FormatCamelLowerTitle(name string) string {
|
||||
return strings.ToLower(FormatCamelTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralTitle formats ValdezMountain to Valdez Mountains
|
||||
func FormatCamelPluralTitle(name string) string {
|
||||
pts := camelcase.Split(name)
|
||||
lastIdx := len(pts) - 1
|
||||
pts[lastIdx] = english.PluralWord(2, pts[lastIdx], "")
|
||||
return strings.Join(pts, " ")
|
||||
}
|
||||
|
||||
// FormatCamelPluralTitleLower formats ValdezMountain to valdez mountains
|
||||
func FormatCamelPluralTitleLower(name string) string {
|
||||
return strings.ToLower(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralCamel formats ValdezMountain to ValdezMountains
|
||||
func FormatCamelPluralCamel(name string) string {
|
||||
return strcase.ToCamel(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralLower formats ValdezMountain to valdezmountains
|
||||
func FormatCamelPluralLower(name string) string {
|
||||
return strcase.ToLowerCamel(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralUnderscore formats ValdezMountain to Valdez_Mountains
|
||||
func FormatCamelPluralUnderscore(name string) string {
|
||||
return strings.Replace(FormatCamelPluralTitle(name), " ", "_", -1)
|
||||
}
|
||||
|
||||
// FormatCamelPluralLowerUnderscore formats ValdezMountain to valdez_mountains
|
||||
func FormatCamelPluralLowerUnderscore(name string) string {
|
||||
return strings.ToLower(FormatCamelPluralUnderscore(name))
|
||||
}
|
||||
|
||||
// FormatCamelUnderscore formats ValdezMountain to Valdez_Mountain
|
||||
func FormatCamelUnderscore(name string) string {
|
||||
return strings.Replace(FormatCamelTitle(name), " ", "_", -1)
|
||||
}
|
||||
|
||||
// FormatCamelLowerUnderscore formats ValdezMountain to valdez_mountain
|
||||
func FormatCamelLowerUnderscore(name string) string {
|
||||
return strings.ToLower(FormatCamelUnderscore(name))
|
||||
}
|
||||
|
||||
// templateFileOrderedNames returns the template names the in order they are defined in the file.
|
||||
func templateFileOrderedNames(localPath string, names []string) (resp []string, err error) {
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return resp, errors.WithStack(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
idxList := []int{}
|
||||
idxNames := make(map[int]string)
|
||||
|
||||
idx := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
if strings.Contains(scanner.Text(), "\""+name+"\"") {
|
||||
idxList = append(idxList, idx)
|
||||
idxNames[idx] = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
idx = idx + 1
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return resp, errors.WithStack(err)
|
||||
}
|
||||
|
||||
sort.Ints(idxList)
|
||||
|
||||
for _, idx := range idxList {
|
||||
resp = append(resp, idxNames[idx])
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
301
example-project/tools/truss/internal/goparse/doc.go
Normal file
301
example-project/tools/truss/internal/goparse/doc.go
Normal file
@ -0,0 +1,301 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// GoDocument defines a single go code file.
|
||||
type GoDocument struct {
|
||||
*GoObjects
|
||||
Package string
|
||||
imports GoImports
|
||||
}
|
||||
|
||||
// GoImport defines a single import line with optional alias.
|
||||
type GoImport struct {
|
||||
Name string
|
||||
Alias string
|
||||
}
|
||||
|
||||
// GoImports holds a list of import lines.
|
||||
type GoImports []GoImport
|
||||
|
||||
// NewGoDocument creates a new GoDocument with the package line set.
|
||||
func NewGoDocument(packageName string) (doc *GoDocument, err error) {
|
||||
doc = &GoDocument{
|
||||
GoObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
err = doc.SetPackage(packageName)
|
||||
return doc, err
|
||||
}
|
||||
|
||||
// Objects returns a list of root GoObject.
|
||||
func (doc *GoDocument) Objects() *GoObjects {
|
||||
if doc.GoObjects == nil {
|
||||
doc.GoObjects = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
}
|
||||
|
||||
return doc.GoObjects
|
||||
}
|
||||
|
||||
// NewObjectPackage returns a new GoObject with a single package definition line.
|
||||
func NewObjectPackage(packageName string) *GoObject {
|
||||
lines := []string{
|
||||
fmt.Sprintf("package %s", packageName),
|
||||
"",
|
||||
}
|
||||
|
||||
obj, _ := ParseGoObject(lines, 0)
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
// SetPackage appends sets the package line for the code file.
|
||||
func (doc *GoDocument) SetPackage(packageName string) error {
|
||||
|
||||
var existing *GoObject
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type == GoObjectType_Package {
|
||||
existing = obj
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
new := NewObjectPackage(packageName)
|
||||
|
||||
var err error
|
||||
if existing != nil {
|
||||
err = doc.Objects().Replace(existing, new)
|
||||
} else if len(doc.Objects().List()) > 0 {
|
||||
|
||||
// Insert after any existing comments or line breaks.
|
||||
var insertPos int
|
||||
//for idx, obj := range doc.Objects().List() {
|
||||
// switch obj.Type {
|
||||
// case GoObjectType_Comment, GoObjectType_LineBreak:
|
||||
// insertPos = idx
|
||||
// default:
|
||||
// break
|
||||
// }
|
||||
//}
|
||||
|
||||
err = doc.Objects().Insert(insertPos, new)
|
||||
} else {
|
||||
err = doc.Objects().Add(new)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AddObject appends a new GoObject to the doc root object list.
|
||||
func (doc *GoDocument) AddObject(newObj *GoObject) error {
|
||||
return doc.Objects().Add(newObj)
|
||||
}
|
||||
|
||||
// InsertObject inserts a new GoObject at the desired position to the doc root object list.
|
||||
func (doc *GoDocument) InsertObject(pos int, newObj *GoObject) error {
|
||||
return doc.Objects().Insert(pos, newObj)
|
||||
}
|
||||
|
||||
// Imports returns the GoDocument imports.
|
||||
func (doc *GoDocument) Imports() (GoImports, error) {
|
||||
// If the doc imports are empty, try to load them from the root objects.
|
||||
if len(doc.imports) == 0 {
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type != GoObjectType_Import {
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := ParseImportObject(obj)
|
||||
if err != nil {
|
||||
return doc.imports, err
|
||||
}
|
||||
|
||||
// Combine all the imports into a single definition.
|
||||
for _, n := range res {
|
||||
doc.imports = append(doc.imports, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return doc.imports, nil
|
||||
}
|
||||
|
||||
// Lines returns all the code lines.
|
||||
func (doc *GoDocument) Lines() []string {
|
||||
l := []string{}
|
||||
|
||||
for _, ol := range doc.Objects().Lines() {
|
||||
l = append(l, ol)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns a single value for all the code lines.
|
||||
func (doc *GoDocument) String() string {
|
||||
return strings.Join(doc.Lines(), "\n")
|
||||
}
|
||||
|
||||
// Print writes all the code lines to standard out.
|
||||
func (doc *GoDocument) Print() {
|
||||
for _, l := range doc.Lines() {
|
||||
fmt.Println(l)
|
||||
}
|
||||
}
|
||||
|
||||
// Save renders all the code lines for the document, formats the code
|
||||
// and then saves it to the supplied file path.
|
||||
func (doc *GoDocument) Save(localpath string) error {
|
||||
res, err := format.Source([]byte(doc.String()))
|
||||
if err != nil {
|
||||
err = errors.WithMessage(err, "Failed formatted source code")
|
||||
return err
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(localpath, res, 0644)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed write source code to file %s", localpath)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddImport checks for any duplicate imports by name and adds it if not.
|
||||
func (doc *GoDocument) AddImport(impt GoImport) error {
|
||||
impt.Name = strings.Trim(impt.Name, "\"")
|
||||
|
||||
// Get a list of current imports for the document.
|
||||
impts, err := doc.Imports()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the document has as the import, don't add it.
|
||||
if impts.Has(impt.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop through all the document root objects for an object of type import.
|
||||
// If one exists, append the import to the existing list.
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type != GoObjectType_Import || len(obj.Lines()) == 1 {
|
||||
continue
|
||||
}
|
||||
obj.subLines = append(obj.subLines, impt.String())
|
||||
obj.goObjects.list = append(obj.goObjects.list, impt.Object())
|
||||
|
||||
doc.imports = append(doc.imports, impt)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Document does not have an existing import object, so need to create one and
|
||||
// then append to the document.
|
||||
newObj := NewObjectImports(impt)
|
||||
|
||||
// Insert after any package, any existing comments or line breaks should be included.
|
||||
var insertPos int
|
||||
for idx, obj := range doc.Objects().List() {
|
||||
switch obj.Type {
|
||||
case GoObjectType_Package, GoObjectType_Comment, GoObjectType_LineBreak:
|
||||
insertPos = idx
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the new import object.
|
||||
err = doc.InsertObject(insertPos, newObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewObjectImports returns a new GoObject with a single import definition.
|
||||
func NewObjectImports(impt GoImport) *GoObject {
|
||||
lines := []string{
|
||||
"import (",
|
||||
impt.String(),
|
||||
")",
|
||||
"",
|
||||
}
|
||||
|
||||
obj, _ := ParseGoObject(lines, 0)
|
||||
children, err := ParseLines(obj.subLines, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, child := range children.List() {
|
||||
obj.Objects().Add(child)
|
||||
}
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
// Has checks to see if an import exists by name or alias.
|
||||
func (impts GoImports) Has(name string) bool {
|
||||
for _, impt := range impts {
|
||||
if name == impt.Name || (impt.Alias != "" && name == impt.Alias) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Line formats an import as a string.
|
||||
func (impt GoImport) String() string {
|
||||
var imptLine string
|
||||
if impt.Alias != "" {
|
||||
imptLine = fmt.Sprintf("\t%s \"%s\"", impt.Alias, impt.Name)
|
||||
} else {
|
||||
imptLine = fmt.Sprintf("\t\"%s\"", impt.Name)
|
||||
}
|
||||
return imptLine
|
||||
}
|
||||
|
||||
// Object returns the first GoObject for an import.
|
||||
func (impt GoImport) Object() *GoObject {
|
||||
imptObj := NewObjectImports(impt)
|
||||
|
||||
return imptObj.Objects().List()[0]
|
||||
}
|
||||
|
||||
// ParseImportObject extracts all the import definitions.
|
||||
func ParseImportObject(obj *GoObject) (resp GoImports, err error) {
|
||||
if obj.Type != GoObjectType_Import {
|
||||
return resp, errors.Errorf("Invalid type %s", string(obj.Type))
|
||||
}
|
||||
|
||||
for _, l := range obj.Lines() {
|
||||
if !strings.Contains(l, "\"") {
|
||||
continue
|
||||
}
|
||||
l = strings.TrimSpace(l)
|
||||
|
||||
pts := strings.Split(l, "\"")
|
||||
|
||||
var impt GoImport
|
||||
if strings.HasPrefix(l, "\"") {
|
||||
impt.Name = pts[1]
|
||||
} else {
|
||||
impt.Alias = strings.TrimSpace(pts[0])
|
||||
impt.Name = pts[1]
|
||||
}
|
||||
|
||||
resp = append(resp, impt)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
458
example-project/tools/truss/internal/goparse/doc_object.go
Normal file
458
example-project/tools/truss/internal/goparse/doc_object.go
Normal file
@ -0,0 +1,458 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/structtag"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// GoEmptyLine defined a GoObject for a code line break.
|
||||
var GoEmptyLine = GoObject{
|
||||
Type: GoObjectType_LineBreak,
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
|
||||
// GoObjectType defines a set of possible types to group
|
||||
// parsed code by.
|
||||
type GoObjectType = string
|
||||
|
||||
var (
|
||||
GoObjectType_Package = "package"
|
||||
GoObjectType_Import = "import"
|
||||
GoObjectType_Var = "var"
|
||||
GoObjectType_Const = "const"
|
||||
GoObjectType_Func = "func"
|
||||
GoObjectType_Struct = "struct"
|
||||
GoObjectType_Comment = "comment"
|
||||
GoObjectType_LineBreak = "linebreak"
|
||||
GoObjectType_Line = "line"
|
||||
GoObjectType_Type = "type"
|
||||
)
|
||||
|
||||
// GoObject defines a section of code with a nested set of children.
|
||||
type GoObject struct {
|
||||
Type GoObjectType
|
||||
Name string
|
||||
startLines []string
|
||||
endLines []string
|
||||
subLines []string
|
||||
goObjects *GoObjects
|
||||
Index int
|
||||
}
|
||||
|
||||
// GoObjects stores a list of GoObject.
|
||||
type GoObjects struct {
|
||||
list []*GoObject
|
||||
}
|
||||
|
||||
// Objects returns the list of *GoObject.
|
||||
func (obj *GoObject) Objects() *GoObjects {
|
||||
if obj.goObjects == nil {
|
||||
obj.goObjects = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
}
|
||||
return obj.goObjects
|
||||
}
|
||||
|
||||
// Clone performs a deep copy of the struct.
|
||||
func (obj *GoObject) Clone() *GoObject {
|
||||
n := &GoObject{
|
||||
Type: obj.Type,
|
||||
Name: obj.Name,
|
||||
startLines: obj.startLines,
|
||||
endLines: obj.endLines,
|
||||
subLines: obj.subLines,
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
Index: obj.Index,
|
||||
}
|
||||
for _, sub := range obj.Objects().List() {
|
||||
n.Objects().Add(sub.Clone())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// IsComment returns whether an object is of type GoObjectType_Comment.
|
||||
func (obj *GoObject) IsComment() bool {
|
||||
if obj.Type != GoObjectType_Comment {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Contains searches all the lines for the object for a matching string.
|
||||
func (obj *GoObject) Contains(match string) bool {
|
||||
for _, l := range obj.Lines() {
|
||||
if strings.Contains(l, match) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateLines parses the new code and replaces the current GoObject.
|
||||
func (obj *GoObject) UpdateLines(newLines []string) error {
|
||||
|
||||
// Parse the new lines.
|
||||
objs, err := ParseLines(newLines, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newObj *GoObject
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Type == GoObjectType_LineBreak {
|
||||
continue
|
||||
}
|
||||
|
||||
if newObj == nil {
|
||||
newObj = obj
|
||||
}
|
||||
|
||||
// There should only be one resulting parsed object that is
|
||||
// not of type GoObjectType_LineBreak.
|
||||
return errors.New("Can only update single blocks of code")
|
||||
}
|
||||
|
||||
// No new code was parsed, return error.
|
||||
if newObj == nil {
|
||||
return errors.New("Failed to render replacement code")
|
||||
}
|
||||
|
||||
return obj.Update(newObj)
|
||||
}
|
||||
|
||||
// Update performs a deep copy that overwrites the existing values.
|
||||
func (obj *GoObject) Update(newObj *GoObject) error {
|
||||
obj.Type = newObj.Type
|
||||
obj.Name = newObj.Name
|
||||
obj.startLines = newObj.startLines
|
||||
obj.endLines = newObj.endLines
|
||||
obj.subLines = newObj.subLines
|
||||
obj.goObjects = newObj.goObjects
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lines returns a list of strings for current object and all children.
|
||||
func (obj *GoObject) Lines() []string {
|
||||
l := []string{}
|
||||
|
||||
// First include any lines before the sub objects.
|
||||
for _, sl := range obj.startLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
|
||||
// If there are parsed sub objects include those lines else when
|
||||
// no sub objects, just use the sub lines.
|
||||
if len(obj.Objects().List()) > 0 {
|
||||
for _, sl := range obj.Objects().Lines() {
|
||||
l = append(l, sl)
|
||||
}
|
||||
} else {
|
||||
for _, sl := range obj.subLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
}
|
||||
|
||||
// Lastly include any other lines that are after all parsed sub objects.
|
||||
for _, sl := range obj.endLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns the lines separated by line break.
|
||||
func (obj *GoObject) String() string {
|
||||
return strings.Join(obj.Lines(), "\n")
|
||||
}
|
||||
|
||||
// Lines returns a list of strings for all the list objects.
|
||||
func (objs *GoObjects) Lines() []string {
|
||||
l := []string{}
|
||||
for _, obj := range objs.List() {
|
||||
for _, oj := range obj.Lines() {
|
||||
l = append(l, oj)
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns all the lines for the list objects.
|
||||
func (objs *GoObjects) String() string {
|
||||
lines := []string{}
|
||||
for _, obj := range objs.List() {
|
||||
lines = append(lines, obj.String())
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// List returns the list of GoObjects.
|
||||
func (objs *GoObjects) List() []*GoObject {
|
||||
return objs.list
|
||||
}
|
||||
|
||||
// HasFunc searches the current list of objects for a function object by name.
|
||||
func (objs *GoObjects) HasFunc(name string) bool {
|
||||
return objs.HasType(name, GoObjectType_Func)
|
||||
}
|
||||
|
||||
// Get returns the GoObject for the matching name and type.
|
||||
func (objs *GoObjects) Get(name string, objType GoObjectType) *GoObject {
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || obj.Type == objType) {
|
||||
return obj
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasType checks is a GoObject exists for the matching name and type.
|
||||
func (objs *GoObjects) HasType(name string, objType GoObjectType) bool {
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || obj.Type == objType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasObject checks to see if the exact code block exists.
|
||||
func (objs *GoObjects) HasObject(src *GoObject) bool {
|
||||
if src == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Generate the code for the supplied object.
|
||||
srcLines := []string{}
|
||||
for _, l := range src.Lines() {
|
||||
// Exclude empty lines.
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
srcLines = append(srcLines, l)
|
||||
}
|
||||
}
|
||||
srcStr := strings.Join(srcLines, "\n")
|
||||
|
||||
// Loop over all the objects and match with src code.
|
||||
for _, obj := range objs.list {
|
||||
objLines := []string{}
|
||||
for _, l := range obj.Lines() {
|
||||
// Exclude empty lines.
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
objLines = append(objLines, l)
|
||||
}
|
||||
}
|
||||
objStr := strings.Join(objLines, "\n")
|
||||
|
||||
// Return true if the current object code matches src code.
|
||||
if srcStr == objStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Add appends a new GoObject to the list.
|
||||
func (objs *GoObjects) Add(newObj *GoObject) error {
|
||||
newObj.Index = len(objs.list)
|
||||
objs.list = append(objs.list, newObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert appends a new GoObject at the desired position to the list.
|
||||
func (objs *GoObjects) Insert(pos int, newObj *GoObject) error {
|
||||
newList := []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range objs.list {
|
||||
if obj.Index < pos {
|
||||
obj.Index = newIdx
|
||||
newList = append(newList, obj)
|
||||
} else {
|
||||
if obj.Index == pos {
|
||||
newObj.Index = newIdx
|
||||
newList = append(newList, newObj)
|
||||
newIdx++
|
||||
}
|
||||
obj.Index = newIdx
|
||||
newList = append(newList, obj)
|
||||
}
|
||||
|
||||
newIdx++
|
||||
}
|
||||
|
||||
objs.list = newList
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes a GoObject from the list.
|
||||
func (objs *GoObjects) Remove(delObjs ...*GoObject) error {
|
||||
for _, delObj := range delObjs {
|
||||
oldList := objs.List()
|
||||
objs.list = []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range oldList {
|
||||
if obj.Index == delObj.Index {
|
||||
continue
|
||||
}
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Replace updates an existing GoObject while maintaining is same position.
|
||||
func (objs *GoObjects) Replace(oldObj *GoObject, newObjs ...*GoObject) error {
|
||||
if oldObj.Index >= len(objs.list) {
|
||||
return errors.WithStack(errGoObjectNotExist)
|
||||
} else if len(newObjs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldList := objs.List()
|
||||
objs.list = []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range oldList {
|
||||
if obj.Index < oldObj.Index {
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
} else if obj.Index == oldObj.Index {
|
||||
for _, newObj := range newObjs {
|
||||
newObj.Index = newIdx
|
||||
objs.list = append(objs.list, newObj)
|
||||
newIdx++
|
||||
}
|
||||
} else {
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceFuncByName finds an existing GoObject with type GoObjectType_Func by name
|
||||
// and then performs a replace with the supplied new GoObject.
|
||||
func (objs *GoObjects) ReplaceFuncByName(name string, fn *GoObject) error {
|
||||
return objs.ReplaceTypeByName(name, fn, GoObjectType_Func)
|
||||
}
|
||||
|
||||
// ReplaceTypeByName finds an existing GoObject with type by name
|
||||
// and then performs a replace with the supplied new GoObject.
|
||||
func (objs *GoObjects) ReplaceTypeByName(name string, newObj *GoObject, objType GoObjectType) error {
|
||||
if newObj.Name == "" {
|
||||
newObj.Name = name
|
||||
}
|
||||
if newObj.Type == "" && objType != "" {
|
||||
newObj.Type = objType
|
||||
}
|
||||
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || objType == obj.Type) {
|
||||
return objs.Replace(obj, newObj)
|
||||
}
|
||||
}
|
||||
return errors.WithStack(errGoObjectNotExist)
|
||||
}
|
||||
|
||||
// Empty determines if all the GoObject in the list are line breaks.
|
||||
func (objs *GoObjects) Empty() bool {
|
||||
var hasStuff bool
|
||||
for _, obj := range objs.List() {
|
||||
switch obj.Type {
|
||||
case GoObjectType_LineBreak:
|
||||
//case GoObjectType_Comment:
|
||||
//case GoObjectType_Import:
|
||||
// do nothing
|
||||
default:
|
||||
hasStuff = true
|
||||
}
|
||||
}
|
||||
return hasStuff
|
||||
}
|
||||
|
||||
// Debug prints out the GoObject to logger.
|
||||
func (obj *GoObject) Debug(log *log.Logger) {
|
||||
log.Println(obj.Name)
|
||||
log.Println(" > type:", obj.Type)
|
||||
log.Println(" > start lines:")
|
||||
for _, l := range obj.startLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
|
||||
log.Println(" > sub lines:")
|
||||
for _, l := range obj.subLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
|
||||
log.Println(" > end lines:")
|
||||
for _, l := range obj.endLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
}
|
||||
|
||||
// Defines a property of a struct.
|
||||
type structProp struct {
|
||||
Name string
|
||||
Type string
|
||||
Tags *structtag.Tags
|
||||
}
|
||||
|
||||
// ParseStructProp extracts the details for a struct property.
|
||||
func ParseStructProp(obj *GoObject) (structProp, error) {
|
||||
|
||||
if obj.Type != GoObjectType_Line {
|
||||
return structProp{}, errors.Errorf("Unable to parse object of type %s", obj.Type)
|
||||
}
|
||||
|
||||
// Remove any white space from the code line.
|
||||
ls := strings.TrimSpace(strings.Join(obj.Lines(), " "))
|
||||
|
||||
// Extract the property name and type for the line.
|
||||
// ie: ID string `json:"id"`
|
||||
var resp structProp
|
||||
for _, p := range strings.Split(ls, " ") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if resp.Name == "" {
|
||||
resp.Name = p
|
||||
} else if resp.Type == "" {
|
||||
resp.Type = p
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the line contains tags, extract and parse them.
|
||||
if strings.Contains(ls, "`") {
|
||||
tagStr := strings.Split(ls, "`")[1]
|
||||
|
||||
var err error
|
||||
resp.Tags, err = structtag.Parse(tagStr)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse struct tag for field %s: %s", resp.Name, tagStr)
|
||||
return structProp{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
329
example-project/tools/truss/internal/goparse/goparse.go
Normal file
329
example-project/tools/truss/internal/goparse/goparse.go
Normal file
@ -0,0 +1,329 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errGoParseType = errors.New("Unable to determine type for line")
|
||||
errGoTypeMissingCodeTemplate = errors.New("No code defined for type")
|
||||
errGoObjectNotExist = errors.New("GoObject does not exist")
|
||||
)
|
||||
|
||||
// ParseFile reads a go code file and parses into a easily transformable set of objects.
|
||||
func ParseFile(log *log.Logger, localPath string) (*GoDocument, error) {
|
||||
|
||||
// Read the code file.
|
||||
src, err := ioutil.ReadFile(localPath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read file %s", localPath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Format the code file source to ensure parse works.
|
||||
dat, err := format.Source(src)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for file %s", localPath)
|
||||
log.Printf("ParseFile : %v\n%v", err, string(src))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Loop of the formatted source code and generate a list of code lines.
|
||||
lines := []string{}
|
||||
r := bytes.NewReader(dat)
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed read formatted source code for file %s", localPath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Append the resulting objects to the document.
|
||||
doc := &GoDocument{}
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Type == GoObjectType_Package {
|
||||
doc.Package = obj.Name
|
||||
}
|
||||
doc.AddObject(obj)
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// ParseLines takes the list of formatted code lines and returns the GoObjects.
|
||||
func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
|
||||
objs = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
|
||||
var (
|
||||
multiLine bool
|
||||
multiComment bool
|
||||
muiliVar bool
|
||||
)
|
||||
curDepth := -1
|
||||
objLines := []string{}
|
||||
|
||||
for idx, l := range lines {
|
||||
ls := strings.TrimSpace(l)
|
||||
|
||||
ld := lineDepth(l)
|
||||
|
||||
if ld == depth {
|
||||
if strings.HasPrefix(ls, "/*") {
|
||||
multiLine = true
|
||||
multiComment = true
|
||||
} else if strings.HasSuffix(ls, "(") ||
|
||||
strings.HasSuffix(ls, "{") {
|
||||
|
||||
if !multiLine {
|
||||
multiLine = true
|
||||
}
|
||||
} else if strings.Contains(ls, "`") {
|
||||
if !multiLine && strings.Count(ls, "`")%2 != 0 {
|
||||
if muiliVar {
|
||||
muiliVar = false
|
||||
} else {
|
||||
muiliVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if multiComment {
|
||||
if strings.HasSuffix(ls, "*/") {
|
||||
multiComment = false
|
||||
multiLine = false
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(ls, ")") ||
|
||||
strings.HasPrefix(ls, "}") {
|
||||
multiLine = false
|
||||
}
|
||||
}
|
||||
|
||||
if !multiLine && !muiliVar {
|
||||
for eidx := idx + 1; eidx < len(lines); eidx++ {
|
||||
if el := lines[eidx]; strings.TrimSpace(el) == "" {
|
||||
objLines = append(objLines, el)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
obj, err := ParseGoObject(objLines, depth)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
err = objs.Add(obj)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
|
||||
objLines = []string{}
|
||||
}
|
||||
|
||||
} else if (multiLine && ld >= curDepth && ld >= depth && len(objLines) > 0) || muiliVar {
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if strings.Contains(ls, "`") {
|
||||
if !multiLine && strings.Count(ls, "`")%2 != 0 {
|
||||
if muiliVar {
|
||||
muiliVar = false
|
||||
} else {
|
||||
muiliVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, obj := range objs.List() {
|
||||
children, err := ParseLines(obj.subLines, depth+1)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
for _, child := range children.List() {
|
||||
obj.Objects().Add(child)
|
||||
}
|
||||
}
|
||||
|
||||
return objs, nil
|
||||
}
|
||||
|
||||
// ParseGoObject generates a GoObjected for the given code lines.
|
||||
func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
|
||||
|
||||
// If there are no lines, return a line break.
|
||||
if len(lines) == 0 {
|
||||
return &GoEmptyLine, nil
|
||||
}
|
||||
|
||||
firstLine := lines[0]
|
||||
firstStrip := strings.TrimSpace(firstLine)
|
||||
|
||||
if len(firstStrip) == 0 {
|
||||
return &GoEmptyLine, nil
|
||||
}
|
||||
|
||||
obj = &GoObject{
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
|
||||
if strings.HasPrefix(firstStrip, "var") {
|
||||
obj.Type = GoObjectType_Var
|
||||
} else if strings.HasPrefix(firstStrip, "const") {
|
||||
obj.Type = GoObjectType_Const
|
||||
} else if strings.HasPrefix(firstStrip, "func") {
|
||||
obj.Type = GoObjectType_Func
|
||||
|
||||
if strings.HasPrefix(firstStrip, "func (") {
|
||||
funcLine := strings.TrimLeft(strings.TrimSpace(strings.TrimLeft(firstStrip, "func ")), "(")
|
||||
|
||||
var structName string
|
||||
pts := strings.Split(strings.Split(funcLine, ")")[0], " ")
|
||||
for i := len(pts) - 1; i >= 0; i-- {
|
||||
ptVal := strings.TrimSpace(pts[i])
|
||||
if ptVal != "" {
|
||||
structName = ptVal
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var funcName string
|
||||
pts = strings.Split(strings.Split(funcLine, "(")[0], " ")
|
||||
for i := len(pts) - 1; i >= 0; i-- {
|
||||
ptVal := strings.TrimSpace(pts[i])
|
||||
if ptVal != "" {
|
||||
funcName = ptVal
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
obj.Name = fmt.Sprintf("%s.%s", structName, funcName)
|
||||
} else {
|
||||
obj.Name = strings.TrimLeft(firstStrip, "func ")
|
||||
obj.Name = strings.Split(obj.Name, "(")[0]
|
||||
}
|
||||
} else if strings.HasSuffix(firstStrip, "struct {") || strings.HasSuffix(firstStrip, "struct{") {
|
||||
obj.Type = GoObjectType_Struct
|
||||
|
||||
if strings.HasPrefix(firstStrip, "type ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
} else if strings.HasPrefix(firstStrip, "type") {
|
||||
obj.Type = GoObjectType_Type
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
} else if strings.HasPrefix(firstStrip, "package") {
|
||||
obj.Name = strings.TrimSpace(strings.TrimLeft(firstStrip, "package "))
|
||||
|
||||
obj.Type = GoObjectType_Package
|
||||
} else if strings.HasPrefix(firstStrip, "import") {
|
||||
obj.Type = GoObjectType_Import
|
||||
} else if strings.HasPrefix(firstStrip, "//") {
|
||||
obj.Type = GoObjectType_Comment
|
||||
} else if strings.HasPrefix(firstStrip, "/*") {
|
||||
obj.Type = GoObjectType_Comment
|
||||
} else {
|
||||
if depth > 0 {
|
||||
obj.Type = GoObjectType_Line
|
||||
} else {
|
||||
err = errors.WithStack(errGoParseType)
|
||||
return obj, err
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
hasSub bool
|
||||
muiliVarStart bool
|
||||
muiliVarSub bool
|
||||
muiliVarEnd bool
|
||||
)
|
||||
for _, l := range lines {
|
||||
ld := lineDepth(l)
|
||||
if (ld == depth && !muiliVarSub) || muiliVarStart || muiliVarEnd {
|
||||
if hasSub && !muiliVarStart {
|
||||
if strings.TrimSpace(l) != "" {
|
||||
obj.endLines = append(obj.endLines, l)
|
||||
}
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarEnd {
|
||||
muiliVarEnd = false
|
||||
} else {
|
||||
muiliVarEnd = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
obj.startLines = append(obj.startLines, l)
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarStart {
|
||||
muiliVarStart = false
|
||||
} else {
|
||||
muiliVarStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if ld > depth || muiliVarSub {
|
||||
obj.subLines = append(obj.subLines, l)
|
||||
hasSub = true
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarSub {
|
||||
muiliVarSub = false
|
||||
} else {
|
||||
muiliVarSub = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add trailing linebreak
|
||||
if len(obj.endLines) > 0 {
|
||||
obj.endLines = append(obj.endLines, "")
|
||||
}
|
||||
|
||||
return obj, err
|
||||
}
|
||||
|
||||
// lineDepth returns the number of spaces for the given code line
|
||||
// used to determine the code level for nesting objects.
|
||||
func lineDepth(l string) int {
|
||||
depth := len(l) - len(strings.TrimLeftFunc(l, unicode.IsSpace))
|
||||
|
||||
ls := strings.TrimSpace(l)
|
||||
if strings.HasPrefix(ls, "}") && strings.Contains(ls, " else ") {
|
||||
depth = depth + 1
|
||||
} else if strings.HasPrefix(ls, "case ") {
|
||||
depth = depth + 1
|
||||
}
|
||||
return depth
|
||||
}
|
195
example-project/tools/truss/internal/goparse/goparse_test.go
Normal file
195
example-project/tools/truss/internal/goparse/goparse_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var logger *log.Logger
|
||||
|
||||
func init() {
|
||||
logger = log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
}
|
||||
|
||||
func TestParseFileModel1(t *testing.T) {
|
||||
|
||||
_, err := ParseFile(logger, "test_gofile_model1.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMultilineVar(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `func ContextAllowedAccountIds(ctx context.Context, db *gorm.DB) (resp akdatamodels.Uint32List, err error) {
|
||||
resp = []uint32{}
|
||||
accountId := akcontext.ContextAccountId(ctx)
|
||||
m := datamodels.UserAccount{}
|
||||
q := fmt.Sprintf("select
|
||||
distinct account_id
|
||||
from %s where account_id = ?", m.TableName())
|
||||
db = db.Raw(q, accountId)
|
||||
}
|
||||
`
|
||||
code = strings.Replace(code, "\"", "`", -1)
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func TestNewDocImports(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
expected := []string{
|
||||
"package goparse",
|
||||
"",
|
||||
"import (",
|
||||
" \"github.com/go/pkg1\"",
|
||||
" \"github.com/go/pkg2\"",
|
||||
")",
|
||||
"",
|
||||
}
|
||||
|
||||
doc := &GoDocument{}
|
||||
doc.SetPackage("goparse")
|
||||
|
||||
doc.AddImport(GoImport{Name: "github.com/go/pkg1"})
|
||||
doc.AddImport(GoImport{Name: "github.com/go/pkg2"})
|
||||
|
||||
g.Expect(doc.Lines()).Should(gomega.Equal(expected))
|
||||
}
|
||||
|
||||
func TestParseLines1(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
obj := datamodels.MockModelNew()
|
||||
resp, err := ModelCreate(ctx, DB, &obj)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(resp.Name).Should(gomega.Equal(obj.Name))
|
||||
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
|
||||
return resp
|
||||
}
|
||||
`
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func TestParseLines2(t *testing.T) {
|
||||
code := `func structToMap(s interface{}) (resp map[string]interface{}) {
|
||||
dat, _ := json.Marshal(s)
|
||||
_ = json.Unmarshal(dat, &resp)
|
||||
for k, x := range resp {
|
||||
switch v := x.(type) {
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
delete(resp, k)
|
||||
}
|
||||
|
||||
case *time.Time:
|
||||
if v == nil || v.IsZero() {
|
||||
delete(resp, k)
|
||||
}
|
||||
|
||||
case nil:
|
||||
delete(resp, k)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
`
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
testLineTextMatches(t, objs.Lines(), lines)
|
||||
}
|
||||
|
||||
func TestParseLines3(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `type UserAccountRoleName = string
|
||||
|
||||
const (
|
||||
UserAccountRoleName_None UserAccountRoleName = ""
|
||||
UserAccountRoleName_Admin UserAccountRoleName = "admin"
|
||||
UserAccountRoleName_User UserAccountRoleName = "user"
|
||||
)
|
||||
|
||||
type UserAccountRole struct {
|
||||
Id uint32 ^gorm:"column:id;type:int(10) unsigned AUTO_INCREMENT;primary_key;not null;auto_increment;" truss:"internal:true"^
|
||||
CreatedAt time.Time ^gorm:"column:created_at;type:datetime;default:CURRENT_TIMESTAMP;not null;" truss:"internal:true"^
|
||||
UpdatedAt time.Time ^gorm:"column:updated_at;type:datetime;" truss:"internal:true"^
|
||||
DeletedAt *time.Time ^gorm:"column:deleted_at;type:datetime;" truss:"internal:true"^
|
||||
Role UserAccountRoleName ^gorm:"unique_index:user_account_role;column:role;type:enum('admin', 'user')"^
|
||||
// belongs to User
|
||||
User *User ^gorm:"foreignkey:UserId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true"^
|
||||
UserId uint32 ^gorm:"unique_index:user_account_role;"^
|
||||
// belongs to Account
|
||||
Account *Account ^gorm:"foreignkey:AccountId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true;api_ro:true;"^
|
||||
AccountId uint32 ^gorm:"unique_index:user_account_role;" truss:"internal:true;api_ro:true;"^
|
||||
}
|
||||
|
||||
func (UserAccountRole) TableName() string {
|
||||
return "user_account_roles"
|
||||
}
|
||||
`
|
||||
code = strings.Replace(code, "^", "'", -1)
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func testLineTextMatches(t *testing.T, l1, l2 []string) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
m1 := []string{}
|
||||
for _, l := range l1 {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
m1 = append(m1, l)
|
||||
}
|
||||
}
|
||||
|
||||
m2 := []string{}
|
||||
for _, l := range l2 {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
m2 = append(m2, l)
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(m1).Should(gomega.Equal(m2))
|
||||
}
|
@ -0,0 +1,126 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
// Account represents someone with access to our system.
|
||||
type Account struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Address1 string `json:"address1"`
|
||||
Address2 string `json:"address2"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
Zipcode string `json:"zipcode"`
|
||||
Status AccountStatus `json:"status"`
|
||||
Timezone string `json:"timezone"`
|
||||
SignupUserID sql.NullString `json:"signup_user_id"`
|
||||
BillingUserID sql.NullString `json:"billing_user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at"`
|
||||
}
|
||||
|
||||
// CreateAccountRequest contains information needed to create a new Account.
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" validate:"required,unique"`
|
||||
Address1 string `json:"address1" validate:"required"`
|
||||
Address2 string `json:"address2" validate:"omitempty"`
|
||||
City string `json:"city" validate:"required"`
|
||||
Region string `json:"region" validate:"required"`
|
||||
Country string `json:"country" validate:"required"`
|
||||
Zipcode string `json:"zipcode" validate:"required"`
|
||||
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
|
||||
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
|
||||
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest defines what information may be provided to modify an existing
|
||||
// Account. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank. Normally
|
||||
// we do not want to use pointers to basic types but we make exceptions around
|
||||
// marshalling/unmarshalling.
|
||||
type UpdateAccountRequest struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
Name *string `json:"name" validate:"omitempty,unique"`
|
||||
Address1 *string `json:"address1" validate:"omitempty"`
|
||||
Address2 *string `json:"address2" validate:"omitempty"`
|
||||
City *string `json:"city" validate:"omitempty"`
|
||||
Region *string `json:"region" validate:"omitempty"`
|
||||
Country *string `json:"country" validate:"omitempty"`
|
||||
Zipcode *string `json:"zipcode" validate:"omitempty"`
|
||||
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
|
||||
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
|
||||
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
|
||||
}
|
||||
|
||||
// AccountFindRequest defines the possible options to search for accounts. By default
|
||||
// archived accounts will be excluded from response.
|
||||
type AccountFindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
IncludedArchived bool
|
||||
}
|
||||
|
||||
// AccountStatus represents the status of an account.
|
||||
type AccountStatus string
|
||||
|
||||
// AccountStatus values define the status field of a user account.
|
||||
const (
|
||||
// AccountStatus_Active defines the state when a user can access an account.
|
||||
AccountStatus_Active AccountStatus = "active"
|
||||
// AccountStatus_Pending defined the state when an account was created but
|
||||
// not activated.
|
||||
AccountStatus_Pending AccountStatus = "pending"
|
||||
// AccountStatus_Disabled defines the state when a user has been disabled from
|
||||
// accessing an account.
|
||||
AccountStatus_Disabled AccountStatus = "disabled"
|
||||
)
|
||||
|
||||
// AccountStatus_Values provides list of valid AccountStatus values.
|
||||
var AccountStatus_Values = []AccountStatus{
|
||||
AccountStatus_Active,
|
||||
AccountStatus_Pending,
|
||||
AccountStatus_Disabled,
|
||||
}
|
||||
|
||||
// Scan supports reading the AccountStatus value from the database.
|
||||
func (s *AccountStatus) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = AccountStatus(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the AccountStatus value to be stored in the database.
|
||||
func (s AccountStatus) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof=active invited disabled")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the AccountStatus value to a string.
|
||||
func (s AccountStatus) String() string {
|
||||
return string(s)
|
||||
}
|
227
example-project/tools/truss/main.go
Normal file
227
example-project/tools/truss/main.go
Normal file
@ -0,0 +1,227 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/cmd/dbtable2crud"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/lib/pq"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli"
|
||||
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
|
||||
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
// service is the name of the program used for logging, tracing and the
|
||||
// the prefix used for loading env variables
|
||||
// ie: export TRUSS_ENV=dev
|
||||
var service = "TRUSS"
|
||||
|
||||
func main() {
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
var cfg struct {
|
||||
DB struct {
|
||||
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
|
||||
User string `default:"postgres" envconfig:"USER"`
|
||||
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
|
||||
Database string `default:"shared" envconfig:"DATABASE"`
|
||||
Driver string `default:"postgres" envconfig:"DRIVER"`
|
||||
Timezone string `default:"utc" envconfig:"TIMEZONE"`
|
||||
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
|
||||
}
|
||||
}
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(service, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
// TODO: can't use flag.Process here since it doesn't support nested arg options
|
||||
//if err := flag.Process(&cfg); err != nil {
|
||||
/// if err != flag.ErrHelp {
|
||||
// log.Fatalf("main : Parsing Command Line : %v", err)
|
||||
// }
|
||||
// return // We displayed help.
|
||||
//}
|
||||
|
||||
// =========================================================================
|
||||
// Log App Info
|
||||
|
||||
// Print the build version for our logs. Also expose it under /debug/vars.
|
||||
expvar.NewString("build").Set(build)
|
||||
log.Printf("main : Started : Application Initializing version %q", build)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// Print the config for our logs. It's important to any credentials in the config
|
||||
// that could expose a security risk are excluded from being json encoded by
|
||||
// applying the tag `json:"-"` to the struct var.
|
||||
{
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Database
|
||||
var dbUrl url.URL
|
||||
{
|
||||
// Query parameters.
|
||||
var q url.Values = make(map[string][]string)
|
||||
|
||||
// Handle SSL Mode
|
||||
if cfg.DB.DisableTLS {
|
||||
q.Set("sslmode", "disable")
|
||||
} else {
|
||||
q.Set("sslmode", "require")
|
||||
}
|
||||
|
||||
q.Set("timezone", cfg.DB.Timezone)
|
||||
|
||||
// Construct url.
|
||||
dbUrl = url.URL{
|
||||
Scheme: cfg.DB.Driver,
|
||||
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
|
||||
Host: cfg.DB.Host,
|
||||
Path: cfg.DB.Database,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
}
|
||||
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
}
|
||||
defer masterDb.Close()
|
||||
|
||||
// =========================================================================
|
||||
// Start Truss
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "dbtable2crud",
|
||||
Aliases: []string{"dbtable2crud"},
|
||||
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "dbtable, table"},
|
||||
cli.StringFlag{Name: "modelFile, modelfile, file"},
|
||||
cli.StringFlag{Name: "modelName, modelname, model"},
|
||||
cli.StringFlag{Name: "templateDir, templates", Value: "./templates/dbtable2crud"},
|
||||
cli.StringFlag{Name: "projectPath", Value: ""},
|
||||
},
|
||||
Action: func(c *cli.Context) error {
|
||||
dbTable := strings.TrimSpace(c.String("dbtable"))
|
||||
modelFile := strings.TrimSpace(c.String("modelFile"))
|
||||
modelName := strings.TrimSpace(c.String("modelName"))
|
||||
templateDir := strings.TrimSpace(c.String("templateDir"))
|
||||
projectPath := strings.TrimSpace(c.String("projectPath"))
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to get current working directory")
|
||||
}
|
||||
|
||||
if !path.IsAbs(templateDir) {
|
||||
templateDir = filepath.Join(pwd, templateDir)
|
||||
}
|
||||
ok, err := exists(templateDir)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load template directory")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Template directory %s does not exist", templateDir)
|
||||
}
|
||||
|
||||
if modelFile == "" {
|
||||
return errors.Errorf("Model file path is required")
|
||||
}
|
||||
|
||||
if !path.IsAbs(modelFile) {
|
||||
modelFile = filepath.Join(pwd, modelFile)
|
||||
}
|
||||
ok, err = exists(modelFile)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load model file")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Model file %s does not exist", modelFile)
|
||||
}
|
||||
|
||||
// Load the project path from go.mod if not set.
|
||||
if projectPath == "" {
|
||||
goModFile := filepath.Join(pwd, "../../go.mod")
|
||||
ok, err = exists(goModFile)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load go.mod for project")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Failed to locate project go.mod at %s", goModFile)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadFile(goModFile)
|
||||
if err != nil {
|
||||
return errors.WithMessagef(err, "Failed to read go.mod at %s", goModFile)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(b), "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "module ") {
|
||||
projectPath = strings.TrimSpace(strings.Split(l, " ")[1])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if modelName == "" {
|
||||
modelName = strings.Split(filepath.Base(modelFile), ".")[0]
|
||||
modelName = strings.Replace(modelName, "_", " ", -1)
|
||||
modelName = strings.Replace(modelName, "-", " ", -1)
|
||||
modelName = strings.Title(modelName)
|
||||
modelName = strings.Replace(modelName, " ", "", -1)
|
||||
}
|
||||
|
||||
return dbtable2crud.Run(masterDb, log, cfg.DB.Database, dbTable, modelFile, modelName, templateDir, projectPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = app.Run(os.Args)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Truss : %+v", err)
|
||||
}
|
||||
|
||||
log.Printf("main : Truss : Completed")
|
||||
}
|
||||
|
||||
// exists returns a bool as to whether a file path exists.
|
||||
func exists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return true, err
|
||||
}
|
4
example-project/tools/truss/sample.env
Normal file
4
example-project/tools/truss/sample.env
Normal file
@ -0,0 +1,4 @@
|
||||
export TRUSS_DB_HOST=127.0.0.1:5433
|
||||
export TRUSS_DB_USER=postgres
|
||||
export TRUSS_DB_PASS=postgres
|
||||
export TRUSS_DB_DISABLE_TLS=true
|
@ -0,0 +1,503 @@
|
||||
{{ define "imports"}}
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"{{ $.GoSrcPath }}/internal/platform/auth"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Globals"}}
|
||||
const (
|
||||
// The database table for {{ $.Model.Name }}
|
||||
{{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the postgres not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Helpers"}}
|
||||
|
||||
// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }}
|
||||
var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}"
|
||||
|
||||
// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct
|
||||
// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns
|
||||
func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
|
||||
var (
|
||||
m {{ $.Model.Name }}
|
||||
err error
|
||||
)
|
||||
err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }})
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "ACL"}}
|
||||
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
|
||||
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
|
||||
if claims.Audience != "" {
|
||||
// select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID]
|
||||
query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Where(query.And(
|
||||
query.Equal("account_id", claims.Audience),
|
||||
query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}),
|
||||
))
|
||||
queryStr, args := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
|
||||
var {{ FormatCamelLower $.Model.PrimaryField }} string
|
||||
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return err
|
||||
}
|
||||
|
||||
// When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access
|
||||
// to the specified {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
if {{ FormatCamelLower $.Model.PrimaryField }} == "" {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
}
|
||||
{{ else }}
|
||||
// TODO: Unable to auto generate sql statement, update accordingly.
|
||||
panic("Not implemented!")
|
||||
{{ end }}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
|
||||
// 1. No claims, request is internal, no ACL applied
|
||||
{{ if $hasAccountId }}
|
||||
// 2. All role types can access their user ID
|
||||
{{ end }}
|
||||
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
||||
// Claims are empty, don't apply any ACL
|
||||
if claims.Audience == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
query.Where(query.Equal("account_id", claims.Audience))
|
||||
{{ end }}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Find"}}
|
||||
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}
|
||||
|
||||
// selectQuery constructs a base select query for {{ $.Model.Name }}
|
||||
func selectQuery() *sqlbuilder.SelectBuilder {
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
|
||||
query.From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
return query
|
||||
}
|
||||
|
||||
// findRequestQuery generates the select query for the given find request.
|
||||
// TODO: Need to figure out why can't parse the args when appending the where
|
||||
// to the query.
|
||||
func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||
query := selectQuery()
|
||||
if req.Where != nil {
|
||||
query.Where(query.And(*req.Where))
|
||||
}
|
||||
if len(req.Order) > 0 {
|
||||
query.OrderBy(req.Order...)
|
||||
}
|
||||
if req.Limit != nil {
|
||||
query.Limit(int(*req.Limit))
|
||||
}
|
||||
if req.Offset != nil {
|
||||
query.Offset(int(*req.Offset))
|
||||
}
|
||||
|
||||
return query, req.Args
|
||||
}
|
||||
|
||||
// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) {
|
||||
query, args := findRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }})
|
||||
}
|
||||
|
||||
// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find")
|
||||
defer span.Finish()
|
||||
|
||||
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
|
||||
query.From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
|
||||
{{ if $hasArchived }}
|
||||
if !includedArchived {
|
||||
query.Where(query.IsNull("archived_at"))
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
// Check to see if a sub query needs to be applied for the claims.
|
||||
err := applyClaimsSelect(ctx, claims, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queryStr, queryArgs := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
args = append(args, queryArgs...)
|
||||
|
||||
// Fetch all entries from the db.
|
||||
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate over each row.
|
||||
resp := []*{{ $.Model.Name }}{}
|
||||
for rows.Next() {
|
||||
u, err := mapRowsTo{{ $.Model.Name }}(rows)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, u)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read")
|
||||
defer span.Finish()
|
||||
|
||||
// Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }}
|
||||
query := selectQuery()
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }}))
|
||||
|
||||
res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil || len(res) == 0 {
|
||||
err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id)
|
||||
return nil, err
|
||||
}
|
||||
u := res[0]
|
||||
|
||||
return u, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Create"}}
|
||||
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
|
||||
{{ $createFields := (index $.StructFields $reqName) }}
|
||||
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
|
||||
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
|
||||
defer span.Finish()
|
||||
|
||||
if claims.Audience != "" {
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
if req.AccountId != "" {
|
||||
// Request accountId must match claims.
|
||||
if req.AccountId != claims.Audience {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
} else {
|
||||
// Set the accountId from claims.
|
||||
req.AccountId = claims.Audience
|
||||
}
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err = v.Struct(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
m := {{ $.Model.Name }}{
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}
|
||||
{{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }},
|
||||
{{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now,
|
||||
{{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }}
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
|
||||
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
|
||||
if req.{{ $f.FieldName }} != nil {
|
||||
{{ if eq $f.FieldType "sql.NullString" }}
|
||||
m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
|
||||
{{ else if eq $f.FieldType "*sql.NullString" }}
|
||||
m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
|
||||
{{ else }}
|
||||
m.{{ $f.FieldName }} = *req.{{ $f.FieldName }}
|
||||
{{ end }}
|
||||
}
|
||||
{{ end }}{{ end }}
|
||||
|
||||
// Build the insert SQL statement.
|
||||
query := sqlbuilder.NewInsertBuilder()
|
||||
query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Cols(
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}",
|
||||
{{ end }}{{ end }}
|
||||
)
|
||||
query.Values(
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }},
|
||||
{{ end }}{{ end }}
|
||||
)
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Update"}}
|
||||
{{ $reqName := (Concat $.Model.Name "UpdateRequest") }}
|
||||
{{ $updateFields := (index $.StructFields $reqName) }}
|
||||
// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database.
|
||||
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update")
|
||||
defer span.Finish()
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
|
||||
var fields []string
|
||||
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }}
|
||||
{{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }}
|
||||
if req.{{ $uf.FieldName }} != nil {
|
||||
{{ if and ($optional) ($isUuid) }}
|
||||
if *req.{{ $uf.FieldName }} != "" {
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
|
||||
} else {
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil))
|
||||
}
|
||||
{{ else }}
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
|
||||
{{ end }}
|
||||
}
|
||||
{{ end }}{{ end }}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }}
|
||||
// Append the updated_at field
|
||||
fields = append(fields, query.Assign("updated_at", now))
|
||||
{{ end }}
|
||||
|
||||
query.Set(fields...)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Archive"}}
|
||||
// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
|
||||
}{
|
||||
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Set(
|
||||
query.Assign("archived_at", now),
|
||||
)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Delete"}}
|
||||
// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
|
||||
}{
|
||||
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the delete SQL statement.
|
||||
query := sqlbuilder.NewDeleteBuilder()
|
||||
query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
@ -0,0 +1,79 @@
|
||||
{{ define "CreateRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}CreateRequest contains information needed to create a new {{ FormatCamel $.Model.Name }}.
|
||||
type {{ FormatCamel $.Model.Name }}CreateRequest struct {
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if and ($f.ApiCreate) (ne $f.FieldName $.Model.PrimaryField) }}{{ $optional := (FieldTagHasOption $f "validate" "omitempty") }}
|
||||
{{ $f.FieldName }} {{ if and ($optional) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ FieldTag $f "validate" }}`
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "UpdateRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}UpdateRequest defines what information may be provided to modify an existing
|
||||
// {{ FormatCamel $.Model.Name }}. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank.
|
||||
type {{ FormatCamel $.Model.Name }}UpdateRequest struct {
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if $f.ApiUpdate }}
|
||||
{{ $f.FieldName }} {{ if and (ne $f.FieldName $.Model.PrimaryField) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ if ne $f.FieldName $.Model.PrimaryField }}{{ FieldTagReplaceOrPrepend $f "validate" "required" "omitempty" }}{{ else }}{{ FieldTagReplaceOrPrepend $f "validate" "omitempty" "required" }}{{ end }}`
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "FindRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}FindRequest defines the possible options to search for {{ FormatCamelPluralTitleLower $.Model.Name }}. By default
|
||||
// archived {{ FormatCamelLowerTitle $.Model.Name }} will be excluded from response.
|
||||
type {{ FormatCamel $.Model.Name }}FindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}{{ if $hasArchived }}IncludedArchived bool{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Enums"}}
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if $f.DbColumn }}{{ if $f.DbColumn.IsEnum }}
|
||||
// {{ $f.FieldType }} represents the {{ $f.ColumnName }} of {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
type {{ $f.FieldType }} string
|
||||
|
||||
// {{ $f.FieldType }} values define the {{ $f.ColumnName }} field of {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
const (
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}"
|
||||
{{ end }}
|
||||
)
|
||||
|
||||
// {{ $f.FieldType }}_Values provides list of valid {{ $f.FieldType }} values.
|
||||
var {{ $f.FieldType }}_Values = []{{ $f.FieldType }}{
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }},
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
// Scan supports reading the {{ $f.FieldType }} value from the database.
|
||||
func (s *{{ $f.FieldType }}) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = {{ $f.FieldType }}(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the {{ $f.FieldType }} value to be stored in the database.
|
||||
func (s {{ $f.FieldType }}) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof={{ JoinStrings $f.DbColumn.EnumValues " " }}")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the {{ $f.FieldType }} value to a string.
|
||||
func (s {{ $f.FieldType }}) String() string {
|
||||
return string(s)
|
||||
}
|
||||
{{ end }}{{ end }}{{ end }}
|
||||
{{ end }}
|
Loading…
x
Reference in New Issue
Block a user