1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-13 00:07:30 +02:00

Completed truss code gen for generating model requests and crud.

This commit is contained in:
Lee Brown 2019-06-24 01:30:18 -08:00
parent efaeeb7103
commit bdbe3c587a
25 changed files with 3554 additions and 351 deletions

View File

@ -120,6 +120,29 @@ If you have a lot of migrations, it can be a pain to run all them, as an example
Another bonus with the globally defined schema allows testing to spin up database containers on demand include all the migrations. The testing package enables unit tests to programmatically execute schema migrations before running any unit tests.
### Accessing Postgres
Login to the local postgres container
```bash
docker exec -it example-project_postgres_1 /bin/bash
bash-4.4# psql -u postgres shared
```
Show tables
```commandline
shared=# \dt
List of relations
Schema | Name | Type | Owner
--------+----------------+-------+----------
public | accounts | table | postgres
public | migrations | table | postgres
public | projects | table | postgres
public | users | table | postgres
public | users_accounts | table | postgres
(5 rows)
```
## Development Notes

View File

@ -3,12 +3,12 @@ package main
import (
"encoding/json"
"expvar"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
"github.com/lib/pq"
"log"
"net/url"
"os"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
"github.com/lib/pq"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/flag"
"github.com/kelseyhightower/envconfig"
_ "github.com/lib/pq"
@ -19,11 +19,16 @@ import (
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
// service is the name of the program used for logging, tracing and the
// the prefix used for loading env variables
// ie: export SCHEMA_ENV=dev
var service = "SCHEMA"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "Schema : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
@ -40,12 +45,8 @@ func main() {
}
}
// The prefix used for loading env variables.
// ie: export SCHEMA_ENV=dev
envKeyPrefix := "SCHEMA"
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
if err := envconfig.Process(service, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
@ -104,7 +105,7 @@ func main() {
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)

View File

@ -34,12 +34,17 @@ import (
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
// service is the name of the program used for logging, tracing and the
// the prefix used for loading env variables
// ie: export WEB_API_ENV=dev
var service = "WEB_API"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "WEB_API : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
@ -110,12 +115,8 @@ func main() {
}
}
// The prefix used for loading env variables.
// ie: export WEB_API_ENV=dev
envKeyPrefix := "WEB_API"
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
if err := envconfig.Process(service, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
@ -243,7 +244,7 @@ func main() {
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)

View File

@ -40,12 +40,17 @@ import (
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
// service is the name of the program used for logging, tracing and the
// the prefix used for loading env variables
// ie: export WEB_APP_ENV=dev
var service = "WEB_APP"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, "WEB_APP : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
@ -125,12 +130,8 @@ func main() {
CMD string `envconfig:"CMD"`
}
// The prefix used for loading env variables.
// ie: export WEB_APP_ENV=dev
envKeyPrefix := "WEB_APP"
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(envKeyPrefix, &cfg); err != nil {
if err := envconfig.Process(service, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
@ -258,7 +259,7 @@ func main() {
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName("my-service"))
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
@ -522,7 +523,7 @@ func main() {
shutdown := make(chan os.Signal, 1)
signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM)
api := http.Server{
app := http.Server{
Addr: cfg.HTTP.Host,
Handler: handlers.APP(shutdown, log, cfg.App.StaticDir, cfg.App.TemplateDir, masterDb, nil, renderer),
ReadTimeout: cfg.HTTP.ReadTimeout,
@ -537,7 +538,7 @@ func main() {
// Start the service listening for requests.
go func() {
log.Printf("main : APP Listening %s", cfg.HTTP.Host)
serverErrors <- api.ListenAndServe()
serverErrors <- app.ListenAndServe()
}()
// =========================================================================
@ -556,10 +557,10 @@ func main() {
defer cancel()
// Asking listener to shutdown and load shed.
err := api.Shutdown(ctx)
err := app.Shutdown(ctx)
if err != nil {
log.Printf("main : Graceful shutdown did not complete in %v : %v", cfg.App.ShutdownTimeout, err)
err = api.Close()
err = app.Close()
}
// Log the status of this shutdown.

View File

@ -5,6 +5,9 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dimfeld/httptreemux v5.0.1+incompatible
github.com/dustin/go-humanize v1.0.0
github.com/fatih/camelcase v1.0.0
github.com/fatih/structtag v1.0.0
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db
github.com/go-playground/locales v0.12.1
github.com/go-playground/universal-translator v0.16.0
@ -12,6 +15,7 @@ require (
github.com/golang/protobuf v1.3.1 // indirect
github.com/google/go-cmp v0.2.0
github.com/huandu/go-sqlbuilder v1.4.0
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365
github.com/jmoiron/sqlx v1.2.0
github.com/kelseyhightower/envconfig v1.3.0
github.com/kr/pretty v0.1.0 // indirect
@ -19,13 +23,16 @@ require (
github.com/lib/pq v1.1.2-0.20190507191818-2ff3cb3adc01
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/onsi/ginkgo v1.8.0 // indirect
github.com/onsi/gomega v1.5.0 // indirect
github.com/onsi/gomega v1.5.0
github.com/opentracing/opentracing-go v1.1.0 // indirect
github.com/pborman/uuid v0.0.0-20180122190007-c65b2f87fee3
github.com/philhofer/fwd v1.0.0 // indirect
github.com/pkg/errors v0.8.1
github.com/sergi/go-diff v1.0.0
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0
github.com/tinylib/msgp v1.1.0 // indirect
github.com/urfave/cli v1.20.0
github.com/uudashr/go-module v0.0.0-20180827225833-c0ca9c3a4966 // indirect
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f
golang.org/x/net v0.0.0-20190522155817-f3200d17e092 // indirect
golang.org/x/sys v0.0.0-20190526052359-791d8a0f4d09 // indirect

View File

@ -8,6 +8,12 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dimfeld/httptreemux v5.0.1+incompatible h1:Qj3gVcDNoOthBAqftuD596rm4wg/adLLz5xh5CmpiCA=
github.com/dimfeld/httptreemux v5.0.1+incompatible/go.mod h1:rbUlSV+CCpv/SuqUTP/8Bk2O3LyUV436/yaRGkhP6Z0=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/fatih/camelcase v1.0.0 h1:hxNvNX/xYBp0ovncs8WyWZrOrpBNub/JfaMvbURyft8=
github.com/fatih/camelcase v1.0.0/go.mod h1:yN2Sb0lFhZJUdVvtELVWefmrXpuZESvPmqwoZc+/fpc=
github.com/fatih/structtag v1.0.0 h1:pTHj65+u3RKWYPSGaU290FpI/dXxTaHdVwVwbcPKmEc=
github.com/fatih/structtag v1.0.0/go.mod h1:IKitwq45uXL/yqi5mYghiD3w9H6eTOvI9vnk8tXMphA=
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/geeks-accelerator/sqlxmigrate v0.0.0-20190527223850-4a863a2d30db h1:mjErP7mTFHQ3cw/ibAkW3CvQ8gM4k19EkfzRzRINDAE=
@ -30,6 +36,8 @@ github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/go-sqlbuilder v1.4.0 h1:2LIlTDOz63lOETLOIiKBPEu4PUbikmS5LUc3EekwYqM=
github.com/huandu/go-sqlbuilder v1.4.0/go.mod h1:mYfGcZTUS6yJsahUQ3imkYSkGGT3A+owd54+79kkW+U=
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365 h1:ECW73yc9MY7935nNYXUkK7Dz17YuSUI9yqRqYS8aBww=
github.com/iancoleman/strcase v0.0.0-20190422225806-e506e3ef7365/go.mod h1:SK73tn/9oHe+/Y0h39VT4UCxmurVJkR5NA7kMEAOgSE=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
@ -69,6 +77,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0 h1:X9XMOYjxEfAYSy3xK1DzO5dMkkWhs9E9UCcS1IERx2k=
github.com/sethgrid/pester v0.0.0-20190127155807-68a33a018ad0/go.mod h1:Ad7IjTpvzZO8Fl0vh9AzQ+j/jYZfyp2diGwI8m5q+ns=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -76,6 +86,10 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/tinylib/msgp v1.1.0 h1:9fQd+ICuRIu/ue4vxJZu6/LzxN0HwMds2nq/0cFvxHU=
github.com/tinylib/msgp v1.1.0/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/urfave/cli v1.20.0 h1:fDqGv3UG/4jbVl/QkFwEdddtEDjh/5Ov6X+0B/3bPaw=
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/uudashr/go-module v0.0.0-20180827225833-c0ca9c3a4966 h1:7dS/ZO0dIwrtj/FGTt9I6urVpx7LEHzucegv4ORYK3M=
github.com/uudashr/go-module v0.0.0-20180827225833-c0ca9c3a4966/go.mod h1:P6Nk1sQWL6jcdBIxnLVlqCsOl0arao7gg7sPoM6gx4A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f h1:R423Cnkcp5JABoeemiGEPlt9tHXFfw5kvc0yqlxRPWo=
golang.org/x/crypto v0.0.0-20190513172903-22d7a77e9e5f/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=

View File

@ -1,43 +1,96 @@
package project
import (
"database/sql/driver"
"time"
"gopkg.in/mgo.v2/bson"
"github.com/lib/pq"
"github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9"
)
// Project is an item we sell.
// Project represents a workflow.
type Project struct {
ID bson.ObjectId `bson:"_id" json:"id"` // Unique identifier.
Name string `bson:"name" json:"name"` // Display name of the project.
Cost int `bson:"cost" json:"cost"` // Price for one item in cents.
Quantity int `bson:"quantity" json:"quantity"` // Original number of items available.
DateCreated time.Time `bson:"date_created" json:"date_created"` // When the project was added.
DateModified time.Time `bson:"date_modified" json:"date_modified"` // When the project record was lost modified.
ID string `json:"id" validate:"required,uuid"`
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
Name string `json:"name" validate:"required"`
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
CreatedAt time.Time `json:"created_at" truss:"api-read"`
UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
}
// NewProject is what we require from clients when adding a Project.
type NewProject struct {
Name string `json:"name" validate:"required"`
Cost int `json:"cost" validate:"required,gte=0"`
Quantity int `json:"quantity" validate:"required,gte=1"`
// CreateProjectRequest contains information needed to create a new Project.
type ProjectCreateRequest struct {
AccountID string `json:"account_id" validate:"required,uuid"`
Name string `json:"name" validate:"required"`
Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
}
// UpdateProject defines what information may be provided to modify an
// existing Project. All fields are optional so clients can send just the
// fields they want changed. It uses pointer fields so we can differentiate
// between a field that was not provided and a field that was provided as
// explicitly blank. Normally we do not want to use pointers to basic types but
// we make exceptions around marshalling/unmarshalling.
type UpdateProject struct {
Name *string `json:"name"`
Cost *int `json:"cost" validate:"omitempty,gte=0"`
Quantity *int `json:"quantity" validate:"omitempty,gte=1"`
// UpdateProjectRequest defines what information may be provided to modify an existing
// Project. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type ProjectUpdateRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty"`
Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
}
// Sale represents a transaction where we sold some quantity of a
// Project.
type Sale struct{}
// ProjectFindRequest defines the possible options to search for projects. By default
// archived projects will be excluded from response.
type ProjectFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// NewSale defines what we require when creating a Sale record.
type NewSale struct{}
// ProjectStatus represents the status of an project.
type ProjectStatus string
// ProjectStatus values define the status field of a user project.
const (
// ProjectStatus_Active defines the state when a user can access an project.
ProjectStatus_Active ProjectStatus = "active"
// ProjectStatus_Disabled defines the state when a user has been disabled from
// accessing an project.
ProjectStatus_Disabled ProjectStatus = "disabled"
)
// ProjectStatus_Values provides list of valid ProjectStatus values.
var ProjectStatus_Values = []ProjectStatus{
ProjectStatus_Active,
ProjectStatus_Disabled,
}
// Scan supports reading the ProjectStatus value from the database.
func (s *ProjectStatus) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = ProjectStatus(string(asBytes))
return nil
}
// Value converts the ProjectStatus value to be stored in the database.
func (s ProjectStatus) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof=active disabled")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the ProjectStatus value to a string.
func (s ProjectStatus) String() string {
return string(s)
}

View File

@ -1,162 +0,0 @@
package project
import (
"context"
"fmt"
"time"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
const projectsCollection = "projects"
var (
// ErrNotFound abstracts the mgo not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
)
// List retrieves a list of existing projects from the database.
func List(ctx context.Context, dbConn *sqlx.DB) ([]Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.List")
defer span.Finish()
p := []Project{}
f := func(collection *mgo.Collection) error {
return collection.Find(nil).All(&p)
}
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
return nil, errors.Wrap(err, "db.projects.find()")
}
return p, nil
}
// Retrieve gets the specified project from the database.
func Retrieve(ctx context.Context, dbConn *sqlx.DB, id string) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Retrieve")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return nil, ErrInvalidID
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
var p *Project
f := func(collection *mgo.Collection) error {
return collection.Find(q).One(&p)
}
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return nil, ErrNotFound
}
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.find(%s)", q))
}
return p, nil
}
// Create inserts a new project into the database.
func Create(ctx context.Context, dbConn *sqlx.DB, cp *NewProject, now time.Time) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create")
defer span.Finish()
// Mongo truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
p := Project{
ID: bson.NewObjectId(),
Name: cp.Name,
Cost: cp.Cost,
Quantity: cp.Quantity,
DateCreated: now,
DateModified: now,
}
f := func(collection *mgo.Collection) error {
return collection.Insert(&p)
}
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("db.projects.insert(%v)", &p))
}
return &p, nil
}
// Update replaces a project document in the database.
func Update(ctx context.Context, dbConn *sqlx.DB, id string, upd UpdateProject, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
fields := make(bson.M)
if upd.Name != nil {
fields["name"] = *upd.Name
}
if upd.Cost != nil {
fields["cost"] = *upd.Cost
}
if upd.Quantity != nil {
fields["quantity"] = *upd.Quantity
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
fields["date_modified"] = now
m := bson.M{"$set": fields}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Update(q, m)
}
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.customers.update(%s, %s)", q, m))
}
return nil
}
// Delete removes a project from the database.
func Delete(ctx context.Context, dbConn *sqlx.DB, id string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
defer span.Finish()
if !bson.IsObjectIdHex(id) {
return ErrInvalidID
}
q := bson.M{"_id": bson.ObjectIdHex(id)}
f := func(collection *mgo.Collection) error {
return collection.Remove(q)
}
if _, err := dbConn.ExecContext(ctx, projectsCollection, f); err != nil {
if err == mgo.ErrNotFound {
return ErrNotFound
}
return errors.Wrap(err, fmt.Sprintf("db.projects.remove(%v)", q))
}
return nil
}

View File

@ -1,129 +0,0 @@
package project_test
import (
"os"
"testing"
"time"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/project"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestProject validates the full set of CRUD operations on Project values.
func TestProject(t *testing.T) {
defer tests.Recover(t)
t.Log("Given the need to work with Project records.")
{
t.Log("\tWhen handling a single Project.")
{
ctx := tests.Context()
dbConn := test.MasterDB.Copy()
defer dbConn.Close()
np := project.NewProject{
Name: "Comic Books",
Cost: 25,
Quantity: 60,
}
p, err := project.Create(ctx, dbConn, &np, time.Now().UTC())
if err != nil {
t.Fatalf("\t%s\tShould be able to create a project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to create a project.", tests.Success)
savedP, err := project.Retrieve(ctx, dbConn, p.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve project by ID: %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve project by ID.", tests.Success)
if diff := cmp.Diff(p, savedP); diff != "" {
t.Fatalf("\t%s\tShould get back the same project. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get back the same project.", tests.Success)
upd := project.UpdateProject{
Name: tests.StringPointer("Comics"),
Cost: tests.IntPointer(50),
Quantity: tests.IntPointer(40),
}
if err := project.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
t.Fatalf("\t%s\tShould be able to update project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to update project.", tests.Success)
savedP, err = project.Retrieve(ctx, dbConn, p.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve updated project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve updated project.", tests.Success)
// Build a project matching what we expect to see. We just use the
// modified time from the database.
want := &project.Project{
ID: p.ID,
Name: *upd.Name,
Cost: *upd.Cost,
Quantity: *upd.Quantity,
DateCreated: p.DateCreated,
DateModified: savedP.DateModified,
}
if diff := cmp.Diff(want, savedP); diff != "" {
t.Fatalf("\t%s\tShould get back the same project. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tShould get back the same project.", tests.Success)
upd = project.UpdateProject{
Name: tests.StringPointer("Graphic Novels"),
}
if err := project.Update(ctx, dbConn, p.ID.Hex(), upd, time.Now().UTC()); err != nil {
t.Fatalf("\t%s\tShould be able to update just some fields of project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to update just some fields of project.", tests.Success)
savedP, err = project.Retrieve(ctx, dbConn, p.ID.Hex())
if err != nil {
t.Fatalf("\t%s\tShould be able to retrieve updated project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to retrieve updated project.", tests.Success)
if savedP.Name != *upd.Name {
t.Fatalf("\t%s\tShould be able to see updated Name field : got %q want %q.", tests.Failed, savedP.Name, *upd.Name)
} else {
t.Logf("\t%s\tShould be able to see updated Name field.", tests.Success)
}
if err := project.Delete(ctx, dbConn, p.ID.Hex()); err != nil {
t.Fatalf("\t%s\tShould be able to delete project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould be able to delete project.", tests.Success)
savedP, err = project.Retrieve(ctx, dbConn, p.ID.Hex())
if errors.Cause(err) != project.ErrNotFound {
t.Fatalf("\t%s\tShould NOT be able to retrieve deleted project : %s.", tests.Failed, err)
}
t.Logf("\t%s\tShould NOT be able to retrieve deleted project.", tests.Success)
}
}
}

View File

@ -65,8 +65,8 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
zipcode varchar(20) NOT NULL DEFAULT '',
status account_status_t NOT NULL DEFAULT 'active',
timezone varchar(128) NOT NULL DEFAULT 'America/Anchorage',
signup_user_id char(36) DEFAULT NULL,
billing_user_id char(36) DEFAULT NULL,
signup_user_id char(36) DEFAULT NULL REFERENCES users(id),
billing_user_id char(36) DEFAULT NULL REFERENCES users(id),
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
@ -107,8 +107,8 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
q3 := `CREATE TABLE IF NOT EXISTS users_accounts (
id char(36) NOT NULL,
account_id char(36) NOT NULL,
user_id char(36) NOT NULL,
account_id char(36) NOT NULL REFERENCES accounts(id),
user_id char(36) NOT NULL REFERENCES users(id),
roles user_account_role_t[] NOT NULL,
status user_account_status_t NOT NULL DEFAULT 'active',
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
@ -142,5 +142,42 @@ func migrationList(db *sqlx.DB, log *log.Logger) []*sqlxmigrate.Migration {
return nil
},
},
// create new table projects
{
ID: "20190622-01",
Migrate: func(tx *sql.Tx) error {
q1 := `CREATE TYPE project_status_t as enum('active','disabled')`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
q2 := `CREATE TABLE IF NOT EXISTS projects (
id char(36) NOT NULL,
account_id char(36) NOT NULL REFERENCES accounts(id),
name varchar(255) NOT NULL,
status project_status_t NOT NULL DEFAULT 'active',
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
archived_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
PRIMARY KEY (id)
)`
if _, err := tx.Exec(q2); err != nil {
return errors.WithMessagef(err, "Query failed %s", q2)
}
return nil
},
Rollback: func(tx *sql.Tx) error {
q1 := `DROP TYPE project_status_t`
if _, err := tx.Exec(q1); err != nil {
return errors.WithMessagef(err, "Query failed %s", q1)
}
q2 := `DROP TABLE IF EXISTS projects`
if _, err := tx.Exec(q2); err != nil {
return errors.WithMessagef(err, "Query failed %s", q2)
}
return nil
},
},
}
}

View File

@ -0,0 +1 @@
truss

View File

@ -0,0 +1,33 @@
# SaaS Truss
Copyright 2019, Geeks Accelerator
accelerator@geeksinthewoods.com.com
## Description
Truss provides code generation to reduce copy/pasting.
## Local Installation
### Build
```bash
go build .
```
### Usage
```bash
./truss -h
Usage of ./truss
--cmd string <dbtable2crud>
--db_host string <127.0.0.1:5433>
--db_user string <postgres>
--db_pass string <postgres>
--db_database string <shared>
--db_driver string <postgres>
--db_timezone string <utc>
--db_disabletls bool <false>
```

View File

@ -0,0 +1,149 @@
package dbtable2crud
import (
"fmt"
"strings"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/pkg/errors"
)
type psqlColumn struct {
Table string
Column string
ColumnId int64
NotNull bool
DataTypeFull string
DataTypeName string
DataTypeLength *int
NumericPrecision *int
NumericScale *int
IsPrimaryKey bool
PrimaryKeyName *string
IsUniqueKey bool
UniqueKeyName *string
IsForeignKey bool
ForeignKeyName *string
ForeignKeyColumnId pq.Int64Array
ForeignKeyTable *string
ForeignKeyLocalColumnId pq.Int64Array
DefaultFull *string
DefaultValue *string
IsEnum bool
EnumTypeId *string
EnumValues []string
}
// descTable lists all the columns for a table.
func descTable(db *sqlx.DB, dbName, dbTable string) ([]psqlColumn, error) {
queryStr := fmt.Sprintf(`SELECT
c.relname as table,
f.attname as column,
f.attnum as columnId,
f.attnotnull as not_null,
pg_catalog.format_type(f.atttypid,f.atttypmod) AS data_type_full,
t.typname AS data_type_name,
CASE WHEN f.atttypmod >= 0 AND t.typname <> 'numeric'THEN (f.atttypmod - 4) --first 4 bytes are for storing actual length of data
END AS data_type_length,
CASE WHEN t.typname = 'numeric' THEN (((f.atttypmod - 4) >> 16) & 65535)
END AS numeric_precision,
CASE WHEN t.typname = 'numeric' THEN ((f.atttypmod - 4)& 65535 )
END AS numeric_scale,
CASE WHEN p.contype = 'p' THEN true ELSE false
END AS is_primary_key,
CASE WHEN p.contype = 'p' THEN p.conname
END AS primary_key_name,
CASE WHEN p.contype = 'u' THEN true ELSE false
END AS is_unique_key,
CASE WHEN p.contype = 'u' THEN p.conname
END AS unique_key_name,
CASE WHEN p.contype = 'f' THEN true ELSE false
END AS is_foreign_key,
CASE WHEN p.contype = 'f' THEN p.conname
END AS foreignkey_name,
CASE WHEN p.contype = 'f' THEN p.confkey
END AS foreign_key_columnid,
CASE WHEN p.contype = 'f' THEN g.relname
END AS foreign_key_table,
CASE WHEN p.contype = 'f' THEN p.conkey
END AS foreign_key_local_column_id,
CASE WHEN f.atthasdef = 't' THEN d.adsrc
END AS default_value,
CASE WHEN t.typtype = 'e' THEN true ELSE false
END AS is_enum,
CASE WHEN t.typtype = 'e' THEN t.oid
END AS enum_type_id
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid
JOIN pg_type t ON t.oid = f.atttypid
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
WHERE c.relkind = 'r'::char
AND f.attisdropped = false
AND c.relname = '%s'
AND f.attnum > 0
ORDER BY f.attnum
;`, dbTable) // AND n.nspname = '%s'
rows, err := db.Query(queryStr)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
// iterate over each row
var resp []psqlColumn
for rows.Next() {
var c psqlColumn
err = rows.Scan(&c.Table, &c.Column, &c.ColumnId, &c.NotNull, &c.DataTypeFull, &c.DataTypeName, &c.DataTypeLength, &c.NumericPrecision, &c.NumericScale, &c.IsPrimaryKey, &c.PrimaryKeyName, &c.IsUniqueKey, &c.UniqueKeyName, &c.IsForeignKey, &c.ForeignKeyName, &c.ForeignKeyColumnId, &c.ForeignKeyTable, &c.ForeignKeyLocalColumnId, &c.DefaultFull, &c.IsEnum, &c.EnumTypeId)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
if c.DefaultFull != nil {
defaultValue := *c.DefaultFull
// "'active'::project_status_t"
defaultValue = strings.Split(defaultValue, "::")[0]
c.DefaultValue = &defaultValue
}
resp = append(resp, c)
}
for colIdx, dbCol := range resp {
if !dbCol.IsEnum {
continue
}
queryStr := fmt.Sprintf(`SELECT e.enumlabel
FROM pg_enum AS e
WHERE e.enumtypid = '%s'
ORDER BY e.enumsortorder`, *dbCol.EnumTypeId)
rows, err := db.Query(queryStr)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
for rows.Next() {
var v string
err = rows.Scan(&v)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
dbCol.EnumValues = append(dbCol.EnumValues, v)
}
resp[colIdx] = dbCol
}
return resp, nil
}

View File

@ -0,0 +1,378 @@
package dbtable2crud
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
"github.com/dustin/go-humanize/english"
"github.com/fatih/camelcase"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/sergi/go-diff/diffmatchpatch"
)
// Run in the main entry point for the dbtable2crud cmd.
func Run(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir, goSrcPath string) error {
log.SetPrefix(log.Prefix() + " : dbtable2crud")
// Ensure the schema is up to date
if err := schema.Migrate(db, log); err != nil {
return err
}
// When dbTable is empty, lower case the model name
if dbTable == "" {
dbTable = strings.Join(camelcase.Split(modelName), " ")
dbTable = english.PluralWord(2, dbTable, "")
dbTable = strings.Replace(dbTable, " ", "_", -1)
dbTable = strings.ToLower(dbTable)
}
// Parse the model file and load the specified model struct.
model, err := parseModelFile(db, log, dbName, dbTable, modelFile, modelName)
if err != nil {
return err
}
// Basic lint of the model struct.
err = validateModel(log, model)
if err != nil {
return err
}
tmplData := map[string]interface{}{
"GoSrcPath": goSrcPath,
}
// Update the model file with new or updated code.
err = updateModel(log, model, templateDir, tmplData)
if err != nil {
return err
}
// Update the model crud file with new or updated code.
err = updateModelCrud(db, log, dbName, dbTable, modelFile, modelName, templateDir, model, tmplData)
if err != nil {
return err
}
return nil
}
// validateModel performs a basic lint of the model struct to ensure
// code gen output is correct.
func validateModel(log *log.Logger, model *modelDef) error {
for _, sf := range model.Fields {
if sf.DbColumn == nil && sf.ColumnName != "-" {
log.Printf("validateStruct : Unable to find struct field for db column %s\n", sf.ColumnName)
}
var expectedType string
switch sf.FieldName {
case "ID":
expectedType = "string"
case "CreatedAt":
expectedType = "time.Time"
case "UpdatedAt":
expectedType = "time.Time"
case "ArchivedAt":
expectedType = "pq.NullTime"
}
if expectedType != "" && expectedType != sf.FieldType {
log.Printf("validateStruct : Struct field %s should be of type %s not %s\n", sf.FieldName, expectedType, sf.FieldType)
}
}
return nil
}
// updateModel updated the parsed code file with the new code.
func updateModel(log *log.Logger, model *modelDef, templateDir string, tmplData map[string]interface{}) error {
// Execute template and parse code to be used to compare against modelFile.
tmplObjs, err := loadTemplateObjects(log, model, templateDir, "models.tmpl", tmplData)
if err != nil {
return err
}
// Store the current code as a string to produce a diff.
curCode := model.String()
objHeaders := []*goparse.GoObject{}
for _, obj := range tmplObjs {
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
objHeaders = append(objHeaders, obj)
continue
}
if model.HasType(obj.Name, obj.Type) {
cur := model.Objects().Get(obj.Name, obj.Type)
newObjs := []*goparse.GoObject{}
if len(objHeaders) > 0 {
// Remove any comments and linebreaks before the existing object so updates can be added.
removeObjs := []*goparse.GoObject{}
for idx := cur.Index - 1; idx > 0; idx-- {
prevObj := model.Objects().List()[idx]
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
removeObjs = append(removeObjs, prevObj)
} else {
break
}
}
if len(removeObjs) > 0 {
err := model.Objects().Remove(removeObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
return err
}
// Make sure the current index is correct.
cur = model.Objects().Get(obj.Name, obj.Type)
}
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
newObjs = append(newObjs, c)
}
}
newObjs = append(newObjs, obj)
// Do the object replacement.
err := model.Objects().Replace(cur, newObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
return err
}
} else {
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
err := model.Objects().Add(c)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, model.Name)
return err
}
}
err := model.Objects().Add(obj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, model.Name)
return err
}
}
objHeaders = []*goparse.GoObject{}
}
// Set some flags to determine additional imports and need to be added.
var hasEnum bool
var hasPq bool
for _, f := range model.Fields {
if f.DbColumn != nil && f.DbColumn.IsEnum {
hasEnum = true
}
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
hasPq = true
}
}
reqImports := []string{}
if hasEnum {
reqImports = append(reqImports, "database/sql/driver")
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
reqImports = append(reqImports, "github.com/pkg/errors")
}
if hasPq {
reqImports = append(reqImports, "github.com/lib/pq")
}
for _, in := range reqImports {
err := model.AddImport(goparse.GoImport{Name: in})
if err != nil {
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, model.Name)
return err
}
}
// Produce a diff after the updates have been applied.
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(curCode, model.String(), true)
fmt.Println(dmp.DiffPrettyText(diffs))
return nil
}
// updateModelCrud updated the parsed code file with the new code.
func updateModelCrud(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir string, baseModel *modelDef, tmplData map[string]interface{}) error {
modelDir := filepath.Dir(modelFile)
crudFile := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+".go")
var crudDoc *goparse.GoDocument
if _, err := os.Stat(crudFile); os.IsNotExist(err) {
crudDoc, err = goparse.NewGoDocument(baseModel.Package)
if err != nil {
return err
}
} else {
// Parse the supplied model file.
crudDoc, err = goparse.ParseFile(log, modelFile)
if err != nil {
return err
}
}
// Load all the updated struct fields from the base model file.
structFields := make(map[string]map[string]modelField)
for _, obj := range baseModel.GoDocument.Objects().List() {
if obj.Type != goparse.GoObjectType_Struct || obj.Name == baseModel.Name {
continue
}
objFields, err := parseModelFields(baseModel.GoDocument, obj.Name, baseModel)
if err != nil {
return err
}
structFields[obj.Name] = make(map[string]modelField)
for _, f := range objFields {
structFields[obj.Name][f.FieldName] = f
}
}
// Append the struct fields to be used for template execution.
if tmplData == nil {
tmplData = make(map[string]interface{})
}
tmplData["StructFields"] = structFields
// Execute template and parse code to be used to compare against modelFile.
tmplObjs, err := loadTemplateObjects(log, baseModel, templateDir, "model_crud.tmpl", tmplData)
if err != nil {
return err
}
// Store the current code as a string to produce a diff.
curCode := crudDoc.String()
objHeaders := []*goparse.GoObject{}
for _, obj := range tmplObjs {
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
objHeaders = append(objHeaders, obj)
continue
}
if crudDoc.HasType(obj.Name, obj.Type) {
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
newObjs := []*goparse.GoObject{}
if len(objHeaders) > 0 {
// Remove any comments and linebreaks before the existing object so updates can be added.
removeObjs := []*goparse.GoObject{}
for idx := cur.Index - 1; idx > 0; idx-- {
prevObj := crudDoc.Objects().List()[idx]
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
removeObjs = append(removeObjs, prevObj)
} else {
break
}
}
if len(removeObjs) > 0 {
err := crudDoc.Objects().Remove(removeObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
// Make sure the current index is correct.
cur = crudDoc.Objects().Get(obj.Name, obj.Type)
}
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
newObjs = append(newObjs, c)
}
}
newObjs = append(newObjs, obj)
// Do the object replacement.
err := crudDoc.Objects().Replace(cur, newObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
} else {
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
err := crudDoc.Objects().Add(c)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
return err
}
}
err := crudDoc.Objects().Add(obj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
}
objHeaders = []*goparse.GoObject{}
}
/*
// Set some flags to determine additional imports and need to be added.
var hasEnum bool
var hasPq bool
for _, f := range crudModel.Fields {
if f.DbColumn != nil && f.DbColumn.IsEnum {
hasEnum = true
}
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
hasPq = true
}
}
reqImports := []string{}
if hasEnum {
reqImports = append(reqImports, "database/sql/driver")
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
reqImports = append(reqImports, "github.com/pkg/errors")
}
if hasPq {
reqImports = append(reqImports, "github.com/lib/pq")
}
for _, in := range reqImports {
err := model.AddImport(goparse.GoImport{Name: in})
if err != nil {
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, crudModel.Name)
return err
}
}
*/
// Produce a diff after the updates have been applied.
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(curCode, crudDoc.String(), true)
fmt.Println(dmp.DiffPrettyText(diffs))
return nil
}

View File

@ -0,0 +1,229 @@
package dbtable2crud
import (
"log"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
"github.com/fatih/structtag"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)
// modelDef defines info about the struct and associated db table.
type modelDef struct {
*goparse.GoDocument
Name string
TableName string
PrimaryField string
PrimaryColumn string
PrimaryType string
Fields []modelField
FieldNames []string
ColumnNames []string
}
// modelField defines a struct field and associated db column.
type modelField struct {
ColumnName string
DbColumn *psqlColumn
FieldName string
FieldType string
FieldIsPtr bool
Tags *structtag.Tags
ApiHide bool
ApiRead bool
ApiCreate bool
ApiUpdate bool
DefaultValue string
}
// parseModelFile parses the entire model file and then loads the specified model struct.
func parseModelFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName string) (*modelDef, error) {
// Parse the supplied model file.
doc, err := goparse.ParseFile(log, modelFile)
if err != nil {
return nil, err
}
// Init new modelDef.
model := &modelDef{
GoDocument: doc,
Name: modelName,
TableName: dbTable,
}
// Append the field the the model def.
model.Fields, err = parseModelFields(doc, modelName, nil)
if err != nil {
return nil, err
}
for _, sf := range model.Fields {
model.FieldNames = append(model.FieldNames, sf.FieldName)
model.ColumnNames = append(model.ColumnNames, sf.ColumnName)
}
// Query the database for a table definition.
dbCols, err := descTable(db, dbName, dbTable)
if err != nil {
return model, err
}
// Loop over all the database table columns and append to the associated
// struct field. Don't force all database table columns to be defined in the
// in the struct.
for _, dbCol := range dbCols {
for idx, sf := range model.Fields {
if sf.ColumnName != dbCol.Column {
continue
}
if dbCol.IsPrimaryKey {
model.PrimaryColumn = sf.ColumnName
model.PrimaryField = sf.FieldName
model.PrimaryType = sf.FieldType
}
if dbCol.DefaultValue != nil {
sf.DefaultValue = *dbCol.DefaultValue
if dbCol.IsEnum {
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
sf.DefaultValue = sf.FieldType + "_" + FormatCamel(sf.DefaultValue)
} else if strings.HasPrefix(sf.DefaultValue, "'") {
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
sf.DefaultValue = "\"" + sf.DefaultValue + "\""
}
}
c := dbCol
sf.DbColumn = &c
model.Fields[idx] = sf
}
}
// Print out the model for debugging.
//modelJSON, err := json.MarshalIndent(model, "", " ")
//if err != nil {
// return model, errors.WithStack(err )
//}
//log.Printf(string(modelJSON))
return model, nil
}
// parseModelFields parses the fields from a struct.
func parseModelFields(doc *goparse.GoDocument, modelName string, baseModel *modelDef) ([]modelField, error) {
// Ensure the model file has a struct with the model name supplied.
if !doc.HasType(modelName, goparse.GoObjectType_Struct) {
err := errors.Errorf("Struct with the name %s does not exist", modelName)
return nil, err
}
// Load the struct from parsed go file.
docModel := doc.Get(modelName, goparse.GoObjectType_Struct)
// Loop over all the objects contained between the struct definition start and end.
// This should be a list of variables defined for model.
resp := []modelField{}
for _, l := range docModel.Objects().List() {
// Skip all lines that are not a var.
if l.Type != goparse.GoObjectType_Line {
log.Printf("parseModelFile : Model %s has line that is %s, not type line, skipping - %s\n", modelName, l.Type, l.String())
continue
}
// Extract the var name, type and defined tags from the line.
sv, err := goparse.ParseStructProp(l)
if err != nil {
return nil, err
}
// Init new modelField for the struct var.
sf := modelField{
FieldName: sv.Name,
FieldType: sv.Type,
FieldIsPtr: strings.HasPrefix(sv.Type, "*"),
Tags: sv.Tags,
}
// Extract the column name from the var tags.
if sf.Tags != nil {
// First try to get the column name from the db tag.
dbt, err := sf.Tags.Get("db")
if err != nil && !strings.Contains(err.Error(), "not exist") {
err = errors.WithStack(err)
return nil, err
} else if dbt != nil {
sf.ColumnName = dbt.Name
}
// Second try to get the column name from the json tag.
if sf.ColumnName == "" {
jt, err := sf.Tags.Get("json")
if err != nil && !strings.Contains(err.Error(), "not exist") {
err = errors.WithStack(err)
return nil, err
} else if jt != nil && jt.Name != "-" {
sf.ColumnName = jt.Name
}
}
var apiActionsSet bool
tt, err := sf.Tags.Get("truss")
if err != nil && !strings.Contains(err.Error(), "not exist") {
err = errors.WithStack(err)
return nil, err
} else if tt != nil {
if tt.Name == "api-create" || tt.HasOption("api-create") {
sf.ApiCreate = true
apiActionsSet = true
}
if tt.Name == "api-read" || tt.HasOption("api-read") {
sf.ApiRead = true
apiActionsSet = true
}
if tt.Name == "api-update" || tt.HasOption("api-update") {
sf.ApiUpdate = true
apiActionsSet = true
}
if tt.Name == "api-hide" || tt.HasOption("api-hide") {
sf.ApiHide = true
apiActionsSet = true
}
}
if !apiActionsSet {
sf.ApiCreate = true
sf.ApiRead = true
sf.ApiUpdate = true
}
}
// Set the column name to the field name if empty and does not equal '-'.
if sf.ColumnName == "" {
sf.ColumnName = sf.FieldName
}
// If a base model as already been parsed with the db columns,
// append to the current field.
if baseModel != nil {
for _, baseSf := range baseModel.Fields {
if baseSf.ColumnName == sf.ColumnName {
sf.DefaultValue = baseSf.DefaultValue
sf.DbColumn = baseSf.DbColumn
break
}
}
}
// Append the field the the model def.
resp = append(resp, sf)
}
return resp, nil
}

View File

@ -0,0 +1,345 @@
package dbtable2crud
import (
"bufio"
"bytes"
"fmt"
"go/format"
"io/ioutil"
"log"
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
"github.com/dustin/go-humanize/english"
"github.com/fatih/camelcase"
"github.com/iancoleman/strcase"
"github.com/pkg/errors"
)
// loadTemplateObjects executes a template file based on the given model struct and
// returns the parsed go objects.
func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename string, tmptData map[string]interface{}) ([]*goparse.GoObject, error) {
// Data used to execute all the of defined code sections in the template file.
if tmptData == nil {
tmptData = make(map[string]interface{})
}
tmptData["Model"] = model
// geeks-accelerator/oss/saas-starter-kit/example-project
// Read the template file from the local file system.
tempFilePath := filepath.Join(templateDir, filename)
dat, err := ioutil.ReadFile(tempFilePath)
if err != nil {
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
return nil, err
}
// New template with custom functions.
baseTmpl := template.New("base")
baseTmpl.Funcs(template.FuncMap{
"Concat": func(vals ...string) string {
return strings.Join(vals, "")
},
"JoinStrings": func(vals []string, sep string) string {
return strings.Join(vals, sep)
},
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
l := []string{}
for _, v := range vals {
l = append(l, pre+v)
}
return strings.Join(l, sep)
},
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
l := []string{}
for _, v := range vals {
l = append(l, fmt.Sprintf(fmtStr, v))
}
return strings.Join(l, sep)
},
"FormatCamel": func(name string) string {
return FormatCamel(name)
},
"FormatCamelTitle": func(name string) string {
return FormatCamelTitle(name)
},
"FormatCamelLower": func(name string) string {
if name == "ID" {
return "id"
}
return FormatCamelLower(name)
},
"FormatCamelLowerTitle": func(name string) string {
return FormatCamelLowerTitle(name)
},
"FormatCamelPluralTitle": func(name string) string {
return FormatCamelPluralTitle(name)
},
"FormatCamelPluralTitleLower": func(name string) string {
return FormatCamelPluralTitleLower(name)
},
"FormatCamelPluralCamel": func(name string) string {
return FormatCamelPluralCamel(name)
},
"FormatCamelPluralLower": func(name string) string {
return FormatCamelPluralLower(name)
},
"FormatCamelPluralUnderscore": func(name string) string {
return FormatCamelPluralUnderscore(name)
},
"FormatCamelPluralLowerUnderscore": func(name string) string {
return FormatCamelPluralLowerUnderscore(name)
},
"FormatCamelUnderscore": func(name string) string {
return FormatCamelUnderscore(name)
},
"FormatCamelLowerUnderscore": func(name string) string {
return FormatCamelLowerUnderscore(name)
},
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
if f.Tags == nil {
return false
}
ft, err := f.Tags.Get(tagName)
if ft == nil || err != nil {
return false
}
if ft.Name == optName || ft.HasOption(optName) {
return true
}
return false
},
"FieldTag": func(f modelField, tagName string) string {
if f.Tags == nil {
return ""
}
ft, err := f.Tags.Get(tagName)
if ft == nil || err != nil {
return ""
}
return ft.String()
},
"FieldTagReplaceOrPrepend": func(f modelField, tagName, oldVal, newVal string) string {
if f.Tags == nil {
return ""
}
ft, err := f.Tags.Get(tagName)
if ft == nil || err != nil {
return ""
}
if ft.Name == oldVal || ft.Name == newVal {
ft.Name = newVal
} else if ft.HasOption(oldVal) {
for idx, val := range ft.Options {
if val == oldVal {
ft.Options[idx] = newVal
}
}
} else if !ft.HasOption(newVal) {
if ft.Name == "" {
ft.Name = newVal
} else {
ft.Options = append(ft.Options, newVal)
}
}
return ft.String()
},
"StringListHasValue": func(list []string, val string) bool {
for _, v := range list {
if v == val {
return true
}
}
return false
},
})
// Load the template file using the text/template package.
tmpl, err := baseTmpl.Parse(string(dat))
if err != nil {
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
return nil, err
}
// Generate a list of template names defined in the template file.
tmplNames := []string{}
for _, defTmpl := range tmpl.Templates() {
tmplNames = append(tmplNames, defTmpl.Name())
}
// Stupid hack to return template names the in order they are defined in the file.
tmplNames, err = templateFileOrderedNames(tempFilePath, tmplNames)
if err != nil {
return nil, err
}
// Loop over all the defined templates, execute using the defined data, parse the
// formatted code and append the parsed go objects to the result list.
var resp []*goparse.GoObject
for _, tmplName := range tmplNames {
// Executed the defined template with the given data.
var tpl bytes.Buffer
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
return resp, err
}
// Format the source code to ensure its valid and code to parsed consistently.
codeBytes, err := format.Source(tpl.Bytes())
if err != nil {
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
dl := []string{}
for idx, l := range strings.Split(tpl.String(), "\n") {
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
}
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
return resp, err
}
// Remove extra white space from the code.
codeStr := strings.TrimSpace(string(codeBytes))
// Split the code into a list of strings.
codeLines := strings.Split(codeStr, "\n")
// Parse the code lines into a set of objects.
objs, err := goparse.ParseLines(codeLines, 0)
if err != nil {
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
return resp, err
}
// Append the parsed objects to the return result list.
for _, obj := range objs.List() {
if obj.Name == "" && obj.Type != goparse.GoObjectType_Import && obj.Type != goparse.GoObjectType_Var && obj.Type != goparse.GoObjectType_Const && obj.Type != goparse.GoObjectType_Comment && obj.Type != goparse.GoObjectType_LineBreak {
// All objects should have a name except for multiline var/const declarations and comments.
err = errors.Errorf("Failed to parse name with type %s from lines: %v", obj.Type, obj.Lines())
return resp, err
} else if string(obj.Type) == "" {
err = errors.Errorf("Failed to parse type for %s from lines: %v", obj.Name, obj.Lines())
return resp, err
}
resp = append(resp, obj)
}
}
return resp, nil
}
// FormatCamel formats Valdez mountain to ValdezMountain
func FormatCamel(name string) string {
return strcase.ToCamel(name)
}
// FormatCamelLower formats ValdezMountain to valdezmountain
func FormatCamelLower(name string) string {
return strcase.ToLowerCamel(FormatCamel(name))
}
// FormatCamelTitle formats ValdezMountain to Valdez Mountain
func FormatCamelTitle(name string) string {
return strings.Join(camelcase.Split(name), " ")
}
// FormatCamelLowerTitle formats ValdezMountain to valdez mountain
func FormatCamelLowerTitle(name string) string {
return strings.ToLower(FormatCamelTitle(name))
}
// FormatCamelPluralTitle formats ValdezMountain to Valdez Mountains
func FormatCamelPluralTitle(name string) string {
pts := camelcase.Split(name)
lastIdx := len(pts) - 1
pts[lastIdx] = english.PluralWord(2, pts[lastIdx], "")
return strings.Join(pts, " ")
}
// FormatCamelPluralTitleLower formats ValdezMountain to valdez mountains
func FormatCamelPluralTitleLower(name string) string {
return strings.ToLower(FormatCamelPluralTitle(name))
}
// FormatCamelPluralCamel formats ValdezMountain to ValdezMountains
func FormatCamelPluralCamel(name string) string {
return strcase.ToCamel(FormatCamelPluralTitle(name))
}
// FormatCamelPluralLower formats ValdezMountain to valdezmountains
func FormatCamelPluralLower(name string) string {
return strcase.ToLowerCamel(FormatCamelPluralTitle(name))
}
// FormatCamelPluralUnderscore formats ValdezMountain to Valdez_Mountains
func FormatCamelPluralUnderscore(name string) string {
return strings.Replace(FormatCamelPluralTitle(name), " ", "_", -1)
}
// FormatCamelPluralLowerUnderscore formats ValdezMountain to valdez_mountains
func FormatCamelPluralLowerUnderscore(name string) string {
return strings.ToLower(FormatCamelPluralUnderscore(name))
}
// FormatCamelUnderscore formats ValdezMountain to Valdez_Mountain
func FormatCamelUnderscore(name string) string {
return strings.Replace(FormatCamelTitle(name), " ", "_", -1)
}
// FormatCamelLowerUnderscore formats ValdezMountain to valdez_mountain
func FormatCamelLowerUnderscore(name string) string {
return strings.ToLower(FormatCamelUnderscore(name))
}
// templateFileOrderedNames returns the template names the in order they are defined in the file.
func templateFileOrderedNames(localPath string, names []string) (resp []string, err error) {
file, err := os.Open(localPath)
if err != nil {
return resp, errors.WithStack(err)
}
defer file.Close()
idxList := []int{}
idxNames := make(map[int]string)
idx := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
continue
}
for _, name := range names {
if strings.Contains(scanner.Text(), "\""+name+"\"") {
idxList = append(idxList, idx)
idxNames[idx] = name
break
}
}
idx = idx + 1
}
if err := scanner.Err(); err != nil {
return resp, errors.WithStack(err)
}
sort.Ints(idxList)
for _, idx := range idxList {
resp = append(resp, idxNames[idx])
}
return resp, nil
}

View File

@ -0,0 +1,301 @@
package goparse
import (
"fmt"
"go/format"
"io/ioutil"
"strings"
"github.com/pkg/errors"
)
// GoDocument defines a single go code file.
type GoDocument struct {
*GoObjects
Package string
imports GoImports
}
// GoImport defines a single import line with optional alias.
type GoImport struct {
Name string
Alias string
}
// GoImports holds a list of import lines.
type GoImports []GoImport
// NewGoDocument creates a new GoDocument with the package line set.
func NewGoDocument(packageName string) (doc *GoDocument, err error) {
doc = &GoDocument{
GoObjects: &GoObjects{
list: []*GoObject{},
},
}
err = doc.SetPackage(packageName)
return doc, err
}
// Objects returns a list of root GoObject.
func (doc *GoDocument) Objects() *GoObjects {
if doc.GoObjects == nil {
doc.GoObjects = &GoObjects{
list: []*GoObject{},
}
}
return doc.GoObjects
}
// NewObjectPackage returns a new GoObject with a single package definition line.
func NewObjectPackage(packageName string) *GoObject {
lines := []string{
fmt.Sprintf("package %s", packageName),
"",
}
obj, _ := ParseGoObject(lines, 0)
return obj
}
// SetPackage appends sets the package line for the code file.
func (doc *GoDocument) SetPackage(packageName string) error {
var existing *GoObject
for _, obj := range doc.Objects().List() {
if obj.Type == GoObjectType_Package {
existing = obj
break
}
}
new := NewObjectPackage(packageName)
var err error
if existing != nil {
err = doc.Objects().Replace(existing, new)
} else if len(doc.Objects().List()) > 0 {
// Insert after any existing comments or line breaks.
var insertPos int
//for idx, obj := range doc.Objects().List() {
// switch obj.Type {
// case GoObjectType_Comment, GoObjectType_LineBreak:
// insertPos = idx
// default:
// break
// }
//}
err = doc.Objects().Insert(insertPos, new)
} else {
err = doc.Objects().Add(new)
}
return err
}
// AddObject appends a new GoObject to the doc root object list.
func (doc *GoDocument) AddObject(newObj *GoObject) error {
return doc.Objects().Add(newObj)
}
// InsertObject inserts a new GoObject at the desired position to the doc root object list.
func (doc *GoDocument) InsertObject(pos int, newObj *GoObject) error {
return doc.Objects().Insert(pos, newObj)
}
// Imports returns the GoDocument imports.
func (doc *GoDocument) Imports() (GoImports, error) {
// If the doc imports are empty, try to load them from the root objects.
if len(doc.imports) == 0 {
for _, obj := range doc.Objects().List() {
if obj.Type != GoObjectType_Import {
continue
}
res, err := ParseImportObject(obj)
if err != nil {
return doc.imports, err
}
// Combine all the imports into a single definition.
for _, n := range res {
doc.imports = append(doc.imports, n)
}
}
}
return doc.imports, nil
}
// Lines returns all the code lines.
func (doc *GoDocument) Lines() []string {
l := []string{}
for _, ol := range doc.Objects().Lines() {
l = append(l, ol)
}
return l
}
// String returns a single value for all the code lines.
func (doc *GoDocument) String() string {
return strings.Join(doc.Lines(), "\n")
}
// Print writes all the code lines to standard out.
func (doc *GoDocument) Print() {
for _, l := range doc.Lines() {
fmt.Println(l)
}
}
// Save renders all the code lines for the document, formats the code
// and then saves it to the supplied file path.
func (doc *GoDocument) Save(localpath string) error {
res, err := format.Source([]byte(doc.String()))
if err != nil {
err = errors.WithMessage(err, "Failed formatted source code")
return err
}
err = ioutil.WriteFile(localpath, res, 0644)
if err != nil {
err = errors.WithMessagef(err, "Failed write source code to file %s", localpath)
return err
}
return nil
}
// AddImport checks for any duplicate imports by name and adds it if not.
func (doc *GoDocument) AddImport(impt GoImport) error {
impt.Name = strings.Trim(impt.Name, "\"")
// Get a list of current imports for the document.
impts, err := doc.Imports()
if err != nil {
return err
}
// If the document has as the import, don't add it.
if impts.Has(impt.Name) {
return nil
}
// Loop through all the document root objects for an object of type import.
// If one exists, append the import to the existing list.
for _, obj := range doc.Objects().List() {
if obj.Type != GoObjectType_Import || len(obj.Lines()) == 1 {
continue
}
obj.subLines = append(obj.subLines, impt.String())
obj.goObjects.list = append(obj.goObjects.list, impt.Object())
doc.imports = append(doc.imports, impt)
return nil
}
// Document does not have an existing import object, so need to create one and
// then append to the document.
newObj := NewObjectImports(impt)
// Insert after any package, any existing comments or line breaks should be included.
var insertPos int
for idx, obj := range doc.Objects().List() {
switch obj.Type {
case GoObjectType_Package, GoObjectType_Comment, GoObjectType_LineBreak:
insertPos = idx
default:
break
}
}
// Insert the new import object.
err = doc.InsertObject(insertPos, newObj)
if err != nil {
return err
}
return nil
}
// NewObjectImports returns a new GoObject with a single import definition.
func NewObjectImports(impt GoImport) *GoObject {
lines := []string{
"import (",
impt.String(),
")",
"",
}
obj, _ := ParseGoObject(lines, 0)
children, err := ParseLines(obj.subLines, 1)
if err != nil {
return nil
}
for _, child := range children.List() {
obj.Objects().Add(child)
}
return obj
}
// Has checks to see if an import exists by name or alias.
func (impts GoImports) Has(name string) bool {
for _, impt := range impts {
if name == impt.Name || (impt.Alias != "" && name == impt.Alias) {
return true
}
}
return false
}
// Line formats an import as a string.
func (impt GoImport) String() string {
var imptLine string
if impt.Alias != "" {
imptLine = fmt.Sprintf("\t%s \"%s\"", impt.Alias, impt.Name)
} else {
imptLine = fmt.Sprintf("\t\"%s\"", impt.Name)
}
return imptLine
}
// Object returns the first GoObject for an import.
func (impt GoImport) Object() *GoObject {
imptObj := NewObjectImports(impt)
return imptObj.Objects().List()[0]
}
// ParseImportObject extracts all the import definitions.
func ParseImportObject(obj *GoObject) (resp GoImports, err error) {
if obj.Type != GoObjectType_Import {
return resp, errors.Errorf("Invalid type %s", string(obj.Type))
}
for _, l := range obj.Lines() {
if !strings.Contains(l, "\"") {
continue
}
l = strings.TrimSpace(l)
pts := strings.Split(l, "\"")
var impt GoImport
if strings.HasPrefix(l, "\"") {
impt.Name = pts[1]
} else {
impt.Alias = strings.TrimSpace(pts[0])
impt.Name = pts[1]
}
resp = append(resp, impt)
}
return resp, nil
}

View File

@ -0,0 +1,458 @@
package goparse
import (
"log"
"strings"
"github.com/fatih/structtag"
"github.com/pkg/errors"
)
// GoEmptyLine defined a GoObject for a code line break.
var GoEmptyLine = GoObject{
Type: GoObjectType_LineBreak,
goObjects: &GoObjects{
list: []*GoObject{},
},
}
// GoObjectType defines a set of possible types to group
// parsed code by.
type GoObjectType = string
var (
GoObjectType_Package = "package"
GoObjectType_Import = "import"
GoObjectType_Var = "var"
GoObjectType_Const = "const"
GoObjectType_Func = "func"
GoObjectType_Struct = "struct"
GoObjectType_Comment = "comment"
GoObjectType_LineBreak = "linebreak"
GoObjectType_Line = "line"
GoObjectType_Type = "type"
)
// GoObject defines a section of code with a nested set of children.
type GoObject struct {
Type GoObjectType
Name string
startLines []string
endLines []string
subLines []string
goObjects *GoObjects
Index int
}
// GoObjects stores a list of GoObject.
type GoObjects struct {
list []*GoObject
}
// Objects returns the list of *GoObject.
func (obj *GoObject) Objects() *GoObjects {
if obj.goObjects == nil {
obj.goObjects = &GoObjects{
list: []*GoObject{},
}
}
return obj.goObjects
}
// Clone performs a deep copy of the struct.
func (obj *GoObject) Clone() *GoObject {
n := &GoObject{
Type: obj.Type,
Name: obj.Name,
startLines: obj.startLines,
endLines: obj.endLines,
subLines: obj.subLines,
goObjects: &GoObjects{
list: []*GoObject{},
},
Index: obj.Index,
}
for _, sub := range obj.Objects().List() {
n.Objects().Add(sub.Clone())
}
return n
}
// IsComment returns whether an object is of type GoObjectType_Comment.
func (obj *GoObject) IsComment() bool {
if obj.Type != GoObjectType_Comment {
return false
}
return true
}
// Contains searches all the lines for the object for a matching string.
func (obj *GoObject) Contains(match string) bool {
for _, l := range obj.Lines() {
if strings.Contains(l, match) {
return true
}
}
return false
}
// UpdateLines parses the new code and replaces the current GoObject.
func (obj *GoObject) UpdateLines(newLines []string) error {
// Parse the new lines.
objs, err := ParseLines(newLines, 0)
if err != nil {
return err
}
var newObj *GoObject
for _, obj := range objs.List() {
if obj.Type == GoObjectType_LineBreak {
continue
}
if newObj == nil {
newObj = obj
}
// There should only be one resulting parsed object that is
// not of type GoObjectType_LineBreak.
return errors.New("Can only update single blocks of code")
}
// No new code was parsed, return error.
if newObj == nil {
return errors.New("Failed to render replacement code")
}
return obj.Update(newObj)
}
// Update performs a deep copy that overwrites the existing values.
func (obj *GoObject) Update(newObj *GoObject) error {
obj.Type = newObj.Type
obj.Name = newObj.Name
obj.startLines = newObj.startLines
obj.endLines = newObj.endLines
obj.subLines = newObj.subLines
obj.goObjects = newObj.goObjects
return nil
}
// Lines returns a list of strings for current object and all children.
func (obj *GoObject) Lines() []string {
l := []string{}
// First include any lines before the sub objects.
for _, sl := range obj.startLines {
l = append(l, sl)
}
// If there are parsed sub objects include those lines else when
// no sub objects, just use the sub lines.
if len(obj.Objects().List()) > 0 {
for _, sl := range obj.Objects().Lines() {
l = append(l, sl)
}
} else {
for _, sl := range obj.subLines {
l = append(l, sl)
}
}
// Lastly include any other lines that are after all parsed sub objects.
for _, sl := range obj.endLines {
l = append(l, sl)
}
return l
}
// String returns the lines separated by line break.
func (obj *GoObject) String() string {
return strings.Join(obj.Lines(), "\n")
}
// Lines returns a list of strings for all the list objects.
func (objs *GoObjects) Lines() []string {
l := []string{}
for _, obj := range objs.List() {
for _, oj := range obj.Lines() {
l = append(l, oj)
}
}
return l
}
// String returns all the lines for the list objects.
func (objs *GoObjects) String() string {
lines := []string{}
for _, obj := range objs.List() {
lines = append(lines, obj.String())
}
return strings.Join(lines, "\n")
}
// List returns the list of GoObjects.
func (objs *GoObjects) List() []*GoObject {
return objs.list
}
// HasFunc searches the current list of objects for a function object by name.
func (objs *GoObjects) HasFunc(name string) bool {
return objs.HasType(name, GoObjectType_Func)
}
// Get returns the GoObject for the matching name and type.
func (objs *GoObjects) Get(name string, objType GoObjectType) *GoObject {
for _, obj := range objs.list {
if obj.Name == name && (objType == "" || obj.Type == objType) {
return obj
}
}
return nil
}
// HasType checks is a GoObject exists for the matching name and type.
func (objs *GoObjects) HasType(name string, objType GoObjectType) bool {
for _, obj := range objs.list {
if obj.Name == name && (objType == "" || obj.Type == objType) {
return true
}
}
return false
}
// HasObject checks to see if the exact code block exists.
func (objs *GoObjects) HasObject(src *GoObject) bool {
if src == nil {
return false
}
// Generate the code for the supplied object.
srcLines := []string{}
for _, l := range src.Lines() {
// Exclude empty lines.
l = strings.TrimSpace(l)
if l != "" {
srcLines = append(srcLines, l)
}
}
srcStr := strings.Join(srcLines, "\n")
// Loop over all the objects and match with src code.
for _, obj := range objs.list {
objLines := []string{}
for _, l := range obj.Lines() {
// Exclude empty lines.
l = strings.TrimSpace(l)
if l != "" {
objLines = append(objLines, l)
}
}
objStr := strings.Join(objLines, "\n")
// Return true if the current object code matches src code.
if srcStr == objStr {
return true
}
}
return false
}
// Add appends a new GoObject to the list.
func (objs *GoObjects) Add(newObj *GoObject) error {
newObj.Index = len(objs.list)
objs.list = append(objs.list, newObj)
return nil
}
// Insert appends a new GoObject at the desired position to the list.
func (objs *GoObjects) Insert(pos int, newObj *GoObject) error {
newList := []*GoObject{}
var newIdx int
for _, obj := range objs.list {
if obj.Index < pos {
obj.Index = newIdx
newList = append(newList, obj)
} else {
if obj.Index == pos {
newObj.Index = newIdx
newList = append(newList, newObj)
newIdx++
}
obj.Index = newIdx
newList = append(newList, obj)
}
newIdx++
}
objs.list = newList
return nil
}
// Remove deletes a GoObject from the list.
func (objs *GoObjects) Remove(delObjs ...*GoObject) error {
for _, delObj := range delObjs {
oldList := objs.List()
objs.list = []*GoObject{}
var newIdx int
for _, obj := range oldList {
if obj.Index == delObj.Index {
continue
}
obj.Index = newIdx
objs.list = append(objs.list, obj)
newIdx++
}
}
return nil
}
// Replace updates an existing GoObject while maintaining is same position.
func (objs *GoObjects) Replace(oldObj *GoObject, newObjs ...*GoObject) error {
if oldObj.Index >= len(objs.list) {
return errors.WithStack(errGoObjectNotExist)
} else if len(newObjs) == 0 {
return nil
}
oldList := objs.List()
objs.list = []*GoObject{}
var newIdx int
for _, obj := range oldList {
if obj.Index < oldObj.Index {
obj.Index = newIdx
objs.list = append(objs.list, obj)
newIdx++
} else if obj.Index == oldObj.Index {
for _, newObj := range newObjs {
newObj.Index = newIdx
objs.list = append(objs.list, newObj)
newIdx++
}
} else {
obj.Index = newIdx
objs.list = append(objs.list, obj)
newIdx++
}
}
return nil
}
// ReplaceFuncByName finds an existing GoObject with type GoObjectType_Func by name
// and then performs a replace with the supplied new GoObject.
func (objs *GoObjects) ReplaceFuncByName(name string, fn *GoObject) error {
return objs.ReplaceTypeByName(name, fn, GoObjectType_Func)
}
// ReplaceTypeByName finds an existing GoObject with type by name
// and then performs a replace with the supplied new GoObject.
func (objs *GoObjects) ReplaceTypeByName(name string, newObj *GoObject, objType GoObjectType) error {
if newObj.Name == "" {
newObj.Name = name
}
if newObj.Type == "" && objType != "" {
newObj.Type = objType
}
for _, obj := range objs.list {
if obj.Name == name && (objType == "" || objType == obj.Type) {
return objs.Replace(obj, newObj)
}
}
return errors.WithStack(errGoObjectNotExist)
}
// Empty determines if all the GoObject in the list are line breaks.
func (objs *GoObjects) Empty() bool {
var hasStuff bool
for _, obj := range objs.List() {
switch obj.Type {
case GoObjectType_LineBreak:
//case GoObjectType_Comment:
//case GoObjectType_Import:
// do nothing
default:
hasStuff = true
}
}
return hasStuff
}
// Debug prints out the GoObject to logger.
func (obj *GoObject) Debug(log *log.Logger) {
log.Println(obj.Name)
log.Println(" > type:", obj.Type)
log.Println(" > start lines:")
for _, l := range obj.startLines {
log.Println(" ", l)
}
log.Println(" > sub lines:")
for _, l := range obj.subLines {
log.Println(" ", l)
}
log.Println(" > end lines:")
for _, l := range obj.endLines {
log.Println(" ", l)
}
}
// Defines a property of a struct.
type structProp struct {
Name string
Type string
Tags *structtag.Tags
}
// ParseStructProp extracts the details for a struct property.
func ParseStructProp(obj *GoObject) (structProp, error) {
if obj.Type != GoObjectType_Line {
return structProp{}, errors.Errorf("Unable to parse object of type %s", obj.Type)
}
// Remove any white space from the code line.
ls := strings.TrimSpace(strings.Join(obj.Lines(), " "))
// Extract the property name and type for the line.
// ie: ID string `json:"id"`
var resp structProp
for _, p := range strings.Split(ls, " ") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if resp.Name == "" {
resp.Name = p
} else if resp.Type == "" {
resp.Type = p
} else {
break
}
}
// If the line contains tags, extract and parse them.
if strings.Contains(ls, "`") {
tagStr := strings.Split(ls, "`")[1]
var err error
resp.Tags, err = structtag.Parse(tagStr)
if err != nil {
err = errors.WithMessagef(err, "Failed to parse struct tag for field %s: %s", resp.Name, tagStr)
return structProp{}, err
}
}
return resp, nil
}

View File

@ -0,0 +1,329 @@
package goparse
import (
"bufio"
"bytes"
"fmt"
"go/format"
"io/ioutil"
"log"
"strings"
"unicode"
"github.com/pkg/errors"
)
var (
errGoParseType = errors.New("Unable to determine type for line")
errGoTypeMissingCodeTemplate = errors.New("No code defined for type")
errGoObjectNotExist = errors.New("GoObject does not exist")
)
// ParseFile reads a go code file and parses into a easily transformable set of objects.
func ParseFile(log *log.Logger, localPath string) (*GoDocument, error) {
// Read the code file.
src, err := ioutil.ReadFile(localPath)
if err != nil {
err = errors.WithMessagef(err, "Failed to read file %s", localPath)
return nil, err
}
// Format the code file source to ensure parse works.
dat, err := format.Source(src)
if err != nil {
err = errors.WithMessagef(err, "Failed to format source for file %s", localPath)
log.Printf("ParseFile : %v\n%v", err, string(src))
return nil, err
}
// Loop of the formatted source code and generate a list of code lines.
lines := []string{}
r := bytes.NewReader(dat)
scanner := bufio.NewScanner(r)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
err = errors.WithMessagef(err, "Failed read formatted source code for file %s", localPath)
return nil, err
}
// Parse the code lines into a set of objects.
objs, err := ParseLines(lines, 0)
if err != nil {
log.Println(err)
return nil, err
}
// Append the resulting objects to the document.
doc := &GoDocument{}
for _, obj := range objs.List() {
if obj.Type == GoObjectType_Package {
doc.Package = obj.Name
}
doc.AddObject(obj)
}
return doc, nil
}
// ParseLines takes the list of formatted code lines and returns the GoObjects.
func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
objs = &GoObjects{
list: []*GoObject{},
}
var (
multiLine bool
multiComment bool
muiliVar bool
)
curDepth := -1
objLines := []string{}
for idx, l := range lines {
ls := strings.TrimSpace(l)
ld := lineDepth(l)
if ld == depth {
if strings.HasPrefix(ls, "/*") {
multiLine = true
multiComment = true
} else if strings.HasSuffix(ls, "(") ||
strings.HasSuffix(ls, "{") {
if !multiLine {
multiLine = true
}
} else if strings.Contains(ls, "`") {
if !multiLine && strings.Count(ls, "`")%2 != 0 {
if muiliVar {
muiliVar = false
} else {
muiliVar = true
}
}
}
objLines = append(objLines, l)
if multiComment {
if strings.HasSuffix(ls, "*/") {
multiComment = false
multiLine = false
}
} else {
if strings.HasPrefix(ls, ")") ||
strings.HasPrefix(ls, "}") {
multiLine = false
}
}
if !multiLine && !muiliVar {
for eidx := idx + 1; eidx < len(lines); eidx++ {
if el := lines[eidx]; strings.TrimSpace(el) == "" {
objLines = append(objLines, el)
} else {
break
}
}
obj, err := ParseGoObject(objLines, depth)
if err != nil {
log.Println(err)
return objs, err
}
err = objs.Add(obj)
if err != nil {
log.Println(err)
return objs, err
}
objLines = []string{}
}
} else if (multiLine && ld >= curDepth && ld >= depth && len(objLines) > 0) || muiliVar {
objLines = append(objLines, l)
if strings.Contains(ls, "`") {
if !multiLine && strings.Count(ls, "`")%2 != 0 {
if muiliVar {
muiliVar = false
} else {
muiliVar = true
}
}
}
}
}
for _, obj := range objs.List() {
children, err := ParseLines(obj.subLines, depth+1)
if err != nil {
log.Println(err)
return objs, err
}
for _, child := range children.List() {
obj.Objects().Add(child)
}
}
return objs, nil
}
// ParseGoObject generates a GoObjected for the given code lines.
func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
// If there are no lines, return a line break.
if len(lines) == 0 {
return &GoEmptyLine, nil
}
firstLine := lines[0]
firstStrip := strings.TrimSpace(firstLine)
if len(firstStrip) == 0 {
return &GoEmptyLine, nil
}
obj = &GoObject{
goObjects: &GoObjects{
list: []*GoObject{},
},
}
if strings.HasPrefix(firstStrip, "var") {
obj.Type = GoObjectType_Var
} else if strings.HasPrefix(firstStrip, "const") {
obj.Type = GoObjectType_Const
} else if strings.HasPrefix(firstStrip, "func") {
obj.Type = GoObjectType_Func
if strings.HasPrefix(firstStrip, "func (") {
funcLine := strings.TrimLeft(strings.TrimSpace(strings.TrimLeft(firstStrip, "func ")), "(")
var structName string
pts := strings.Split(strings.Split(funcLine, ")")[0], " ")
for i := len(pts) - 1; i >= 0; i-- {
ptVal := strings.TrimSpace(pts[i])
if ptVal != "" {
structName = ptVal
break
}
}
var funcName string
pts = strings.Split(strings.Split(funcLine, "(")[0], " ")
for i := len(pts) - 1; i >= 0; i-- {
ptVal := strings.TrimSpace(pts[i])
if ptVal != "" {
funcName = ptVal
break
}
}
obj.Name = fmt.Sprintf("%s.%s", structName, funcName)
} else {
obj.Name = strings.TrimLeft(firstStrip, "func ")
obj.Name = strings.Split(obj.Name, "(")[0]
}
} else if strings.HasSuffix(firstStrip, "struct {") || strings.HasSuffix(firstStrip, "struct{") {
obj.Type = GoObjectType_Struct
if strings.HasPrefix(firstStrip, "type ") {
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
}
obj.Name = strings.Split(firstStrip, " ")[0]
} else if strings.HasPrefix(firstStrip, "type") {
obj.Type = GoObjectType_Type
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
obj.Name = strings.Split(firstStrip, " ")[0]
} else if strings.HasPrefix(firstStrip, "package") {
obj.Name = strings.TrimSpace(strings.TrimLeft(firstStrip, "package "))
obj.Type = GoObjectType_Package
} else if strings.HasPrefix(firstStrip, "import") {
obj.Type = GoObjectType_Import
} else if strings.HasPrefix(firstStrip, "//") {
obj.Type = GoObjectType_Comment
} else if strings.HasPrefix(firstStrip, "/*") {
obj.Type = GoObjectType_Comment
} else {
if depth > 0 {
obj.Type = GoObjectType_Line
} else {
err = errors.WithStack(errGoParseType)
return obj, err
}
}
var (
hasSub bool
muiliVarStart bool
muiliVarSub bool
muiliVarEnd bool
)
for _, l := range lines {
ld := lineDepth(l)
if (ld == depth && !muiliVarSub) || muiliVarStart || muiliVarEnd {
if hasSub && !muiliVarStart {
if strings.TrimSpace(l) != "" {
obj.endLines = append(obj.endLines, l)
}
if strings.Count(l, "`")%2 != 0 {
if muiliVarEnd {
muiliVarEnd = false
} else {
muiliVarEnd = true
}
}
} else {
obj.startLines = append(obj.startLines, l)
if strings.Count(l, "`")%2 != 0 {
if muiliVarStart {
muiliVarStart = false
} else {
muiliVarStart = true
}
}
}
} else if ld > depth || muiliVarSub {
obj.subLines = append(obj.subLines, l)
hasSub = true
if strings.Count(l, "`")%2 != 0 {
if muiliVarSub {
muiliVarSub = false
} else {
muiliVarSub = true
}
}
}
}
// add trailing linebreak
if len(obj.endLines) > 0 {
obj.endLines = append(obj.endLines, "")
}
return obj, err
}
// lineDepth returns the number of spaces for the given code line
// used to determine the code level for nesting objects.
func lineDepth(l string) int {
depth := len(l) - len(strings.TrimLeftFunc(l, unicode.IsSpace))
ls := strings.TrimSpace(l)
if strings.HasPrefix(ls, "}") && strings.Contains(ls, " else ") {
depth = depth + 1
} else if strings.HasPrefix(ls, "case ") {
depth = depth + 1
}
return depth
}

View File

@ -0,0 +1,195 @@
package goparse
import (
"log"
"os"
"strings"
"testing"
"github.com/onsi/gomega"
)
var logger *log.Logger
func init() {
logger = log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
}
func TestParseFileModel1(t *testing.T) {
_, err := ParseFile(logger, "test_gofile_model1.txt")
if err != nil {
t.Fatalf("got error %v", err)
}
}
func TestMultilineVar(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `func ContextAllowedAccountIds(ctx context.Context, db *gorm.DB) (resp akdatamodels.Uint32List, err error) {
resp = []uint32{}
accountId := akcontext.ContextAccountId(ctx)
m := datamodels.UserAccount{}
q := fmt.Sprintf("select
distinct account_id
from %s where account_id = ?", m.TableName())
db = db.Raw(q, accountId)
}
`
code = strings.Replace(code, "\"", "`", -1)
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
func TestNewDocImports(t *testing.T) {
g := gomega.NewGomegaWithT(t)
expected := []string{
"package goparse",
"",
"import (",
" \"github.com/go/pkg1\"",
" \"github.com/go/pkg2\"",
")",
"",
}
doc := &GoDocument{}
doc.SetPackage("goparse")
doc.AddImport(GoImport{Name: "github.com/go/pkg1"})
doc.AddImport(GoImport{Name: "github.com/go/pkg2"})
g.Expect(doc.Lines()).Should(gomega.Equal(expected))
}
func TestParseLines1(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
g := gomega.NewGomegaWithT(t)
obj := datamodels.MockModelNew()
resp, err := ModelCreate(ctx, DB, &obj)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(resp.Name).Should(gomega.Equal(obj.Name))
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
return resp
}
`
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
func TestParseLines2(t *testing.T) {
code := `func structToMap(s interface{}) (resp map[string]interface{}) {
dat, _ := json.Marshal(s)
_ = json.Unmarshal(dat, &resp)
for k, x := range resp {
switch v := x.(type) {
case time.Time:
if v.IsZero() {
delete(resp, k)
}
case *time.Time:
if v == nil || v.IsZero() {
delete(resp, k)
}
case nil:
delete(resp, k)
}
}
return resp
}
`
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
testLineTextMatches(t, objs.Lines(), lines)
}
func TestParseLines3(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `type UserAccountRoleName = string
const (
UserAccountRoleName_None UserAccountRoleName = ""
UserAccountRoleName_Admin UserAccountRoleName = "admin"
UserAccountRoleName_User UserAccountRoleName = "user"
)
type UserAccountRole struct {
Id uint32 ^gorm:"column:id;type:int(10) unsigned AUTO_INCREMENT;primary_key;not null;auto_increment;" truss:"internal:true"^
CreatedAt time.Time ^gorm:"column:created_at;type:datetime;default:CURRENT_TIMESTAMP;not null;" truss:"internal:true"^
UpdatedAt time.Time ^gorm:"column:updated_at;type:datetime;" truss:"internal:true"^
DeletedAt *time.Time ^gorm:"column:deleted_at;type:datetime;" truss:"internal:true"^
Role UserAccountRoleName ^gorm:"unique_index:user_account_role;column:role;type:enum('admin', 'user')"^
// belongs to User
User *User ^gorm:"foreignkey:UserId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true"^
UserId uint32 ^gorm:"unique_index:user_account_role;"^
// belongs to Account
Account *Account ^gorm:"foreignkey:AccountId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true;api_ro:true;"^
AccountId uint32 ^gorm:"unique_index:user_account_role;" truss:"internal:true;api_ro:true;"^
}
func (UserAccountRole) TableName() string {
return "user_account_roles"
}
`
code = strings.Replace(code, "^", "'", -1)
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
func testLineTextMatches(t *testing.T, l1, l2 []string) {
g := gomega.NewGomegaWithT(t)
m1 := []string{}
for _, l := range l1 {
l = strings.TrimSpace(l)
if l != "" {
m1 = append(m1, l)
}
}
m2 := []string{}
for _, l := range l2 {
l = strings.TrimSpace(l)
if l != "" {
m2 = append(m2, l)
}
}
g.Expect(m1).Should(gomega.Equal(m2))
}

View File

@ -0,0 +1,126 @@
package account
import (
"database/sql"
"database/sql/driver"
"time"
"github.com/lib/pq"
"github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9"
)
// Account represents someone with access to our system.
type Account struct {
ID string `json:"id"`
Name string `json:"name"`
Address1 string `json:"address1"`
Address2 string `json:"address2"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
Zipcode string `json:"zipcode"`
Status AccountStatus `json:"status"`
Timezone string `json:"timezone"`
SignupUserID sql.NullString `json:"signup_user_id"`
BillingUserID sql.NullString `json:"billing_user_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// CreateAccountRequest contains information needed to create a new Account.
type CreateAccountRequest struct {
Name string `json:"name" validate:"required,unique"`
Address1 string `json:"address1" validate:"required"`
Address2 string `json:"address2" validate:"omitempty"`
City string `json:"city" validate:"required"`
Region string `json:"region" validate:"required"`
Country string `json:"country" validate:"required"`
Zipcode string `json:"zipcode" validate:"required"`
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
Timezone *string `json:"timezone" validate:"omitempty"`
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
}
// UpdateAccountRequest defines what information may be provided to modify an existing
// Account. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateAccountRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty,unique"`
Address1 *string `json:"address1" validate:"omitempty"`
Address2 *string `json:"address2" validate:"omitempty"`
City *string `json:"city" validate:"omitempty"`
Region *string `json:"region" validate:"omitempty"`
Country *string `json:"country" validate:"omitempty"`
Zipcode *string `json:"zipcode" validate:"omitempty"`
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
Timezone *string `json:"timezone" validate:"omitempty"`
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
}
// AccountFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response.
type AccountFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// AccountStatus represents the status of an account.
type AccountStatus string
// AccountStatus values define the status field of a user account.
const (
// AccountStatus_Active defines the state when a user can access an account.
AccountStatus_Active AccountStatus = "active"
// AccountStatus_Pending defined the state when an account was created but
// not activated.
AccountStatus_Pending AccountStatus = "pending"
// AccountStatus_Disabled defines the state when a user has been disabled from
// accessing an account.
AccountStatus_Disabled AccountStatus = "disabled"
)
// AccountStatus_Values provides list of valid AccountStatus values.
var AccountStatus_Values = []AccountStatus{
AccountStatus_Active,
AccountStatus_Pending,
AccountStatus_Disabled,
}
// Scan supports reading the AccountStatus value from the database.
func (s *AccountStatus) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = AccountStatus(string(asBytes))
return nil
}
// Value converts the AccountStatus value to be stored in the database.
func (s AccountStatus) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof=active invited disabled")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the AccountStatus value to a string.
func (s AccountStatus) String() string {
return string(s)
}

View File

@ -0,0 +1,227 @@
package main
import (
"encoding/json"
"expvar"
"io/ioutil"
"log"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/cmd/dbtable2crud"
"github.com/kelseyhightower/envconfig"
"github.com/lib/pq"
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/urfave/cli"
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
)
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
// service is the name of the program used for logging, tracing and the
// the prefix used for loading env variables
// ie: export TRUSS_ENV=dev
var service = "TRUSS"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
var cfg struct {
DB struct {
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
User string `default:"postgres" envconfig:"USER"`
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
Database string `default:"shared" envconfig:"DATABASE"`
Driver string `default:"postgres" envconfig:"DRIVER"`
Timezone string `default:"utc" envconfig:"TIMEZONE"`
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
}
}
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(service, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
// TODO: can't use flag.Process here since it doesn't support nested arg options
//if err := flag.Process(&cfg); err != nil {
/// if err != flag.ErrHelp {
// log.Fatalf("main : Parsing Command Line : %v", err)
// }
// return // We displayed help.
//}
// =========================================================================
// Log App Info
// Print the build version for our logs. Also expose it under /debug/vars.
expvar.NewString("build").Set(build)
log.Printf("main : Started : Application Initializing version %q", build)
defer log.Println("main : Completed")
// Print the config for our logs. It's important to any credentials in the config
// that could expose a security risk are excluded from being json encoded by
// applying the tag `json:"-"` to the struct var.
{
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("main : Config : %v\n", string(cfgJSON))
}
// =========================================================================
// Start Database
var dbUrl url.URL
{
// Query parameters.
var q url.Values = make(map[string][]string)
// Handle SSL Mode
if cfg.DB.DisableTLS {
q.Set("sslmode", "disable")
} else {
q.Set("sslmode", "require")
}
q.Set("timezone", cfg.DB.Timezone)
// Construct url.
dbUrl = url.URL{
Scheme: cfg.DB.Driver,
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
Host: cfg.DB.Host,
Path: cfg.DB.Database,
RawQuery: q.Encode(),
}
}
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
}
defer masterDb.Close()
// =========================================================================
// Start Truss
app := cli.NewApp()
app.Commands = []cli.Command{
{
Name: "dbtable2crud",
Aliases: []string{"dbtable2crud"},
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project",
Flags: []cli.Flag{
cli.StringFlag{Name: "dbtable, table"},
cli.StringFlag{Name: "modelFile, modelfile, file"},
cli.StringFlag{Name: "modelName, modelname, model"},
cli.StringFlag{Name: "templateDir, templates", Value: "./templates/dbtable2crud"},
cli.StringFlag{Name: "projectPath", Value: ""},
},
Action: func(c *cli.Context) error {
dbTable := strings.TrimSpace(c.String("dbtable"))
modelFile := strings.TrimSpace(c.String("modelFile"))
modelName := strings.TrimSpace(c.String("modelName"))
templateDir := strings.TrimSpace(c.String("templateDir"))
projectPath := strings.TrimSpace(c.String("projectPath"))
pwd, err := os.Getwd()
if err != nil {
return errors.WithMessage(err, "Failed to get current working directory")
}
if !path.IsAbs(templateDir) {
templateDir = filepath.Join(pwd, templateDir)
}
ok, err := exists(templateDir)
if err != nil {
return errors.WithMessage(err, "Failed to load template directory")
} else if !ok {
return errors.Errorf("Template directory %s does not exist", templateDir)
}
if modelFile == "" {
return errors.Errorf("Model file path is required")
}
if !path.IsAbs(modelFile) {
modelFile = filepath.Join(pwd, modelFile)
}
ok, err = exists(modelFile)
if err != nil {
return errors.WithMessage(err, "Failed to load model file")
} else if !ok {
return errors.Errorf("Model file %s does not exist", modelFile)
}
// Load the project path from go.mod if not set.
if projectPath == "" {
goModFile := filepath.Join(pwd, "../../go.mod")
ok, err = exists(goModFile)
if err != nil {
return errors.WithMessage(err, "Failed to load go.mod for project")
} else if !ok {
return errors.Errorf("Failed to locate project go.mod at %s", goModFile)
}
b, err := ioutil.ReadFile(goModFile)
if err != nil {
return errors.WithMessagef(err, "Failed to read go.mod at %s", goModFile)
}
lines := strings.Split(string(b), "\n")
for _, l := range lines {
if strings.HasPrefix(l, "module ") {
projectPath = strings.TrimSpace(strings.Split(l, " ")[1])
break
}
}
}
if modelName == "" {
modelName = strings.Split(filepath.Base(modelFile), ".")[0]
modelName = strings.Replace(modelName, "_", " ", -1)
modelName = strings.Replace(modelName, "-", " ", -1)
modelName = strings.Title(modelName)
modelName = strings.Replace(modelName, " ", "", -1)
}
return dbtable2crud.Run(masterDb, log, cfg.DB.Database, dbTable, modelFile, modelName, templateDir, projectPath)
},
},
}
err = app.Run(os.Args)
if err != nil {
log.Fatalf("main : Truss : %+v", err)
}
log.Printf("main : Truss : Completed")
}
// exists returns a bool as to whether a file path exists.
func exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return true, err
}

View File

@ -0,0 +1,4 @@
export TRUSS_DB_HOST=127.0.0.1:5433
export TRUSS_DB_USER=postgres
export TRUSS_DB_PASS=postgres
export TRUSS_DB_DISABLE_TLS=true

View File

@ -0,0 +1,503 @@
{{ define "imports"}}
import (
"context"
"database/sql"
"time"
"{{ $.GoSrcPath }}/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
{{ end }}
{{ define "Globals"}}
const (
// The database table for {{ $.Model.Name }}
{{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}"
)
var (
// ErrNotFound abstracts the postgres not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
{{ end }}
{{ define "Helpers"}}
// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }}
var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}"
// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct
// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns
func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
var (
m {{ $.Model.Name }}
err error
)
err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }})
if err != nil {
return nil, errors.WithStack(err)
}
return &a, nil
}
{{ end }}
{{ define "ACL"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
{{ if $hasAccountId }}
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
if claims.Audience != "" {
// select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID]
query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var {{ FormatCamelLower $.Model.PrimaryField }} string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }})
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access
// to the specified {{ FormatCamelLowerTitle $.Model.Name }}.
if {{ FormatCamelLower $.Model.PrimaryField }} == "" {
return errors.WithStack(ErrForbidden)
}
}
{{ else }}
// TODO: Unable to auto generate sql statement, update accordingly.
panic("Not implemented!")
{{ end }}
return nil
}
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
if err != nil {
return err
}
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
return nil
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
// 1. No claims, request is internal, no ACL applied
{{ if $hasAccountId }}
// 2. All role types can access their user ID
{{ end }}
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" {
return nil
}
{{ if $hasAccountId }}
query.Where(query.Equal("account_id", claims.Audience))
{{ end }}
return nil
}
{{ end }}
{{ define "Find"}}
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}
// selectQuery constructs a base select query for {{ $.Model.Name }}
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
query.From({{ FormatCamelLower $.Model.Name }}TableName)
return query
}
// findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }})
}
// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find")
defer span.Finish()
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
query.From({{ FormatCamelLower $.Model.Name }}TableName)
{{ if $hasArchived }}
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
{{ end }}
// Check to see if a sub query needs to be applied for the claims.
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// Fetch all entries from the db.
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed")
return nil, err
}
// Iterate over each row.
resp := []*{{ $.Model.Name }}{}
for rows.Next() {
u, err := mapRowsTo{{ $.Model.Name }}(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read")
defer span.Finish()
// Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }}
query := selectQuery()
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }}))
res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }})
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
{{ end }}
{{ define "Create"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
{{ $createFields := (index $.StructFields $reqName) }}
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
defer span.Finish()
if claims.Audience != "" {
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
{{ if $hasAccountId }}
if req.AccountId != "" {
// Request accountId must match claims.
if req.AccountId != claims.Audience {
return errors.WithStack(ErrForbidden)
}
} else {
// Set the accountId from claims.
req.AccountId = claims.Audience
}
{{ end }}
}
v := validator.New()
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
m := {{ $.Model.Name }}{
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}
{{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }},
{{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now,
{{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }}
{{ end }}{{ end }}
}
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
if req.{{ $f.FieldName }} != nil {
{{ if eq $f.FieldType "sql.NullString" }}
m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
{{ else if eq $f.FieldType "*sql.NullString" }}
m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
{{ else }}
m.{{ $f.FieldName }} = *req.{{ $f.FieldName }}
{{ end }}
}
{{ end }}{{ end }}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName)
query.Cols(
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}",
{{ end }}{{ end }}
)
query.Values(
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }},
{{ end }}{{ end }}
)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed")
return nil, err
}
return &a, nil
}
{{ end }}
{{ define "Update"}}
{{ $reqName := (Concat $.Model.Name "UpdateRequest") }}
{{ $updateFields := (index $.StructFields $reqName) }}
// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update")
defer span.Finish()
v := validator.New()
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }})
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
var fields []string
{{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }}
{{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }}
if req.{{ $uf.FieldName }} != nil {
{{ if and ($optional) ($isUuid) }}
if *req.{{ $uf.FieldName }} != "" {
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
} else {
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil))
}
{{ else }}
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
{{ end }}
}
{{ end }}{{ end }}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
{{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
{{ end }}
query.Set(fields...)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}
{{ define "Archive"}}
// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
}{
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}
{{ define "Delete"}}
// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
}{
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}

View File

@ -0,0 +1,79 @@
{{ define "CreateRequest"}}
// {{ FormatCamel $.Model.Name }}CreateRequest contains information needed to create a new {{ FormatCamel $.Model.Name }}.
type {{ FormatCamel $.Model.Name }}CreateRequest struct {
{{ range $fk, $f := .Model.Fields }}{{ if and ($f.ApiCreate) (ne $f.FieldName $.Model.PrimaryField) }}{{ $optional := (FieldTagHasOption $f "validate" "omitempty") }}
{{ $f.FieldName }} {{ if and ($optional) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ FieldTag $f "validate" }}`
{{ end }}{{ end }}
}
{{ end }}
{{ define "UpdateRequest"}}
// {{ FormatCamel $.Model.Name }}UpdateRequest defines what information may be provided to modify an existing
// {{ FormatCamel $.Model.Name }}. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank.
type {{ FormatCamel $.Model.Name }}UpdateRequest struct {
{{ range $fk, $f := .Model.Fields }}{{ if $f.ApiUpdate }}
{{ $f.FieldName }} {{ if and (ne $f.FieldName $.Model.PrimaryField) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ if ne $f.FieldName $.Model.PrimaryField }}{{ FieldTagReplaceOrPrepend $f "validate" "required" "omitempty" }}{{ else }}{{ FieldTagReplaceOrPrepend $f "validate" "omitempty" "required" }}{{ end }}`
{{ end }}{{ end }}
}
{{ end }}
{{ define "FindRequest"}}
// {{ FormatCamel $.Model.Name }}FindRequest defines the possible options to search for {{ FormatCamelPluralTitleLower $.Model.Name }}. By default
// archived {{ FormatCamelLowerTitle $.Model.Name }} will be excluded from response.
type {{ FormatCamel $.Model.Name }}FindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}{{ if $hasArchived }}IncludedArchived bool{{ end }}
}
{{ end }}
{{ define "Enums"}}
{{ range $fk, $f := .Model.Fields }}{{ if $f.DbColumn }}{{ if $f.DbColumn.IsEnum }}
// {{ $f.FieldType }} represents the {{ $f.ColumnName }} of {{ FormatCamelLowerTitle $.Model.Name }}.
type {{ $f.FieldType }} string
// {{ $f.FieldType }} values define the {{ $f.ColumnName }} field of {{ FormatCamelLowerTitle $.Model.Name }}.
const (
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}"
{{ end }}
)
// {{ $f.FieldType }}_Values provides list of valid {{ $f.FieldType }} values.
var {{ $f.FieldType }}_Values = []{{ $f.FieldType }}{
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
{{ $f.FieldType }}_{{ FormatCamel $ev }},
{{ end }}
}
// Scan supports reading the {{ $f.FieldType }} value from the database.
func (s *{{ $f.FieldType }}) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = {{ $f.FieldType }}(string(asBytes))
return nil
}
// Value converts the {{ $f.FieldType }} value to be stored in the database.
func (s {{ $f.FieldType }}) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof={{ JoinStrings $f.DbColumn.EnumValues " " }}")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the {{ $f.FieldType }} value to a string.
func (s {{ $f.FieldType }}) String() string {
return string(s)
}
{{ end }}{{ end }}{{ end }}
{{ end }}