You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-15 00:15:15 +02:00
Completed truss code gen for generating model requests and crud.
This commit is contained in:
1
example-project/tools/truss/.gitignore
vendored
Normal file
1
example-project/tools/truss/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
truss
|
33
example-project/tools/truss/README.md
Normal file
33
example-project/tools/truss/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# SaaS Truss
|
||||
|
||||
Copyright 2019, Geeks Accelerator
|
||||
accelerator@geeksinthewoods.com.com
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
Truss provides code generation to reduce copy/pasting.
|
||||
|
||||
|
||||
## Local Installation
|
||||
|
||||
### Build
|
||||
```bash
|
||||
go build .
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
./truss -h
|
||||
|
||||
Usage of ./truss
|
||||
--cmd string <dbtable2crud>
|
||||
--db_host string <127.0.0.1:5433>
|
||||
--db_user string <postgres>
|
||||
--db_pass string <postgres>
|
||||
--db_database string <shared>
|
||||
--db_driver string <postgres>
|
||||
--db_timezone string <utc>
|
||||
--db_disabletls bool <false>
|
||||
```
|
||||
|
149
example-project/tools/truss/cmd/dbtable2crud/db.go
Normal file
149
example-project/tools/truss/cmd/dbtable2crud/db.go
Normal file
@ -0,0 +1,149 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type psqlColumn struct {
|
||||
Table string
|
||||
Column string
|
||||
ColumnId int64
|
||||
NotNull bool
|
||||
DataTypeFull string
|
||||
DataTypeName string
|
||||
DataTypeLength *int
|
||||
NumericPrecision *int
|
||||
NumericScale *int
|
||||
IsPrimaryKey bool
|
||||
PrimaryKeyName *string
|
||||
IsUniqueKey bool
|
||||
UniqueKeyName *string
|
||||
IsForeignKey bool
|
||||
ForeignKeyName *string
|
||||
ForeignKeyColumnId pq.Int64Array
|
||||
ForeignKeyTable *string
|
||||
ForeignKeyLocalColumnId pq.Int64Array
|
||||
DefaultFull *string
|
||||
DefaultValue *string
|
||||
IsEnum bool
|
||||
EnumTypeId *string
|
||||
EnumValues []string
|
||||
}
|
||||
|
||||
// descTable lists all the columns for a table.
|
||||
func descTable(db *sqlx.DB, dbName, dbTable string) ([]psqlColumn, error) {
|
||||
|
||||
queryStr := fmt.Sprintf(`SELECT
|
||||
c.relname as table,
|
||||
f.attname as column,
|
||||
f.attnum as columnId,
|
||||
f.attnotnull as not_null,
|
||||
pg_catalog.format_type(f.atttypid,f.atttypmod) AS data_type_full,
|
||||
t.typname AS data_type_name,
|
||||
CASE WHEN f.atttypmod >= 0 AND t.typname <> 'numeric'THEN (f.atttypmod - 4) --first 4 bytes are for storing actual length of data
|
||||
END AS data_type_length,
|
||||
CASE WHEN t.typname = 'numeric' THEN (((f.atttypmod - 4) >> 16) & 65535)
|
||||
END AS numeric_precision,
|
||||
CASE WHEN t.typname = 'numeric' THEN ((f.atttypmod - 4)& 65535 )
|
||||
END AS numeric_scale,
|
||||
CASE WHEN p.contype = 'p' THEN true ELSE false
|
||||
END AS is_primary_key,
|
||||
CASE WHEN p.contype = 'p' THEN p.conname
|
||||
END AS primary_key_name,
|
||||
CASE WHEN p.contype = 'u' THEN true ELSE false
|
||||
END AS is_unique_key,
|
||||
CASE WHEN p.contype = 'u' THEN p.conname
|
||||
END AS unique_key_name,
|
||||
CASE WHEN p.contype = 'f' THEN true ELSE false
|
||||
END AS is_foreign_key,
|
||||
CASE WHEN p.contype = 'f' THEN p.conname
|
||||
END AS foreignkey_name,
|
||||
CASE WHEN p.contype = 'f' THEN p.confkey
|
||||
END AS foreign_key_columnid,
|
||||
CASE WHEN p.contype = 'f' THEN g.relname
|
||||
END AS foreign_key_table,
|
||||
CASE WHEN p.contype = 'f' THEN p.conkey
|
||||
END AS foreign_key_local_column_id,
|
||||
CASE WHEN f.atthasdef = 't' THEN d.adsrc
|
||||
END AS default_value,
|
||||
CASE WHEN t.typtype = 'e' THEN true ELSE false
|
||||
END AS is_enum,
|
||||
CASE WHEN t.typtype = 'e' THEN t.oid
|
||||
END AS enum_type_id
|
||||
FROM pg_attribute f
|
||||
JOIN pg_class c ON c.oid = f.attrelid
|
||||
JOIN pg_type t ON t.oid = f.atttypid
|
||||
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||
WHERE c.relkind = 'r'::char
|
||||
AND f.attisdropped = false
|
||||
AND c.relname = '%s'
|
||||
AND f.attnum > 0
|
||||
ORDER BY f.attnum
|
||||
;`, dbTable) // AND n.nspname = '%s'
|
||||
|
||||
rows, err := db.Query(queryStr)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// iterate over each row
|
||||
var resp []psqlColumn
|
||||
for rows.Next() {
|
||||
var c psqlColumn
|
||||
err = rows.Scan(&c.Table, &c.Column, &c.ColumnId, &c.NotNull, &c.DataTypeFull, &c.DataTypeName, &c.DataTypeLength, &c.NumericPrecision, &c.NumericScale, &c.IsPrimaryKey, &c.PrimaryKeyName, &c.IsUniqueKey, &c.UniqueKeyName, &c.IsForeignKey, &c.ForeignKeyName, &c.ForeignKeyColumnId, &c.ForeignKeyTable, &c.ForeignKeyLocalColumnId, &c.DefaultFull, &c.IsEnum, &c.EnumTypeId)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.DefaultFull != nil {
|
||||
defaultValue := *c.DefaultFull
|
||||
|
||||
// "'active'::project_status_t"
|
||||
defaultValue = strings.Split(defaultValue, "::")[0]
|
||||
c.DefaultValue = &defaultValue
|
||||
}
|
||||
|
||||
resp = append(resp, c)
|
||||
}
|
||||
|
||||
for colIdx, dbCol := range resp {
|
||||
if !dbCol.IsEnum {
|
||||
continue
|
||||
}
|
||||
|
||||
queryStr := fmt.Sprintf(`SELECT e.enumlabel
|
||||
FROM pg_enum AS e
|
||||
WHERE e.enumtypid = '%s'
|
||||
ORDER BY e.enumsortorder`, *dbCol.EnumTypeId)
|
||||
|
||||
rows, err := db.Query(queryStr)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var v string
|
||||
err = rows.Scan(&v)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
dbCol.EnumValues = append(dbCol.EnumValues, v)
|
||||
}
|
||||
|
||||
resp[colIdx] = dbCol
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
378
example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go
Normal file
378
example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go
Normal file
@ -0,0 +1,378 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
)
|
||||
|
||||
// Run in the main entry point for the dbtable2crud cmd.
|
||||
func Run(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir, goSrcPath string) error {
|
||||
log.SetPrefix(log.Prefix() + " : dbtable2crud")
|
||||
|
||||
// Ensure the schema is up to date
|
||||
if err := schema.Migrate(db, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// When dbTable is empty, lower case the model name
|
||||
if dbTable == "" {
|
||||
dbTable = strings.Join(camelcase.Split(modelName), " ")
|
||||
dbTable = english.PluralWord(2, dbTable, "")
|
||||
dbTable = strings.Replace(dbTable, " ", "_", -1)
|
||||
dbTable = strings.ToLower(dbTable)
|
||||
}
|
||||
|
||||
// Parse the model file and load the specified model struct.
|
||||
model, err := parseModelFile(db, log, dbName, dbTable, modelFile, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic lint of the model struct.
|
||||
err = validateModel(log, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmplData := map[string]interface{}{
|
||||
"GoSrcPath": goSrcPath,
|
||||
}
|
||||
|
||||
// Update the model file with new or updated code.
|
||||
err = updateModel(log, model, templateDir, tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the model crud file with new or updated code.
|
||||
err = updateModelCrud(db, log, dbName, dbTable, modelFile, modelName, templateDir, model, tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateModel performs a basic lint of the model struct to ensure
|
||||
// code gen output is correct.
|
||||
func validateModel(log *log.Logger, model *modelDef) error {
|
||||
for _, sf := range model.Fields {
|
||||
if sf.DbColumn == nil && sf.ColumnName != "-" {
|
||||
log.Printf("validateStruct : Unable to find struct field for db column %s\n", sf.ColumnName)
|
||||
}
|
||||
|
||||
var expectedType string
|
||||
switch sf.FieldName {
|
||||
case "ID":
|
||||
expectedType = "string"
|
||||
case "CreatedAt":
|
||||
expectedType = "time.Time"
|
||||
case "UpdatedAt":
|
||||
expectedType = "time.Time"
|
||||
case "ArchivedAt":
|
||||
expectedType = "pq.NullTime"
|
||||
}
|
||||
|
||||
if expectedType != "" && expectedType != sf.FieldType {
|
||||
log.Printf("validateStruct : Struct field %s should be of type %s not %s\n", sf.FieldName, expectedType, sf.FieldType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModel updated the parsed code file with the new code.
|
||||
func updateModel(log *log.Logger, model *modelDef, templateDir string, tmplData map[string]interface{}) error {
|
||||
|
||||
// Execute template and parse code to be used to compare against modelFile.
|
||||
tmplObjs, err := loadTemplateObjects(log, model, templateDir, "models.tmpl", tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the current code as a string to produce a diff.
|
||||
curCode := model.String()
|
||||
|
||||
objHeaders := []*goparse.GoObject{}
|
||||
|
||||
for _, obj := range tmplObjs {
|
||||
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
|
||||
objHeaders = append(objHeaders, obj)
|
||||
continue
|
||||
}
|
||||
|
||||
if model.HasType(obj.Name, obj.Type) {
|
||||
cur := model.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
if len(objHeaders) > 0 {
|
||||
// Remove any comments and linebreaks before the existing object so updates can be added.
|
||||
removeObjs := []*goparse.GoObject{}
|
||||
for idx := cur.Index - 1; idx > 0; idx-- {
|
||||
prevObj := model.Objects().List()[idx]
|
||||
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
|
||||
removeObjs = append(removeObjs, prevObj)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(removeObjs) > 0 {
|
||||
err := model.Objects().Remove(removeObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the current index is correct.
|
||||
cur = model.Objects().Get(obj.Name, obj.Type)
|
||||
}
|
||||
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
newObjs = append(newObjs, c)
|
||||
}
|
||||
}
|
||||
|
||||
newObjs = append(newObjs, obj)
|
||||
|
||||
// Do the object replacement.
|
||||
err := model.Objects().Replace(cur, newObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := model.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := model.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
objHeaders = []*goparse.GoObject{}
|
||||
}
|
||||
|
||||
// Set some flags to determine additional imports and need to be added.
|
||||
var hasEnum bool
|
||||
var hasPq bool
|
||||
for _, f := range model.Fields {
|
||||
if f.DbColumn != nil && f.DbColumn.IsEnum {
|
||||
hasEnum = true
|
||||
}
|
||||
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
|
||||
hasPq = true
|
||||
}
|
||||
}
|
||||
|
||||
reqImports := []string{}
|
||||
if hasEnum {
|
||||
reqImports = append(reqImports, "database/sql/driver")
|
||||
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
|
||||
reqImports = append(reqImports, "github.com/pkg/errors")
|
||||
}
|
||||
|
||||
if hasPq {
|
||||
reqImports = append(reqImports, "github.com/lib/pq")
|
||||
}
|
||||
|
||||
for _, in := range reqImports {
|
||||
err := model.AddImport(goparse.GoImport{Name: in})
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Produce a diff after the updates have been applied.
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(curCode, model.String(), true)
|
||||
|
||||
fmt.Println(dmp.DiffPrettyText(diffs))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModelCrud updated the parsed code file with the new code.
|
||||
func updateModelCrud(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir string, baseModel *modelDef, tmplData map[string]interface{}) error {
|
||||
|
||||
modelDir := filepath.Dir(modelFile)
|
||||
crudFile := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+".go")
|
||||
|
||||
var crudDoc *goparse.GoDocument
|
||||
if _, err := os.Stat(crudFile); os.IsNotExist(err) {
|
||||
crudDoc, err = goparse.NewGoDocument(baseModel.Package)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Parse the supplied model file.
|
||||
crudDoc, err = goparse.ParseFile(log, modelFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Load all the updated struct fields from the base model file.
|
||||
structFields := make(map[string]map[string]modelField)
|
||||
for _, obj := range baseModel.GoDocument.Objects().List() {
|
||||
if obj.Type != goparse.GoObjectType_Struct || obj.Name == baseModel.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
objFields, err := parseModelFields(baseModel.GoDocument, obj.Name, baseModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
structFields[obj.Name] = make(map[string]modelField)
|
||||
for _, f := range objFields {
|
||||
structFields[obj.Name][f.FieldName] = f
|
||||
}
|
||||
}
|
||||
|
||||
// Append the struct fields to be used for template execution.
|
||||
if tmplData == nil {
|
||||
tmplData = make(map[string]interface{})
|
||||
}
|
||||
tmplData["StructFields"] = structFields
|
||||
|
||||
// Execute template and parse code to be used to compare against modelFile.
|
||||
tmplObjs, err := loadTemplateObjects(log, baseModel, templateDir, "model_crud.tmpl", tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the current code as a string to produce a diff.
|
||||
curCode := crudDoc.String()
|
||||
|
||||
objHeaders := []*goparse.GoObject{}
|
||||
|
||||
for _, obj := range tmplObjs {
|
||||
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
|
||||
objHeaders = append(objHeaders, obj)
|
||||
continue
|
||||
}
|
||||
|
||||
if crudDoc.HasType(obj.Name, obj.Type) {
|
||||
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
if len(objHeaders) > 0 {
|
||||
// Remove any comments and linebreaks before the existing object so updates can be added.
|
||||
removeObjs := []*goparse.GoObject{}
|
||||
for idx := cur.Index - 1; idx > 0; idx-- {
|
||||
prevObj := crudDoc.Objects().List()[idx]
|
||||
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
|
||||
removeObjs = append(removeObjs, prevObj)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(removeObjs) > 0 {
|
||||
err := crudDoc.Objects().Remove(removeObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the current index is correct.
|
||||
cur = crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
}
|
||||
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
newObjs = append(newObjs, c)
|
||||
}
|
||||
}
|
||||
|
||||
newObjs = append(newObjs, obj)
|
||||
|
||||
// Do the object replacement.
|
||||
err := crudDoc.Objects().Replace(cur, newObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := crudDoc.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := crudDoc.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
objHeaders = []*goparse.GoObject{}
|
||||
}
|
||||
|
||||
/*
|
||||
// Set some flags to determine additional imports and need to be added.
|
||||
var hasEnum bool
|
||||
var hasPq bool
|
||||
for _, f := range crudModel.Fields {
|
||||
if f.DbColumn != nil && f.DbColumn.IsEnum {
|
||||
hasEnum = true
|
||||
}
|
||||
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
|
||||
hasPq = true
|
||||
}
|
||||
}
|
||||
|
||||
reqImports := []string{}
|
||||
if hasEnum {
|
||||
reqImports = append(reqImports, "database/sql/driver")
|
||||
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
|
||||
reqImports = append(reqImports, "github.com/pkg/errors")
|
||||
}
|
||||
|
||||
if hasPq {
|
||||
reqImports = append(reqImports, "github.com/lib/pq")
|
||||
}
|
||||
|
||||
for _, in := range reqImports {
|
||||
err := model.AddImport(goparse.GoImport{Name: in})
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, crudModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// Produce a diff after the updates have been applied.
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(curCode, crudDoc.String(), true)
|
||||
|
||||
fmt.Println(dmp.DiffPrettyText(diffs))
|
||||
|
||||
return nil
|
||||
}
|
229
example-project/tools/truss/cmd/dbtable2crud/models.go
Normal file
229
example-project/tools/truss/cmd/dbtable2crud/models.go
Normal file
@ -0,0 +1,229 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
|
||||
"github.com/fatih/structtag"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// modelDef defines info about the struct and associated db table.
|
||||
type modelDef struct {
|
||||
*goparse.GoDocument
|
||||
Name string
|
||||
TableName string
|
||||
PrimaryField string
|
||||
PrimaryColumn string
|
||||
PrimaryType string
|
||||
Fields []modelField
|
||||
FieldNames []string
|
||||
ColumnNames []string
|
||||
}
|
||||
|
||||
// modelField defines a struct field and associated db column.
|
||||
type modelField struct {
|
||||
ColumnName string
|
||||
DbColumn *psqlColumn
|
||||
FieldName string
|
||||
FieldType string
|
||||
FieldIsPtr bool
|
||||
Tags *structtag.Tags
|
||||
ApiHide bool
|
||||
ApiRead bool
|
||||
ApiCreate bool
|
||||
ApiUpdate bool
|
||||
DefaultValue string
|
||||
}
|
||||
|
||||
// parseModelFile parses the entire model file and then loads the specified model struct.
|
||||
func parseModelFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName string) (*modelDef, error) {
|
||||
|
||||
// Parse the supplied model file.
|
||||
doc, err := goparse.ParseFile(log, modelFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Init new modelDef.
|
||||
model := &modelDef{
|
||||
GoDocument: doc,
|
||||
Name: modelName,
|
||||
TableName: dbTable,
|
||||
}
|
||||
|
||||
// Append the field the the model def.
|
||||
model.Fields, err = parseModelFields(doc, modelName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sf := range model.Fields {
|
||||
model.FieldNames = append(model.FieldNames, sf.FieldName)
|
||||
model.ColumnNames = append(model.ColumnNames, sf.ColumnName)
|
||||
}
|
||||
|
||||
// Query the database for a table definition.
|
||||
dbCols, err := descTable(db, dbName, dbTable)
|
||||
if err != nil {
|
||||
return model, err
|
||||
}
|
||||
|
||||
// Loop over all the database table columns and append to the associated
|
||||
// struct field. Don't force all database table columns to be defined in the
|
||||
// in the struct.
|
||||
for _, dbCol := range dbCols {
|
||||
for idx, sf := range model.Fields {
|
||||
if sf.ColumnName != dbCol.Column {
|
||||
continue
|
||||
}
|
||||
|
||||
if dbCol.IsPrimaryKey {
|
||||
model.PrimaryColumn = sf.ColumnName
|
||||
model.PrimaryField = sf.FieldName
|
||||
model.PrimaryType = sf.FieldType
|
||||
}
|
||||
|
||||
if dbCol.DefaultValue != nil {
|
||||
sf.DefaultValue = *dbCol.DefaultValue
|
||||
|
||||
if dbCol.IsEnum {
|
||||
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
|
||||
sf.DefaultValue = sf.FieldType + "_" + FormatCamel(sf.DefaultValue)
|
||||
} else if strings.HasPrefix(sf.DefaultValue, "'") {
|
||||
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
|
||||
sf.DefaultValue = "\"" + sf.DefaultValue + "\""
|
||||
}
|
||||
}
|
||||
|
||||
c := dbCol
|
||||
sf.DbColumn = &c
|
||||
model.Fields[idx] = sf
|
||||
}
|
||||
}
|
||||
|
||||
// Print out the model for debugging.
|
||||
//modelJSON, err := json.MarshalIndent(model, "", " ")
|
||||
//if err != nil {
|
||||
// return model, errors.WithStack(err )
|
||||
//}
|
||||
//log.Printf(string(modelJSON))
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// parseModelFields parses the fields from a struct.
|
||||
func parseModelFields(doc *goparse.GoDocument, modelName string, baseModel *modelDef) ([]modelField, error) {
|
||||
|
||||
// Ensure the model file has a struct with the model name supplied.
|
||||
if !doc.HasType(modelName, goparse.GoObjectType_Struct) {
|
||||
err := errors.Errorf("Struct with the name %s does not exist", modelName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load the struct from parsed go file.
|
||||
docModel := doc.Get(modelName, goparse.GoObjectType_Struct)
|
||||
|
||||
// Loop over all the objects contained between the struct definition start and end.
|
||||
// This should be a list of variables defined for model.
|
||||
resp := []modelField{}
|
||||
for _, l := range docModel.Objects().List() {
|
||||
|
||||
// Skip all lines that are not a var.
|
||||
if l.Type != goparse.GoObjectType_Line {
|
||||
log.Printf("parseModelFile : Model %s has line that is %s, not type line, skipping - %s\n", modelName, l.Type, l.String())
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the var name, type and defined tags from the line.
|
||||
sv, err := goparse.ParseStructProp(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Init new modelField for the struct var.
|
||||
sf := modelField{
|
||||
FieldName: sv.Name,
|
||||
FieldType: sv.Type,
|
||||
FieldIsPtr: strings.HasPrefix(sv.Type, "*"),
|
||||
Tags: sv.Tags,
|
||||
}
|
||||
|
||||
// Extract the column name from the var tags.
|
||||
if sf.Tags != nil {
|
||||
// First try to get the column name from the db tag.
|
||||
dbt, err := sf.Tags.Get("db")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if dbt != nil {
|
||||
sf.ColumnName = dbt.Name
|
||||
}
|
||||
|
||||
// Second try to get the column name from the json tag.
|
||||
if sf.ColumnName == "" {
|
||||
jt, err := sf.Tags.Get("json")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if jt != nil && jt.Name != "-" {
|
||||
sf.ColumnName = jt.Name
|
||||
}
|
||||
}
|
||||
|
||||
var apiActionsSet bool
|
||||
tt, err := sf.Tags.Get("truss")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if tt != nil {
|
||||
if tt.Name == "api-create" || tt.HasOption("api-create") {
|
||||
sf.ApiCreate = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-read" || tt.HasOption("api-read") {
|
||||
sf.ApiRead = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-update" || tt.HasOption("api-update") {
|
||||
sf.ApiUpdate = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-hide" || tt.HasOption("api-hide") {
|
||||
sf.ApiHide = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
}
|
||||
|
||||
if !apiActionsSet {
|
||||
sf.ApiCreate = true
|
||||
sf.ApiRead = true
|
||||
sf.ApiUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set the column name to the field name if empty and does not equal '-'.
|
||||
if sf.ColumnName == "" {
|
||||
sf.ColumnName = sf.FieldName
|
||||
}
|
||||
|
||||
// If a base model as already been parsed with the db columns,
|
||||
// append to the current field.
|
||||
if baseModel != nil {
|
||||
for _, baseSf := range baseModel.Fields {
|
||||
if baseSf.ColumnName == sf.ColumnName {
|
||||
sf.DefaultValue = baseSf.DefaultValue
|
||||
sf.DbColumn = baseSf.DbColumn
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Append the field the the model def.
|
||||
resp = append(resp, sf)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
345
example-project/tools/truss/cmd/dbtable2crud/templates.go
Normal file
345
example-project/tools/truss/cmd/dbtable2crud/templates.go
Normal file
@ -0,0 +1,345 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// loadTemplateObjects executes a template file based on the given model struct and
|
||||
// returns the parsed go objects.
|
||||
func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename string, tmptData map[string]interface{}) ([]*goparse.GoObject, error) {
|
||||
|
||||
// Data used to execute all the of defined code sections in the template file.
|
||||
if tmptData == nil {
|
||||
tmptData = make(map[string]interface{})
|
||||
}
|
||||
tmptData["Model"] = model
|
||||
|
||||
// geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
|
||||
// Read the template file from the local file system.
|
||||
tempFilePath := filepath.Join(templateDir, filename)
|
||||
dat, err := ioutil.ReadFile(tempFilePath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// New template with custom functions.
|
||||
baseTmpl := template.New("base")
|
||||
baseTmpl.Funcs(template.FuncMap{
|
||||
"Concat": func(vals ...string) string {
|
||||
return strings.Join(vals, "")
|
||||
},
|
||||
"JoinStrings": func(vals []string, sep string) string {
|
||||
return strings.Join(vals, sep)
|
||||
},
|
||||
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, pre+v)
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, fmt.Sprintf(fmtStr, v))
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FormatCamel": func(name string) string {
|
||||
return FormatCamel(name)
|
||||
},
|
||||
"FormatCamelTitle": func(name string) string {
|
||||
return FormatCamelTitle(name)
|
||||
},
|
||||
"FormatCamelLower": func(name string) string {
|
||||
if name == "ID" {
|
||||
return "id"
|
||||
}
|
||||
return FormatCamelLower(name)
|
||||
},
|
||||
"FormatCamelLowerTitle": func(name string) string {
|
||||
return FormatCamelLowerTitle(name)
|
||||
},
|
||||
"FormatCamelPluralTitle": func(name string) string {
|
||||
return FormatCamelPluralTitle(name)
|
||||
},
|
||||
"FormatCamelPluralTitleLower": func(name string) string {
|
||||
return FormatCamelPluralTitleLower(name)
|
||||
},
|
||||
"FormatCamelPluralCamel": func(name string) string {
|
||||
return FormatCamelPluralCamel(name)
|
||||
},
|
||||
"FormatCamelPluralLower": func(name string) string {
|
||||
return FormatCamelPluralLower(name)
|
||||
},
|
||||
"FormatCamelPluralUnderscore": func(name string) string {
|
||||
return FormatCamelPluralUnderscore(name)
|
||||
},
|
||||
"FormatCamelPluralLowerUnderscore": func(name string) string {
|
||||
return FormatCamelPluralLowerUnderscore(name)
|
||||
},
|
||||
"FormatCamelUnderscore": func(name string) string {
|
||||
return FormatCamelUnderscore(name)
|
||||
},
|
||||
"FormatCamelLowerUnderscore": func(name string) string {
|
||||
return FormatCamelLowerUnderscore(name)
|
||||
},
|
||||
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
|
||||
if f.Tags == nil {
|
||||
return false
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return false
|
||||
}
|
||||
if ft.Name == optName || ft.HasOption(optName) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
"FieldTag": func(f modelField, tagName string) string {
|
||||
if f.Tags == nil {
|
||||
return ""
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return ""
|
||||
}
|
||||
return ft.String()
|
||||
},
|
||||
"FieldTagReplaceOrPrepend": func(f modelField, tagName, oldVal, newVal string) string {
|
||||
if f.Tags == nil {
|
||||
return ""
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if ft.Name == oldVal || ft.Name == newVal {
|
||||
ft.Name = newVal
|
||||
} else if ft.HasOption(oldVal) {
|
||||
for idx, val := range ft.Options {
|
||||
if val == oldVal {
|
||||
ft.Options[idx] = newVal
|
||||
}
|
||||
}
|
||||
} else if !ft.HasOption(newVal) {
|
||||
if ft.Name == "" {
|
||||
ft.Name = newVal
|
||||
} else {
|
||||
ft.Options = append(ft.Options, newVal)
|
||||
}
|
||||
}
|
||||
|
||||
return ft.String()
|
||||
},
|
||||
"StringListHasValue": func(list []string, val string) bool {
|
||||
for _, v := range list {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
|
||||
// Load the template file using the text/template package.
|
||||
tmpl, err := baseTmpl.Parse(string(dat))
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a list of template names defined in the template file.
|
||||
tmplNames := []string{}
|
||||
for _, defTmpl := range tmpl.Templates() {
|
||||
tmplNames = append(tmplNames, defTmpl.Name())
|
||||
}
|
||||
|
||||
// Stupid hack to return template names the in order they are defined in the file.
|
||||
tmplNames, err = templateFileOrderedNames(tempFilePath, tmplNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Loop over all the defined templates, execute using the defined data, parse the
|
||||
// formatted code and append the parsed go objects to the result list.
|
||||
var resp []*goparse.GoObject
|
||||
for _, tmplName := range tmplNames {
|
||||
// Executed the defined template with the given data.
|
||||
var tpl bytes.Buffer
|
||||
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Format the source code to ensure its valid and code to parsed consistently.
|
||||
codeBytes, err := format.Source(tpl.Bytes())
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
|
||||
|
||||
dl := []string{}
|
||||
for idx, l := range strings.Split(tpl.String(), "\n") {
|
||||
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
|
||||
}
|
||||
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Remove extra white space from the code.
|
||||
codeStr := strings.TrimSpace(string(codeBytes))
|
||||
|
||||
// Split the code into a list of strings.
|
||||
codeLines := strings.Split(codeStr, "\n")
|
||||
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := goparse.ParseLines(codeLines, 0)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Append the parsed objects to the return result list.
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Name == "" && obj.Type != goparse.GoObjectType_Import && obj.Type != goparse.GoObjectType_Var && obj.Type != goparse.GoObjectType_Const && obj.Type != goparse.GoObjectType_Comment && obj.Type != goparse.GoObjectType_LineBreak {
|
||||
// All objects should have a name except for multiline var/const declarations and comments.
|
||||
err = errors.Errorf("Failed to parse name with type %s from lines: %v", obj.Type, obj.Lines())
|
||||
return resp, err
|
||||
} else if string(obj.Type) == "" {
|
||||
err = errors.Errorf("Failed to parse type for %s from lines: %v", obj.Name, obj.Lines())
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resp = append(resp, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// FormatCamel formats Valdez mountain to ValdezMountain
|
||||
func FormatCamel(name string) string {
|
||||
return strcase.ToCamel(name)
|
||||
}
|
||||
|
||||
// FormatCamelLower formats ValdezMountain to valdezmountain
|
||||
func FormatCamelLower(name string) string {
|
||||
return strcase.ToLowerCamel(FormatCamel(name))
|
||||
}
|
||||
|
||||
// FormatCamelTitle formats ValdezMountain to Valdez Mountain
|
||||
func FormatCamelTitle(name string) string {
|
||||
return strings.Join(camelcase.Split(name), " ")
|
||||
}
|
||||
|
||||
// FormatCamelLowerTitle formats ValdezMountain to valdez mountain
|
||||
func FormatCamelLowerTitle(name string) string {
|
||||
return strings.ToLower(FormatCamelTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralTitle formats ValdezMountain to Valdez Mountains
|
||||
func FormatCamelPluralTitle(name string) string {
|
||||
pts := camelcase.Split(name)
|
||||
lastIdx := len(pts) - 1
|
||||
pts[lastIdx] = english.PluralWord(2, pts[lastIdx], "")
|
||||
return strings.Join(pts, " ")
|
||||
}
|
||||
|
||||
// FormatCamelPluralTitleLower formats ValdezMountain to valdez mountains
|
||||
func FormatCamelPluralTitleLower(name string) string {
|
||||
return strings.ToLower(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralCamel formats ValdezMountain to ValdezMountains
|
||||
func FormatCamelPluralCamel(name string) string {
|
||||
return strcase.ToCamel(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralLower formats ValdezMountain to valdezmountains
|
||||
func FormatCamelPluralLower(name string) string {
|
||||
return strcase.ToLowerCamel(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralUnderscore formats ValdezMountain to Valdez_Mountains
|
||||
func FormatCamelPluralUnderscore(name string) string {
|
||||
return strings.Replace(FormatCamelPluralTitle(name), " ", "_", -1)
|
||||
}
|
||||
|
||||
// FormatCamelPluralLowerUnderscore formats ValdezMountain to valdez_mountains
|
||||
func FormatCamelPluralLowerUnderscore(name string) string {
|
||||
return strings.ToLower(FormatCamelPluralUnderscore(name))
|
||||
}
|
||||
|
||||
// FormatCamelUnderscore formats ValdezMountain to Valdez_Mountain
|
||||
func FormatCamelUnderscore(name string) string {
|
||||
return strings.Replace(FormatCamelTitle(name), " ", "_", -1)
|
||||
}
|
||||
|
||||
// FormatCamelLowerUnderscore formats ValdezMountain to valdez_mountain
|
||||
func FormatCamelLowerUnderscore(name string) string {
|
||||
return strings.ToLower(FormatCamelUnderscore(name))
|
||||
}
|
||||
|
||||
// templateFileOrderedNames returns the template names the in order they are defined in the file.
|
||||
func templateFileOrderedNames(localPath string, names []string) (resp []string, err error) {
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return resp, errors.WithStack(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
idxList := []int{}
|
||||
idxNames := make(map[int]string)
|
||||
|
||||
idx := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
if strings.Contains(scanner.Text(), "\""+name+"\"") {
|
||||
idxList = append(idxList, idx)
|
||||
idxNames[idx] = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
idx = idx + 1
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return resp, errors.WithStack(err)
|
||||
}
|
||||
|
||||
sort.Ints(idxList)
|
||||
|
||||
for _, idx := range idxList {
|
||||
resp = append(resp, idxNames[idx])
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
301
example-project/tools/truss/internal/goparse/doc.go
Normal file
301
example-project/tools/truss/internal/goparse/doc.go
Normal file
@ -0,0 +1,301 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// GoDocument defines a single go code file.
|
||||
type GoDocument struct {
|
||||
*GoObjects
|
||||
Package string
|
||||
imports GoImports
|
||||
}
|
||||
|
||||
// GoImport defines a single import line with optional alias.
|
||||
type GoImport struct {
|
||||
Name string
|
||||
Alias string
|
||||
}
|
||||
|
||||
// GoImports holds a list of import lines.
|
||||
type GoImports []GoImport
|
||||
|
||||
// NewGoDocument creates a new GoDocument with the package line set.
|
||||
func NewGoDocument(packageName string) (doc *GoDocument, err error) {
|
||||
doc = &GoDocument{
|
||||
GoObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
err = doc.SetPackage(packageName)
|
||||
return doc, err
|
||||
}
|
||||
|
||||
// Objects returns a list of root GoObject.
|
||||
func (doc *GoDocument) Objects() *GoObjects {
|
||||
if doc.GoObjects == nil {
|
||||
doc.GoObjects = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
}
|
||||
|
||||
return doc.GoObjects
|
||||
}
|
||||
|
||||
// NewObjectPackage returns a new GoObject with a single package definition line.
|
||||
func NewObjectPackage(packageName string) *GoObject {
|
||||
lines := []string{
|
||||
fmt.Sprintf("package %s", packageName),
|
||||
"",
|
||||
}
|
||||
|
||||
obj, _ := ParseGoObject(lines, 0)
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
// SetPackage appends sets the package line for the code file.
|
||||
func (doc *GoDocument) SetPackage(packageName string) error {
|
||||
|
||||
var existing *GoObject
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type == GoObjectType_Package {
|
||||
existing = obj
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
new := NewObjectPackage(packageName)
|
||||
|
||||
var err error
|
||||
if existing != nil {
|
||||
err = doc.Objects().Replace(existing, new)
|
||||
} else if len(doc.Objects().List()) > 0 {
|
||||
|
||||
// Insert after any existing comments or line breaks.
|
||||
var insertPos int
|
||||
//for idx, obj := range doc.Objects().List() {
|
||||
// switch obj.Type {
|
||||
// case GoObjectType_Comment, GoObjectType_LineBreak:
|
||||
// insertPos = idx
|
||||
// default:
|
||||
// break
|
||||
// }
|
||||
//}
|
||||
|
||||
err = doc.Objects().Insert(insertPos, new)
|
||||
} else {
|
||||
err = doc.Objects().Add(new)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AddObject appends a new GoObject to the doc root object list.
|
||||
func (doc *GoDocument) AddObject(newObj *GoObject) error {
|
||||
return doc.Objects().Add(newObj)
|
||||
}
|
||||
|
||||
// InsertObject inserts a new GoObject at the desired position to the doc root object list.
|
||||
func (doc *GoDocument) InsertObject(pos int, newObj *GoObject) error {
|
||||
return doc.Objects().Insert(pos, newObj)
|
||||
}
|
||||
|
||||
// Imports returns the GoDocument imports.
|
||||
func (doc *GoDocument) Imports() (GoImports, error) {
|
||||
// If the doc imports are empty, try to load them from the root objects.
|
||||
if len(doc.imports) == 0 {
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type != GoObjectType_Import {
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := ParseImportObject(obj)
|
||||
if err != nil {
|
||||
return doc.imports, err
|
||||
}
|
||||
|
||||
// Combine all the imports into a single definition.
|
||||
for _, n := range res {
|
||||
doc.imports = append(doc.imports, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return doc.imports, nil
|
||||
}
|
||||
|
||||
// Lines returns all the code lines.
|
||||
func (doc *GoDocument) Lines() []string {
|
||||
l := []string{}
|
||||
|
||||
for _, ol := range doc.Objects().Lines() {
|
||||
l = append(l, ol)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns a single value for all the code lines.
|
||||
func (doc *GoDocument) String() string {
|
||||
return strings.Join(doc.Lines(), "\n")
|
||||
}
|
||||
|
||||
// Print writes all the code lines to standard out.
|
||||
func (doc *GoDocument) Print() {
|
||||
for _, l := range doc.Lines() {
|
||||
fmt.Println(l)
|
||||
}
|
||||
}
|
||||
|
||||
// Save renders all the code lines for the document, formats the code
|
||||
// and then saves it to the supplied file path.
|
||||
func (doc *GoDocument) Save(localpath string) error {
|
||||
res, err := format.Source([]byte(doc.String()))
|
||||
if err != nil {
|
||||
err = errors.WithMessage(err, "Failed formatted source code")
|
||||
return err
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(localpath, res, 0644)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed write source code to file %s", localpath)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddImport checks for any duplicate imports by name and adds it if not.
|
||||
func (doc *GoDocument) AddImport(impt GoImport) error {
|
||||
impt.Name = strings.Trim(impt.Name, "\"")
|
||||
|
||||
// Get a list of current imports for the document.
|
||||
impts, err := doc.Imports()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the document has as the import, don't add it.
|
||||
if impts.Has(impt.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop through all the document root objects for an object of type import.
|
||||
// If one exists, append the import to the existing list.
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type != GoObjectType_Import || len(obj.Lines()) == 1 {
|
||||
continue
|
||||
}
|
||||
obj.subLines = append(obj.subLines, impt.String())
|
||||
obj.goObjects.list = append(obj.goObjects.list, impt.Object())
|
||||
|
||||
doc.imports = append(doc.imports, impt)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Document does not have an existing import object, so need to create one and
|
||||
// then append to the document.
|
||||
newObj := NewObjectImports(impt)
|
||||
|
||||
// Insert after any package, any existing comments or line breaks should be included.
|
||||
var insertPos int
|
||||
for idx, obj := range doc.Objects().List() {
|
||||
switch obj.Type {
|
||||
case GoObjectType_Package, GoObjectType_Comment, GoObjectType_LineBreak:
|
||||
insertPos = idx
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the new import object.
|
||||
err = doc.InsertObject(insertPos, newObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewObjectImports returns a new GoObject with a single import definition.
|
||||
func NewObjectImports(impt GoImport) *GoObject {
|
||||
lines := []string{
|
||||
"import (",
|
||||
impt.String(),
|
||||
")",
|
||||
"",
|
||||
}
|
||||
|
||||
obj, _ := ParseGoObject(lines, 0)
|
||||
children, err := ParseLines(obj.subLines, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, child := range children.List() {
|
||||
obj.Objects().Add(child)
|
||||
}
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
// Has checks to see if an import exists by name or alias.
|
||||
func (impts GoImports) Has(name string) bool {
|
||||
for _, impt := range impts {
|
||||
if name == impt.Name || (impt.Alias != "" && name == impt.Alias) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Line formats an import as a string.
|
||||
func (impt GoImport) String() string {
|
||||
var imptLine string
|
||||
if impt.Alias != "" {
|
||||
imptLine = fmt.Sprintf("\t%s \"%s\"", impt.Alias, impt.Name)
|
||||
} else {
|
||||
imptLine = fmt.Sprintf("\t\"%s\"", impt.Name)
|
||||
}
|
||||
return imptLine
|
||||
}
|
||||
|
||||
// Object returns the first GoObject for an import.
|
||||
func (impt GoImport) Object() *GoObject {
|
||||
imptObj := NewObjectImports(impt)
|
||||
|
||||
return imptObj.Objects().List()[0]
|
||||
}
|
||||
|
||||
// ParseImportObject extracts all the import definitions.
|
||||
func ParseImportObject(obj *GoObject) (resp GoImports, err error) {
|
||||
if obj.Type != GoObjectType_Import {
|
||||
return resp, errors.Errorf("Invalid type %s", string(obj.Type))
|
||||
}
|
||||
|
||||
for _, l := range obj.Lines() {
|
||||
if !strings.Contains(l, "\"") {
|
||||
continue
|
||||
}
|
||||
l = strings.TrimSpace(l)
|
||||
|
||||
pts := strings.Split(l, "\"")
|
||||
|
||||
var impt GoImport
|
||||
if strings.HasPrefix(l, "\"") {
|
||||
impt.Name = pts[1]
|
||||
} else {
|
||||
impt.Alias = strings.TrimSpace(pts[0])
|
||||
impt.Name = pts[1]
|
||||
}
|
||||
|
||||
resp = append(resp, impt)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
458
example-project/tools/truss/internal/goparse/doc_object.go
Normal file
458
example-project/tools/truss/internal/goparse/doc_object.go
Normal file
@ -0,0 +1,458 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/structtag"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// GoEmptyLine defined a GoObject for a code line break.
|
||||
var GoEmptyLine = GoObject{
|
||||
Type: GoObjectType_LineBreak,
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
|
||||
// GoObjectType defines a set of possible types to group
|
||||
// parsed code by.
|
||||
type GoObjectType = string
|
||||
|
||||
var (
|
||||
GoObjectType_Package = "package"
|
||||
GoObjectType_Import = "import"
|
||||
GoObjectType_Var = "var"
|
||||
GoObjectType_Const = "const"
|
||||
GoObjectType_Func = "func"
|
||||
GoObjectType_Struct = "struct"
|
||||
GoObjectType_Comment = "comment"
|
||||
GoObjectType_LineBreak = "linebreak"
|
||||
GoObjectType_Line = "line"
|
||||
GoObjectType_Type = "type"
|
||||
)
|
||||
|
||||
// GoObject defines a section of code with a nested set of children.
|
||||
type GoObject struct {
|
||||
Type GoObjectType
|
||||
Name string
|
||||
startLines []string
|
||||
endLines []string
|
||||
subLines []string
|
||||
goObjects *GoObjects
|
||||
Index int
|
||||
}
|
||||
|
||||
// GoObjects stores a list of GoObject.
|
||||
type GoObjects struct {
|
||||
list []*GoObject
|
||||
}
|
||||
|
||||
// Objects returns the list of *GoObject.
|
||||
func (obj *GoObject) Objects() *GoObjects {
|
||||
if obj.goObjects == nil {
|
||||
obj.goObjects = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
}
|
||||
return obj.goObjects
|
||||
}
|
||||
|
||||
// Clone performs a deep copy of the struct.
|
||||
func (obj *GoObject) Clone() *GoObject {
|
||||
n := &GoObject{
|
||||
Type: obj.Type,
|
||||
Name: obj.Name,
|
||||
startLines: obj.startLines,
|
||||
endLines: obj.endLines,
|
||||
subLines: obj.subLines,
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
Index: obj.Index,
|
||||
}
|
||||
for _, sub := range obj.Objects().List() {
|
||||
n.Objects().Add(sub.Clone())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// IsComment returns whether an object is of type GoObjectType_Comment.
|
||||
func (obj *GoObject) IsComment() bool {
|
||||
if obj.Type != GoObjectType_Comment {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Contains searches all the lines for the object for a matching string.
|
||||
func (obj *GoObject) Contains(match string) bool {
|
||||
for _, l := range obj.Lines() {
|
||||
if strings.Contains(l, match) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateLines parses the new code and replaces the current GoObject.
|
||||
func (obj *GoObject) UpdateLines(newLines []string) error {
|
||||
|
||||
// Parse the new lines.
|
||||
objs, err := ParseLines(newLines, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newObj *GoObject
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Type == GoObjectType_LineBreak {
|
||||
continue
|
||||
}
|
||||
|
||||
if newObj == nil {
|
||||
newObj = obj
|
||||
}
|
||||
|
||||
// There should only be one resulting parsed object that is
|
||||
// not of type GoObjectType_LineBreak.
|
||||
return errors.New("Can only update single blocks of code")
|
||||
}
|
||||
|
||||
// No new code was parsed, return error.
|
||||
if newObj == nil {
|
||||
return errors.New("Failed to render replacement code")
|
||||
}
|
||||
|
||||
return obj.Update(newObj)
|
||||
}
|
||||
|
||||
// Update performs a deep copy that overwrites the existing values.
|
||||
func (obj *GoObject) Update(newObj *GoObject) error {
|
||||
obj.Type = newObj.Type
|
||||
obj.Name = newObj.Name
|
||||
obj.startLines = newObj.startLines
|
||||
obj.endLines = newObj.endLines
|
||||
obj.subLines = newObj.subLines
|
||||
obj.goObjects = newObj.goObjects
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lines returns a list of strings for current object and all children.
|
||||
func (obj *GoObject) Lines() []string {
|
||||
l := []string{}
|
||||
|
||||
// First include any lines before the sub objects.
|
||||
for _, sl := range obj.startLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
|
||||
// If there are parsed sub objects include those lines else when
|
||||
// no sub objects, just use the sub lines.
|
||||
if len(obj.Objects().List()) > 0 {
|
||||
for _, sl := range obj.Objects().Lines() {
|
||||
l = append(l, sl)
|
||||
}
|
||||
} else {
|
||||
for _, sl := range obj.subLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
}
|
||||
|
||||
// Lastly include any other lines that are after all parsed sub objects.
|
||||
for _, sl := range obj.endLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns the lines separated by line break.
|
||||
func (obj *GoObject) String() string {
|
||||
return strings.Join(obj.Lines(), "\n")
|
||||
}
|
||||
|
||||
// Lines returns a list of strings for all the list objects.
|
||||
func (objs *GoObjects) Lines() []string {
|
||||
l := []string{}
|
||||
for _, obj := range objs.List() {
|
||||
for _, oj := range obj.Lines() {
|
||||
l = append(l, oj)
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns all the lines for the list objects.
|
||||
func (objs *GoObjects) String() string {
|
||||
lines := []string{}
|
||||
for _, obj := range objs.List() {
|
||||
lines = append(lines, obj.String())
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// List returns the list of GoObjects.
|
||||
func (objs *GoObjects) List() []*GoObject {
|
||||
return objs.list
|
||||
}
|
||||
|
||||
// HasFunc searches the current list of objects for a function object by name.
|
||||
func (objs *GoObjects) HasFunc(name string) bool {
|
||||
return objs.HasType(name, GoObjectType_Func)
|
||||
}
|
||||
|
||||
// Get returns the GoObject for the matching name and type.
|
||||
func (objs *GoObjects) Get(name string, objType GoObjectType) *GoObject {
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || obj.Type == objType) {
|
||||
return obj
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasType checks is a GoObject exists for the matching name and type.
|
||||
func (objs *GoObjects) HasType(name string, objType GoObjectType) bool {
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || obj.Type == objType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasObject checks to see if the exact code block exists.
|
||||
func (objs *GoObjects) HasObject(src *GoObject) bool {
|
||||
if src == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Generate the code for the supplied object.
|
||||
srcLines := []string{}
|
||||
for _, l := range src.Lines() {
|
||||
// Exclude empty lines.
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
srcLines = append(srcLines, l)
|
||||
}
|
||||
}
|
||||
srcStr := strings.Join(srcLines, "\n")
|
||||
|
||||
// Loop over all the objects and match with src code.
|
||||
for _, obj := range objs.list {
|
||||
objLines := []string{}
|
||||
for _, l := range obj.Lines() {
|
||||
// Exclude empty lines.
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
objLines = append(objLines, l)
|
||||
}
|
||||
}
|
||||
objStr := strings.Join(objLines, "\n")
|
||||
|
||||
// Return true if the current object code matches src code.
|
||||
if srcStr == objStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Add appends a new GoObject to the list.
|
||||
func (objs *GoObjects) Add(newObj *GoObject) error {
|
||||
newObj.Index = len(objs.list)
|
||||
objs.list = append(objs.list, newObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert appends a new GoObject at the desired position to the list.
|
||||
func (objs *GoObjects) Insert(pos int, newObj *GoObject) error {
|
||||
newList := []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range objs.list {
|
||||
if obj.Index < pos {
|
||||
obj.Index = newIdx
|
||||
newList = append(newList, obj)
|
||||
} else {
|
||||
if obj.Index == pos {
|
||||
newObj.Index = newIdx
|
||||
newList = append(newList, newObj)
|
||||
newIdx++
|
||||
}
|
||||
obj.Index = newIdx
|
||||
newList = append(newList, obj)
|
||||
}
|
||||
|
||||
newIdx++
|
||||
}
|
||||
|
||||
objs.list = newList
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes a GoObject from the list.
|
||||
func (objs *GoObjects) Remove(delObjs ...*GoObject) error {
|
||||
for _, delObj := range delObjs {
|
||||
oldList := objs.List()
|
||||
objs.list = []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range oldList {
|
||||
if obj.Index == delObj.Index {
|
||||
continue
|
||||
}
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Replace updates an existing GoObject while maintaining is same position.
|
||||
func (objs *GoObjects) Replace(oldObj *GoObject, newObjs ...*GoObject) error {
|
||||
if oldObj.Index >= len(objs.list) {
|
||||
return errors.WithStack(errGoObjectNotExist)
|
||||
} else if len(newObjs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldList := objs.List()
|
||||
objs.list = []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range oldList {
|
||||
if obj.Index < oldObj.Index {
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
} else if obj.Index == oldObj.Index {
|
||||
for _, newObj := range newObjs {
|
||||
newObj.Index = newIdx
|
||||
objs.list = append(objs.list, newObj)
|
||||
newIdx++
|
||||
}
|
||||
} else {
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceFuncByName finds an existing GoObject with type GoObjectType_Func by name
|
||||
// and then performs a replace with the supplied new GoObject.
|
||||
func (objs *GoObjects) ReplaceFuncByName(name string, fn *GoObject) error {
|
||||
return objs.ReplaceTypeByName(name, fn, GoObjectType_Func)
|
||||
}
|
||||
|
||||
// ReplaceTypeByName finds an existing GoObject with type by name
|
||||
// and then performs a replace with the supplied new GoObject.
|
||||
func (objs *GoObjects) ReplaceTypeByName(name string, newObj *GoObject, objType GoObjectType) error {
|
||||
if newObj.Name == "" {
|
||||
newObj.Name = name
|
||||
}
|
||||
if newObj.Type == "" && objType != "" {
|
||||
newObj.Type = objType
|
||||
}
|
||||
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || objType == obj.Type) {
|
||||
return objs.Replace(obj, newObj)
|
||||
}
|
||||
}
|
||||
return errors.WithStack(errGoObjectNotExist)
|
||||
}
|
||||
|
||||
// Empty determines if all the GoObject in the list are line breaks.
|
||||
func (objs *GoObjects) Empty() bool {
|
||||
var hasStuff bool
|
||||
for _, obj := range objs.List() {
|
||||
switch obj.Type {
|
||||
case GoObjectType_LineBreak:
|
||||
//case GoObjectType_Comment:
|
||||
//case GoObjectType_Import:
|
||||
// do nothing
|
||||
default:
|
||||
hasStuff = true
|
||||
}
|
||||
}
|
||||
return hasStuff
|
||||
}
|
||||
|
||||
// Debug prints out the GoObject to logger.
|
||||
func (obj *GoObject) Debug(log *log.Logger) {
|
||||
log.Println(obj.Name)
|
||||
log.Println(" > type:", obj.Type)
|
||||
log.Println(" > start lines:")
|
||||
for _, l := range obj.startLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
|
||||
log.Println(" > sub lines:")
|
||||
for _, l := range obj.subLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
|
||||
log.Println(" > end lines:")
|
||||
for _, l := range obj.endLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
}
|
||||
|
||||
// Defines a property of a struct.
|
||||
type structProp struct {
|
||||
Name string
|
||||
Type string
|
||||
Tags *structtag.Tags
|
||||
}
|
||||
|
||||
// ParseStructProp extracts the details for a struct property.
|
||||
func ParseStructProp(obj *GoObject) (structProp, error) {
|
||||
|
||||
if obj.Type != GoObjectType_Line {
|
||||
return structProp{}, errors.Errorf("Unable to parse object of type %s", obj.Type)
|
||||
}
|
||||
|
||||
// Remove any white space from the code line.
|
||||
ls := strings.TrimSpace(strings.Join(obj.Lines(), " "))
|
||||
|
||||
// Extract the property name and type for the line.
|
||||
// ie: ID string `json:"id"`
|
||||
var resp structProp
|
||||
for _, p := range strings.Split(ls, " ") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if resp.Name == "" {
|
||||
resp.Name = p
|
||||
} else if resp.Type == "" {
|
||||
resp.Type = p
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the line contains tags, extract and parse them.
|
||||
if strings.Contains(ls, "`") {
|
||||
tagStr := strings.Split(ls, "`")[1]
|
||||
|
||||
var err error
|
||||
resp.Tags, err = structtag.Parse(tagStr)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse struct tag for field %s: %s", resp.Name, tagStr)
|
||||
return structProp{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
329
example-project/tools/truss/internal/goparse/goparse.go
Normal file
329
example-project/tools/truss/internal/goparse/goparse.go
Normal file
@ -0,0 +1,329 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errGoParseType = errors.New("Unable to determine type for line")
|
||||
errGoTypeMissingCodeTemplate = errors.New("No code defined for type")
|
||||
errGoObjectNotExist = errors.New("GoObject does not exist")
|
||||
)
|
||||
|
||||
// ParseFile reads a go code file and parses into a easily transformable set of objects.
|
||||
func ParseFile(log *log.Logger, localPath string) (*GoDocument, error) {
|
||||
|
||||
// Read the code file.
|
||||
src, err := ioutil.ReadFile(localPath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read file %s", localPath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Format the code file source to ensure parse works.
|
||||
dat, err := format.Source(src)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for file %s", localPath)
|
||||
log.Printf("ParseFile : %v\n%v", err, string(src))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Loop of the formatted source code and generate a list of code lines.
|
||||
lines := []string{}
|
||||
r := bytes.NewReader(dat)
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed read formatted source code for file %s", localPath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Append the resulting objects to the document.
|
||||
doc := &GoDocument{}
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Type == GoObjectType_Package {
|
||||
doc.Package = obj.Name
|
||||
}
|
||||
doc.AddObject(obj)
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// ParseLines takes the list of formatted code lines and returns the GoObjects.
|
||||
func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
|
||||
objs = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
|
||||
var (
|
||||
multiLine bool
|
||||
multiComment bool
|
||||
muiliVar bool
|
||||
)
|
||||
curDepth := -1
|
||||
objLines := []string{}
|
||||
|
||||
for idx, l := range lines {
|
||||
ls := strings.TrimSpace(l)
|
||||
|
||||
ld := lineDepth(l)
|
||||
|
||||
if ld == depth {
|
||||
if strings.HasPrefix(ls, "/*") {
|
||||
multiLine = true
|
||||
multiComment = true
|
||||
} else if strings.HasSuffix(ls, "(") ||
|
||||
strings.HasSuffix(ls, "{") {
|
||||
|
||||
if !multiLine {
|
||||
multiLine = true
|
||||
}
|
||||
} else if strings.Contains(ls, "`") {
|
||||
if !multiLine && strings.Count(ls, "`")%2 != 0 {
|
||||
if muiliVar {
|
||||
muiliVar = false
|
||||
} else {
|
||||
muiliVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if multiComment {
|
||||
if strings.HasSuffix(ls, "*/") {
|
||||
multiComment = false
|
||||
multiLine = false
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(ls, ")") ||
|
||||
strings.HasPrefix(ls, "}") {
|
||||
multiLine = false
|
||||
}
|
||||
}
|
||||
|
||||
if !multiLine && !muiliVar {
|
||||
for eidx := idx + 1; eidx < len(lines); eidx++ {
|
||||
if el := lines[eidx]; strings.TrimSpace(el) == "" {
|
||||
objLines = append(objLines, el)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
obj, err := ParseGoObject(objLines, depth)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
err = objs.Add(obj)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
|
||||
objLines = []string{}
|
||||
}
|
||||
|
||||
} else if (multiLine && ld >= curDepth && ld >= depth && len(objLines) > 0) || muiliVar {
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if strings.Contains(ls, "`") {
|
||||
if !multiLine && strings.Count(ls, "`")%2 != 0 {
|
||||
if muiliVar {
|
||||
muiliVar = false
|
||||
} else {
|
||||
muiliVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, obj := range objs.List() {
|
||||
children, err := ParseLines(obj.subLines, depth+1)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
for _, child := range children.List() {
|
||||
obj.Objects().Add(child)
|
||||
}
|
||||
}
|
||||
|
||||
return objs, nil
|
||||
}
|
||||
|
||||
// ParseGoObject generates a GoObjected for the given code lines.
|
||||
func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
|
||||
|
||||
// If there are no lines, return a line break.
|
||||
if len(lines) == 0 {
|
||||
return &GoEmptyLine, nil
|
||||
}
|
||||
|
||||
firstLine := lines[0]
|
||||
firstStrip := strings.TrimSpace(firstLine)
|
||||
|
||||
if len(firstStrip) == 0 {
|
||||
return &GoEmptyLine, nil
|
||||
}
|
||||
|
||||
obj = &GoObject{
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
|
||||
if strings.HasPrefix(firstStrip, "var") {
|
||||
obj.Type = GoObjectType_Var
|
||||
} else if strings.HasPrefix(firstStrip, "const") {
|
||||
obj.Type = GoObjectType_Const
|
||||
} else if strings.HasPrefix(firstStrip, "func") {
|
||||
obj.Type = GoObjectType_Func
|
||||
|
||||
if strings.HasPrefix(firstStrip, "func (") {
|
||||
funcLine := strings.TrimLeft(strings.TrimSpace(strings.TrimLeft(firstStrip, "func ")), "(")
|
||||
|
||||
var structName string
|
||||
pts := strings.Split(strings.Split(funcLine, ")")[0], " ")
|
||||
for i := len(pts) - 1; i >= 0; i-- {
|
||||
ptVal := strings.TrimSpace(pts[i])
|
||||
if ptVal != "" {
|
||||
structName = ptVal
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var funcName string
|
||||
pts = strings.Split(strings.Split(funcLine, "(")[0], " ")
|
||||
for i := len(pts) - 1; i >= 0; i-- {
|
||||
ptVal := strings.TrimSpace(pts[i])
|
||||
if ptVal != "" {
|
||||
funcName = ptVal
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
obj.Name = fmt.Sprintf("%s.%s", structName, funcName)
|
||||
} else {
|
||||
obj.Name = strings.TrimLeft(firstStrip, "func ")
|
||||
obj.Name = strings.Split(obj.Name, "(")[0]
|
||||
}
|
||||
} else if strings.HasSuffix(firstStrip, "struct {") || strings.HasSuffix(firstStrip, "struct{") {
|
||||
obj.Type = GoObjectType_Struct
|
||||
|
||||
if strings.HasPrefix(firstStrip, "type ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
} else if strings.HasPrefix(firstStrip, "type") {
|
||||
obj.Type = GoObjectType_Type
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
} else if strings.HasPrefix(firstStrip, "package") {
|
||||
obj.Name = strings.TrimSpace(strings.TrimLeft(firstStrip, "package "))
|
||||
|
||||
obj.Type = GoObjectType_Package
|
||||
} else if strings.HasPrefix(firstStrip, "import") {
|
||||
obj.Type = GoObjectType_Import
|
||||
} else if strings.HasPrefix(firstStrip, "//") {
|
||||
obj.Type = GoObjectType_Comment
|
||||
} else if strings.HasPrefix(firstStrip, "/*") {
|
||||
obj.Type = GoObjectType_Comment
|
||||
} else {
|
||||
if depth > 0 {
|
||||
obj.Type = GoObjectType_Line
|
||||
} else {
|
||||
err = errors.WithStack(errGoParseType)
|
||||
return obj, err
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
hasSub bool
|
||||
muiliVarStart bool
|
||||
muiliVarSub bool
|
||||
muiliVarEnd bool
|
||||
)
|
||||
for _, l := range lines {
|
||||
ld := lineDepth(l)
|
||||
if (ld == depth && !muiliVarSub) || muiliVarStart || muiliVarEnd {
|
||||
if hasSub && !muiliVarStart {
|
||||
if strings.TrimSpace(l) != "" {
|
||||
obj.endLines = append(obj.endLines, l)
|
||||
}
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarEnd {
|
||||
muiliVarEnd = false
|
||||
} else {
|
||||
muiliVarEnd = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
obj.startLines = append(obj.startLines, l)
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarStart {
|
||||
muiliVarStart = false
|
||||
} else {
|
||||
muiliVarStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if ld > depth || muiliVarSub {
|
||||
obj.subLines = append(obj.subLines, l)
|
||||
hasSub = true
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarSub {
|
||||
muiliVarSub = false
|
||||
} else {
|
||||
muiliVarSub = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add trailing linebreak
|
||||
if len(obj.endLines) > 0 {
|
||||
obj.endLines = append(obj.endLines, "")
|
||||
}
|
||||
|
||||
return obj, err
|
||||
}
|
||||
|
||||
// lineDepth returns the number of spaces for the given code line
|
||||
// used to determine the code level for nesting objects.
|
||||
func lineDepth(l string) int {
|
||||
depth := len(l) - len(strings.TrimLeftFunc(l, unicode.IsSpace))
|
||||
|
||||
ls := strings.TrimSpace(l)
|
||||
if strings.HasPrefix(ls, "}") && strings.Contains(ls, " else ") {
|
||||
depth = depth + 1
|
||||
} else if strings.HasPrefix(ls, "case ") {
|
||||
depth = depth + 1
|
||||
}
|
||||
return depth
|
||||
}
|
195
example-project/tools/truss/internal/goparse/goparse_test.go
Normal file
195
example-project/tools/truss/internal/goparse/goparse_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var logger *log.Logger
|
||||
|
||||
func init() {
|
||||
logger = log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
}
|
||||
|
||||
func TestParseFileModel1(t *testing.T) {
|
||||
|
||||
_, err := ParseFile(logger, "test_gofile_model1.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestMultilineVar(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `func ContextAllowedAccountIds(ctx context.Context, db *gorm.DB) (resp akdatamodels.Uint32List, err error) {
|
||||
resp = []uint32{}
|
||||
accountId := akcontext.ContextAccountId(ctx)
|
||||
m := datamodels.UserAccount{}
|
||||
q := fmt.Sprintf("select
|
||||
distinct account_id
|
||||
from %s where account_id = ?", m.TableName())
|
||||
db = db.Raw(q, accountId)
|
||||
}
|
||||
`
|
||||
code = strings.Replace(code, "\"", "`", -1)
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func TestNewDocImports(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
expected := []string{
|
||||
"package goparse",
|
||||
"",
|
||||
"import (",
|
||||
" \"github.com/go/pkg1\"",
|
||||
" \"github.com/go/pkg2\"",
|
||||
")",
|
||||
"",
|
||||
}
|
||||
|
||||
doc := &GoDocument{}
|
||||
doc.SetPackage("goparse")
|
||||
|
||||
doc.AddImport(GoImport{Name: "github.com/go/pkg1"})
|
||||
doc.AddImport(GoImport{Name: "github.com/go/pkg2"})
|
||||
|
||||
g.Expect(doc.Lines()).Should(gomega.Equal(expected))
|
||||
}
|
||||
|
||||
func TestParseLines1(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
obj := datamodels.MockModelNew()
|
||||
resp, err := ModelCreate(ctx, DB, &obj)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(resp.Name).Should(gomega.Equal(obj.Name))
|
||||
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
|
||||
return resp
|
||||
}
|
||||
`
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func TestParseLines2(t *testing.T) {
|
||||
code := `func structToMap(s interface{}) (resp map[string]interface{}) {
|
||||
dat, _ := json.Marshal(s)
|
||||
_ = json.Unmarshal(dat, &resp)
|
||||
for k, x := range resp {
|
||||
switch v := x.(type) {
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
delete(resp, k)
|
||||
}
|
||||
|
||||
case *time.Time:
|
||||
if v == nil || v.IsZero() {
|
||||
delete(resp, k)
|
||||
}
|
||||
|
||||
case nil:
|
||||
delete(resp, k)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
`
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
testLineTextMatches(t, objs.Lines(), lines)
|
||||
}
|
||||
|
||||
func TestParseLines3(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `type UserAccountRoleName = string
|
||||
|
||||
const (
|
||||
UserAccountRoleName_None UserAccountRoleName = ""
|
||||
UserAccountRoleName_Admin UserAccountRoleName = "admin"
|
||||
UserAccountRoleName_User UserAccountRoleName = "user"
|
||||
)
|
||||
|
||||
type UserAccountRole struct {
|
||||
Id uint32 ^gorm:"column:id;type:int(10) unsigned AUTO_INCREMENT;primary_key;not null;auto_increment;" truss:"internal:true"^
|
||||
CreatedAt time.Time ^gorm:"column:created_at;type:datetime;default:CURRENT_TIMESTAMP;not null;" truss:"internal:true"^
|
||||
UpdatedAt time.Time ^gorm:"column:updated_at;type:datetime;" truss:"internal:true"^
|
||||
DeletedAt *time.Time ^gorm:"column:deleted_at;type:datetime;" truss:"internal:true"^
|
||||
Role UserAccountRoleName ^gorm:"unique_index:user_account_role;column:role;type:enum('admin', 'user')"^
|
||||
// belongs to User
|
||||
User *User ^gorm:"foreignkey:UserId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true"^
|
||||
UserId uint32 ^gorm:"unique_index:user_account_role;"^
|
||||
// belongs to Account
|
||||
Account *Account ^gorm:"foreignkey:AccountId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true;api_ro:true;"^
|
||||
AccountId uint32 ^gorm:"unique_index:user_account_role;" truss:"internal:true;api_ro:true;"^
|
||||
}
|
||||
|
||||
func (UserAccountRole) TableName() string {
|
||||
return "user_account_roles"
|
||||
}
|
||||
`
|
||||
code = strings.Replace(code, "^", "'", -1)
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func testLineTextMatches(t *testing.T, l1, l2 []string) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
m1 := []string{}
|
||||
for _, l := range l1 {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
m1 = append(m1, l)
|
||||
}
|
||||
}
|
||||
|
||||
m2 := []string{}
|
||||
for _, l := range l2 {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
m2 = append(m2, l)
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(m1).Should(gomega.Equal(m2))
|
||||
}
|
@ -0,0 +1,126 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
|
||||
// Account represents someone with access to our system.
|
||||
type Account struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Address1 string `json:"address1"`
|
||||
Address2 string `json:"address2"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
Zipcode string `json:"zipcode"`
|
||||
Status AccountStatus `json:"status"`
|
||||
Timezone string `json:"timezone"`
|
||||
SignupUserID sql.NullString `json:"signup_user_id"`
|
||||
BillingUserID sql.NullString `json:"billing_user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at"`
|
||||
}
|
||||
|
||||
// CreateAccountRequest contains information needed to create a new Account.
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name" validate:"required,unique"`
|
||||
Address1 string `json:"address1" validate:"required"`
|
||||
Address2 string `json:"address2" validate:"omitempty"`
|
||||
City string `json:"city" validate:"required"`
|
||||
Region string `json:"region" validate:"required"`
|
||||
Country string `json:"country" validate:"required"`
|
||||
Zipcode string `json:"zipcode" validate:"required"`
|
||||
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
|
||||
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
|
||||
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest defines what information may be provided to modify an existing
|
||||
// Account. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank. Normally
|
||||
// we do not want to use pointers to basic types but we make exceptions around
|
||||
// marshalling/unmarshalling.
|
||||
type UpdateAccountRequest struct {
|
||||
ID string `validate:"required,uuid"`
|
||||
Name *string `json:"name" validate:"omitempty,unique"`
|
||||
Address1 *string `json:"address1" validate:"omitempty"`
|
||||
Address2 *string `json:"address2" validate:"omitempty"`
|
||||
City *string `json:"city" validate:"omitempty"`
|
||||
Region *string `json:"region" validate:"omitempty"`
|
||||
Country *string `json:"country" validate:"omitempty"`
|
||||
Zipcode *string `json:"zipcode" validate:"omitempty"`
|
||||
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
|
||||
Timezone *string `json:"timezone" validate:"omitempty"`
|
||||
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
|
||||
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
|
||||
}
|
||||
|
||||
// AccountFindRequest defines the possible options to search for accounts. By default
|
||||
// archived accounts will be excluded from response.
|
||||
type AccountFindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
IncludedArchived bool
|
||||
}
|
||||
|
||||
// AccountStatus represents the status of an account.
|
||||
type AccountStatus string
|
||||
|
||||
// AccountStatus values define the status field of a user account.
|
||||
const (
|
||||
// AccountStatus_Active defines the state when a user can access an account.
|
||||
AccountStatus_Active AccountStatus = "active"
|
||||
// AccountStatus_Pending defined the state when an account was created but
|
||||
// not activated.
|
||||
AccountStatus_Pending AccountStatus = "pending"
|
||||
// AccountStatus_Disabled defines the state when a user has been disabled from
|
||||
// accessing an account.
|
||||
AccountStatus_Disabled AccountStatus = "disabled"
|
||||
)
|
||||
|
||||
// AccountStatus_Values provides list of valid AccountStatus values.
|
||||
var AccountStatus_Values = []AccountStatus{
|
||||
AccountStatus_Active,
|
||||
AccountStatus_Pending,
|
||||
AccountStatus_Disabled,
|
||||
}
|
||||
|
||||
// Scan supports reading the AccountStatus value from the database.
|
||||
func (s *AccountStatus) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = AccountStatus(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the AccountStatus value to be stored in the database.
|
||||
func (s AccountStatus) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof=active invited disabled")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the AccountStatus value to a string.
|
||||
func (s AccountStatus) String() string {
|
||||
return string(s)
|
||||
}
|
227
example-project/tools/truss/main.go
Normal file
227
example-project/tools/truss/main.go
Normal file
@ -0,0 +1,227 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/cmd/dbtable2crud"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/lib/pq"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli"
|
||||
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
|
||||
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
// service is the name of the program used for logging, tracing and the
|
||||
// the prefix used for loading env variables
|
||||
// ie: export TRUSS_ENV=dev
|
||||
var service = "TRUSS"
|
||||
|
||||
func main() {
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
var cfg struct {
|
||||
DB struct {
|
||||
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
|
||||
User string `default:"postgres" envconfig:"USER"`
|
||||
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
|
||||
Database string `default:"shared" envconfig:"DATABASE"`
|
||||
Driver string `default:"postgres" envconfig:"DRIVER"`
|
||||
Timezone string `default:"utc" envconfig:"TIMEZONE"`
|
||||
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
|
||||
}
|
||||
}
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(service, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
// TODO: can't use flag.Process here since it doesn't support nested arg options
|
||||
//if err := flag.Process(&cfg); err != nil {
|
||||
/// if err != flag.ErrHelp {
|
||||
// log.Fatalf("main : Parsing Command Line : %v", err)
|
||||
// }
|
||||
// return // We displayed help.
|
||||
//}
|
||||
|
||||
// =========================================================================
|
||||
// Log App Info
|
||||
|
||||
// Print the build version for our logs. Also expose it under /debug/vars.
|
||||
expvar.NewString("build").Set(build)
|
||||
log.Printf("main : Started : Application Initializing version %q", build)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// Print the config for our logs. It's important to any credentials in the config
|
||||
// that could expose a security risk are excluded from being json encoded by
|
||||
// applying the tag `json:"-"` to the struct var.
|
||||
{
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Database
|
||||
var dbUrl url.URL
|
||||
{
|
||||
// Query parameters.
|
||||
var q url.Values = make(map[string][]string)
|
||||
|
||||
// Handle SSL Mode
|
||||
if cfg.DB.DisableTLS {
|
||||
q.Set("sslmode", "disable")
|
||||
} else {
|
||||
q.Set("sslmode", "require")
|
||||
}
|
||||
|
||||
q.Set("timezone", cfg.DB.Timezone)
|
||||
|
||||
// Construct url.
|
||||
dbUrl = url.URL{
|
||||
Scheme: cfg.DB.Driver,
|
||||
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
|
||||
Host: cfg.DB.Host,
|
||||
Path: cfg.DB.Database,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
}
|
||||
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
}
|
||||
defer masterDb.Close()
|
||||
|
||||
// =========================================================================
|
||||
// Start Truss
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "dbtable2crud",
|
||||
Aliases: []string{"dbtable2crud"},
|
||||
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "dbtable, table"},
|
||||
cli.StringFlag{Name: "modelFile, modelfile, file"},
|
||||
cli.StringFlag{Name: "modelName, modelname, model"},
|
||||
cli.StringFlag{Name: "templateDir, templates", Value: "./templates/dbtable2crud"},
|
||||
cli.StringFlag{Name: "projectPath", Value: ""},
|
||||
},
|
||||
Action: func(c *cli.Context) error {
|
||||
dbTable := strings.TrimSpace(c.String("dbtable"))
|
||||
modelFile := strings.TrimSpace(c.String("modelFile"))
|
||||
modelName := strings.TrimSpace(c.String("modelName"))
|
||||
templateDir := strings.TrimSpace(c.String("templateDir"))
|
||||
projectPath := strings.TrimSpace(c.String("projectPath"))
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to get current working directory")
|
||||
}
|
||||
|
||||
if !path.IsAbs(templateDir) {
|
||||
templateDir = filepath.Join(pwd, templateDir)
|
||||
}
|
||||
ok, err := exists(templateDir)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load template directory")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Template directory %s does not exist", templateDir)
|
||||
}
|
||||
|
||||
if modelFile == "" {
|
||||
return errors.Errorf("Model file path is required")
|
||||
}
|
||||
|
||||
if !path.IsAbs(modelFile) {
|
||||
modelFile = filepath.Join(pwd, modelFile)
|
||||
}
|
||||
ok, err = exists(modelFile)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load model file")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Model file %s does not exist", modelFile)
|
||||
}
|
||||
|
||||
// Load the project path from go.mod if not set.
|
||||
if projectPath == "" {
|
||||
goModFile := filepath.Join(pwd, "../../go.mod")
|
||||
ok, err = exists(goModFile)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load go.mod for project")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Failed to locate project go.mod at %s", goModFile)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadFile(goModFile)
|
||||
if err != nil {
|
||||
return errors.WithMessagef(err, "Failed to read go.mod at %s", goModFile)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(b), "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "module ") {
|
||||
projectPath = strings.TrimSpace(strings.Split(l, " ")[1])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if modelName == "" {
|
||||
modelName = strings.Split(filepath.Base(modelFile), ".")[0]
|
||||
modelName = strings.Replace(modelName, "_", " ", -1)
|
||||
modelName = strings.Replace(modelName, "-", " ", -1)
|
||||
modelName = strings.Title(modelName)
|
||||
modelName = strings.Replace(modelName, " ", "", -1)
|
||||
}
|
||||
|
||||
return dbtable2crud.Run(masterDb, log, cfg.DB.Database, dbTable, modelFile, modelName, templateDir, projectPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = app.Run(os.Args)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Truss : %+v", err)
|
||||
}
|
||||
|
||||
log.Printf("main : Truss : Completed")
|
||||
}
|
||||
|
||||
// exists returns a bool as to whether a file path exists.
|
||||
func exists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return true, err
|
||||
}
|
4
example-project/tools/truss/sample.env
Normal file
4
example-project/tools/truss/sample.env
Normal file
@ -0,0 +1,4 @@
|
||||
export TRUSS_DB_HOST=127.0.0.1:5433
|
||||
export TRUSS_DB_USER=postgres
|
||||
export TRUSS_DB_PASS=postgres
|
||||
export TRUSS_DB_DISABLE_TLS=true
|
@ -0,0 +1,503 @@
|
||||
{{ define "imports"}}
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"{{ $.GoSrcPath }}/internal/platform/auth"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Globals"}}
|
||||
const (
|
||||
// The database table for {{ $.Model.Name }}
|
||||
{{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the postgres not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Helpers"}}
|
||||
|
||||
// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }}
|
||||
var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}"
|
||||
|
||||
// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct
|
||||
// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns
|
||||
func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
|
||||
var (
|
||||
m {{ $.Model.Name }}
|
||||
err error
|
||||
)
|
||||
err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }})
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "ACL"}}
|
||||
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
|
||||
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
|
||||
if claims.Audience != "" {
|
||||
// select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID]
|
||||
query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Where(query.And(
|
||||
query.Equal("account_id", claims.Audience),
|
||||
query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}),
|
||||
))
|
||||
queryStr, args := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
|
||||
var {{ FormatCamelLower $.Model.PrimaryField }} string
|
||||
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return err
|
||||
}
|
||||
|
||||
// When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access
|
||||
// to the specified {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
if {{ FormatCamelLower $.Model.PrimaryField }} == "" {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
}
|
||||
{{ else }}
|
||||
// TODO: Unable to auto generate sql statement, update accordingly.
|
||||
panic("Not implemented!")
|
||||
{{ end }}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
|
||||
// 1. No claims, request is internal, no ACL applied
|
||||
{{ if $hasAccountId }}
|
||||
// 2. All role types can access their user ID
|
||||
{{ end }}
|
||||
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
||||
// Claims are empty, don't apply any ACL
|
||||
if claims.Audience == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
query.Where(query.Equal("account_id", claims.Audience))
|
||||
{{ end }}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Find"}}
|
||||
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}
|
||||
|
||||
// selectQuery constructs a base select query for {{ $.Model.Name }}
|
||||
func selectQuery() *sqlbuilder.SelectBuilder {
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
|
||||
query.From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
return query
|
||||
}
|
||||
|
||||
// findRequestQuery generates the select query for the given find request.
|
||||
// TODO: Need to figure out why can't parse the args when appending the where
|
||||
// to the query.
|
||||
func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||
query := selectQuery()
|
||||
if req.Where != nil {
|
||||
query.Where(query.And(*req.Where))
|
||||
}
|
||||
if len(req.Order) > 0 {
|
||||
query.OrderBy(req.Order...)
|
||||
}
|
||||
if req.Limit != nil {
|
||||
query.Limit(int(*req.Limit))
|
||||
}
|
||||
if req.Offset != nil {
|
||||
query.Offset(int(*req.Offset))
|
||||
}
|
||||
|
||||
return query, req.Args
|
||||
}
|
||||
|
||||
// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) {
|
||||
query, args := findRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }})
|
||||
}
|
||||
|
||||
// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find")
|
||||
defer span.Finish()
|
||||
|
||||
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
|
||||
query.From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
|
||||
{{ if $hasArchived }}
|
||||
if !includedArchived {
|
||||
query.Where(query.IsNull("archived_at"))
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
// Check to see if a sub query needs to be applied for the claims.
|
||||
err := applyClaimsSelect(ctx, claims, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queryStr, queryArgs := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
args = append(args, queryArgs...)
|
||||
|
||||
// Fetch all entries from the db.
|
||||
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate over each row.
|
||||
resp := []*{{ $.Model.Name }}{}
|
||||
for rows.Next() {
|
||||
u, err := mapRowsTo{{ $.Model.Name }}(rows)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, u)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read")
|
||||
defer span.Finish()
|
||||
|
||||
// Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }}
|
||||
query := selectQuery()
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }}))
|
||||
|
||||
res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil || len(res) == 0 {
|
||||
err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id)
|
||||
return nil, err
|
||||
}
|
||||
u := res[0]
|
||||
|
||||
return u, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Create"}}
|
||||
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
|
||||
{{ $createFields := (index $.StructFields $reqName) }}
|
||||
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
|
||||
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
|
||||
defer span.Finish()
|
||||
|
||||
if claims.Audience != "" {
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
if req.AccountId != "" {
|
||||
// Request accountId must match claims.
|
||||
if req.AccountId != claims.Audience {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
} else {
|
||||
// Set the accountId from claims.
|
||||
req.AccountId = claims.Audience
|
||||
}
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err = v.Struct(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
m := {{ $.Model.Name }}{
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}
|
||||
{{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }},
|
||||
{{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now,
|
||||
{{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }}
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
|
||||
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
|
||||
if req.{{ $f.FieldName }} != nil {
|
||||
{{ if eq $f.FieldType "sql.NullString" }}
|
||||
m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
|
||||
{{ else if eq $f.FieldType "*sql.NullString" }}
|
||||
m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
|
||||
{{ else }}
|
||||
m.{{ $f.FieldName }} = *req.{{ $f.FieldName }}
|
||||
{{ end }}
|
||||
}
|
||||
{{ end }}{{ end }}
|
||||
|
||||
// Build the insert SQL statement.
|
||||
query := sqlbuilder.NewInsertBuilder()
|
||||
query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Cols(
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}",
|
||||
{{ end }}{{ end }}
|
||||
)
|
||||
query.Values(
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }},
|
||||
{{ end }}{{ end }}
|
||||
)
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Update"}}
|
||||
{{ $reqName := (Concat $.Model.Name "UpdateRequest") }}
|
||||
{{ $updateFields := (index $.StructFields $reqName) }}
|
||||
// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database.
|
||||
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update")
|
||||
defer span.Finish()
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
|
||||
var fields []string
|
||||
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }}
|
||||
{{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }}
|
||||
if req.{{ $uf.FieldName }} != nil {
|
||||
{{ if and ($optional) ($isUuid) }}
|
||||
if *req.{{ $uf.FieldName }} != "" {
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
|
||||
} else {
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil))
|
||||
}
|
||||
{{ else }}
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
|
||||
{{ end }}
|
||||
}
|
||||
{{ end }}{{ end }}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }}
|
||||
// Append the updated_at field
|
||||
fields = append(fields, query.Assign("updated_at", now))
|
||||
{{ end }}
|
||||
|
||||
query.Set(fields...)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Archive"}}
|
||||
// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
|
||||
}{
|
||||
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Set(
|
||||
query.Assign("archived_at", now),
|
||||
)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Delete"}}
|
||||
// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
|
||||
}{
|
||||
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the delete SQL statement.
|
||||
query := sqlbuilder.NewDeleteBuilder()
|
||||
query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
@ -0,0 +1,79 @@
|
||||
{{ define "CreateRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}CreateRequest contains information needed to create a new {{ FormatCamel $.Model.Name }}.
|
||||
type {{ FormatCamel $.Model.Name }}CreateRequest struct {
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if and ($f.ApiCreate) (ne $f.FieldName $.Model.PrimaryField) }}{{ $optional := (FieldTagHasOption $f "validate" "omitempty") }}
|
||||
{{ $f.FieldName }} {{ if and ($optional) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ FieldTag $f "validate" }}`
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "UpdateRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}UpdateRequest defines what information may be provided to modify an existing
|
||||
// {{ FormatCamel $.Model.Name }}. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank.
|
||||
type {{ FormatCamel $.Model.Name }}UpdateRequest struct {
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if $f.ApiUpdate }}
|
||||
{{ $f.FieldName }} {{ if and (ne $f.FieldName $.Model.PrimaryField) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ if ne $f.FieldName $.Model.PrimaryField }}{{ FieldTagReplaceOrPrepend $f "validate" "required" "omitempty" }}{{ else }}{{ FieldTagReplaceOrPrepend $f "validate" "omitempty" "required" }}{{ end }}`
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "FindRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}FindRequest defines the possible options to search for {{ FormatCamelPluralTitleLower $.Model.Name }}. By default
|
||||
// archived {{ FormatCamelLowerTitle $.Model.Name }} will be excluded from response.
|
||||
type {{ FormatCamel $.Model.Name }}FindRequest struct {
|
||||
Where *string
|
||||
Args []interface{}
|
||||
Order []string
|
||||
Limit *uint
|
||||
Offset *uint
|
||||
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}{{ if $hasArchived }}IncludedArchived bool{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Enums"}}
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if $f.DbColumn }}{{ if $f.DbColumn.IsEnum }}
|
||||
// {{ $f.FieldType }} represents the {{ $f.ColumnName }} of {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
type {{ $f.FieldType }} string
|
||||
|
||||
// {{ $f.FieldType }} values define the {{ $f.ColumnName }} field of {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
const (
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}"
|
||||
{{ end }}
|
||||
)
|
||||
|
||||
// {{ $f.FieldType }}_Values provides list of valid {{ $f.FieldType }} values.
|
||||
var {{ $f.FieldType }}_Values = []{{ $f.FieldType }}{
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }},
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
// Scan supports reading the {{ $f.FieldType }} value from the database.
|
||||
func (s *{{ $f.FieldType }}) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = {{ $f.FieldType }}(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the {{ $f.FieldType }} value to be stored in the database.
|
||||
func (s {{ $f.FieldType }}) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof={{ JoinStrings $f.DbColumn.EnumValues " " }}")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the {{ $f.FieldType }} value to a string.
|
||||
func (s {{ $f.FieldType }}) String() string {
|
||||
return string(s)
|
||||
}
|
||||
{{ end }}{{ end }}{{ end }}
|
||||
{{ end }}
|
Reference in New Issue
Block a user