1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-15 00:15:15 +02:00

Completed truss code gen for generating model requests and crud.

This commit is contained in:
Lee Brown
2019-06-24 01:30:18 -08:00
parent efaeeb7103
commit bdbe3c587a
25 changed files with 3554 additions and 351 deletions

View File

@ -0,0 +1 @@
truss

View File

@ -0,0 +1,33 @@
# SaaS Truss
Copyright 2019, Geeks Accelerator
accelerator@geeksinthewoods.com.com
## Description
Truss provides code generation to reduce copy/pasting.
## Local Installation
### Build
```bash
go build .
```
### Usage
```bash
./truss -h
Usage of ./truss
--cmd string <dbtable2crud>
--db_host string <127.0.0.1:5433>
--db_user string <postgres>
--db_pass string <postgres>
--db_database string <shared>
--db_driver string <postgres>
--db_timezone string <utc>
--db_disabletls bool <false>
```

View File

@ -0,0 +1,149 @@
package dbtable2crud
import (
"fmt"
"strings"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
"github.com/pkg/errors"
)
type psqlColumn struct {
Table string
Column string
ColumnId int64
NotNull bool
DataTypeFull string
DataTypeName string
DataTypeLength *int
NumericPrecision *int
NumericScale *int
IsPrimaryKey bool
PrimaryKeyName *string
IsUniqueKey bool
UniqueKeyName *string
IsForeignKey bool
ForeignKeyName *string
ForeignKeyColumnId pq.Int64Array
ForeignKeyTable *string
ForeignKeyLocalColumnId pq.Int64Array
DefaultFull *string
DefaultValue *string
IsEnum bool
EnumTypeId *string
EnumValues []string
}
// descTable lists all the columns for a table.
func descTable(db *sqlx.DB, dbName, dbTable string) ([]psqlColumn, error) {
queryStr := fmt.Sprintf(`SELECT
c.relname as table,
f.attname as column,
f.attnum as columnId,
f.attnotnull as not_null,
pg_catalog.format_type(f.atttypid,f.atttypmod) AS data_type_full,
t.typname AS data_type_name,
CASE WHEN f.atttypmod >= 0 AND t.typname <> 'numeric'THEN (f.atttypmod - 4) --first 4 bytes are for storing actual length of data
END AS data_type_length,
CASE WHEN t.typname = 'numeric' THEN (((f.atttypmod - 4) >> 16) & 65535)
END AS numeric_precision,
CASE WHEN t.typname = 'numeric' THEN ((f.atttypmod - 4)& 65535 )
END AS numeric_scale,
CASE WHEN p.contype = 'p' THEN true ELSE false
END AS is_primary_key,
CASE WHEN p.contype = 'p' THEN p.conname
END AS primary_key_name,
CASE WHEN p.contype = 'u' THEN true ELSE false
END AS is_unique_key,
CASE WHEN p.contype = 'u' THEN p.conname
END AS unique_key_name,
CASE WHEN p.contype = 'f' THEN true ELSE false
END AS is_foreign_key,
CASE WHEN p.contype = 'f' THEN p.conname
END AS foreignkey_name,
CASE WHEN p.contype = 'f' THEN p.confkey
END AS foreign_key_columnid,
CASE WHEN p.contype = 'f' THEN g.relname
END AS foreign_key_table,
CASE WHEN p.contype = 'f' THEN p.conkey
END AS foreign_key_local_column_id,
CASE WHEN f.atthasdef = 't' THEN d.adsrc
END AS default_value,
CASE WHEN t.typtype = 'e' THEN true ELSE false
END AS is_enum,
CASE WHEN t.typtype = 'e' THEN t.oid
END AS enum_type_id
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid
JOIN pg_type t ON t.oid = f.atttypid
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
WHERE c.relkind = 'r'::char
AND f.attisdropped = false
AND c.relname = '%s'
AND f.attnum > 0
ORDER BY f.attnum
;`, dbTable) // AND n.nspname = '%s'
rows, err := db.Query(queryStr)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
// iterate over each row
var resp []psqlColumn
for rows.Next() {
var c psqlColumn
err = rows.Scan(&c.Table, &c.Column, &c.ColumnId, &c.NotNull, &c.DataTypeFull, &c.DataTypeName, &c.DataTypeLength, &c.NumericPrecision, &c.NumericScale, &c.IsPrimaryKey, &c.PrimaryKeyName, &c.IsUniqueKey, &c.UniqueKeyName, &c.IsForeignKey, &c.ForeignKeyName, &c.ForeignKeyColumnId, &c.ForeignKeyTable, &c.ForeignKeyLocalColumnId, &c.DefaultFull, &c.IsEnum, &c.EnumTypeId)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
if c.DefaultFull != nil {
defaultValue := *c.DefaultFull
// "'active'::project_status_t"
defaultValue = strings.Split(defaultValue, "::")[0]
c.DefaultValue = &defaultValue
}
resp = append(resp, c)
}
for colIdx, dbCol := range resp {
if !dbCol.IsEnum {
continue
}
queryStr := fmt.Sprintf(`SELECT e.enumlabel
FROM pg_enum AS e
WHERE e.enumtypid = '%s'
ORDER BY e.enumsortorder`, *dbCol.EnumTypeId)
rows, err := db.Query(queryStr)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
for rows.Next() {
var v string
err = rows.Scan(&v)
if err != nil {
err = errors.Wrapf(err, "query - %s", queryStr)
return nil, err
}
dbCol.EnumValues = append(dbCol.EnumValues, v)
}
resp[colIdx] = dbCol
}
return resp, nil
}

View File

@ -0,0 +1,378 @@
package dbtable2crud
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/schema"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
"github.com/dustin/go-humanize/english"
"github.com/fatih/camelcase"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
"github.com/sergi/go-diff/diffmatchpatch"
)
// Run in the main entry point for the dbtable2crud cmd.
func Run(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir, goSrcPath string) error {
log.SetPrefix(log.Prefix() + " : dbtable2crud")
// Ensure the schema is up to date
if err := schema.Migrate(db, log); err != nil {
return err
}
// When dbTable is empty, lower case the model name
if dbTable == "" {
dbTable = strings.Join(camelcase.Split(modelName), " ")
dbTable = english.PluralWord(2, dbTable, "")
dbTable = strings.Replace(dbTable, " ", "_", -1)
dbTable = strings.ToLower(dbTable)
}
// Parse the model file and load the specified model struct.
model, err := parseModelFile(db, log, dbName, dbTable, modelFile, modelName)
if err != nil {
return err
}
// Basic lint of the model struct.
err = validateModel(log, model)
if err != nil {
return err
}
tmplData := map[string]interface{}{
"GoSrcPath": goSrcPath,
}
// Update the model file with new or updated code.
err = updateModel(log, model, templateDir, tmplData)
if err != nil {
return err
}
// Update the model crud file with new or updated code.
err = updateModelCrud(db, log, dbName, dbTable, modelFile, modelName, templateDir, model, tmplData)
if err != nil {
return err
}
return nil
}
// validateModel performs a basic lint of the model struct to ensure
// code gen output is correct.
func validateModel(log *log.Logger, model *modelDef) error {
for _, sf := range model.Fields {
if sf.DbColumn == nil && sf.ColumnName != "-" {
log.Printf("validateStruct : Unable to find struct field for db column %s\n", sf.ColumnName)
}
var expectedType string
switch sf.FieldName {
case "ID":
expectedType = "string"
case "CreatedAt":
expectedType = "time.Time"
case "UpdatedAt":
expectedType = "time.Time"
case "ArchivedAt":
expectedType = "pq.NullTime"
}
if expectedType != "" && expectedType != sf.FieldType {
log.Printf("validateStruct : Struct field %s should be of type %s not %s\n", sf.FieldName, expectedType, sf.FieldType)
}
}
return nil
}
// updateModel updated the parsed code file with the new code.
func updateModel(log *log.Logger, model *modelDef, templateDir string, tmplData map[string]interface{}) error {
// Execute template and parse code to be used to compare against modelFile.
tmplObjs, err := loadTemplateObjects(log, model, templateDir, "models.tmpl", tmplData)
if err != nil {
return err
}
// Store the current code as a string to produce a diff.
curCode := model.String()
objHeaders := []*goparse.GoObject{}
for _, obj := range tmplObjs {
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
objHeaders = append(objHeaders, obj)
continue
}
if model.HasType(obj.Name, obj.Type) {
cur := model.Objects().Get(obj.Name, obj.Type)
newObjs := []*goparse.GoObject{}
if len(objHeaders) > 0 {
// Remove any comments and linebreaks before the existing object so updates can be added.
removeObjs := []*goparse.GoObject{}
for idx := cur.Index - 1; idx > 0; idx-- {
prevObj := model.Objects().List()[idx]
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
removeObjs = append(removeObjs, prevObj)
} else {
break
}
}
if len(removeObjs) > 0 {
err := model.Objects().Remove(removeObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
return err
}
// Make sure the current index is correct.
cur = model.Objects().Get(obj.Name, obj.Type)
}
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
newObjs = append(newObjs, c)
}
}
newObjs = append(newObjs, obj)
// Do the object replacement.
err := model.Objects().Replace(cur, newObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
return err
}
} else {
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
err := model.Objects().Add(c)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, model.Name)
return err
}
}
err := model.Objects().Add(obj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, model.Name)
return err
}
}
objHeaders = []*goparse.GoObject{}
}
// Set some flags to determine additional imports and need to be added.
var hasEnum bool
var hasPq bool
for _, f := range model.Fields {
if f.DbColumn != nil && f.DbColumn.IsEnum {
hasEnum = true
}
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
hasPq = true
}
}
reqImports := []string{}
if hasEnum {
reqImports = append(reqImports, "database/sql/driver")
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
reqImports = append(reqImports, "github.com/pkg/errors")
}
if hasPq {
reqImports = append(reqImports, "github.com/lib/pq")
}
for _, in := range reqImports {
err := model.AddImport(goparse.GoImport{Name: in})
if err != nil {
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, model.Name)
return err
}
}
// Produce a diff after the updates have been applied.
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(curCode, model.String(), true)
fmt.Println(dmp.DiffPrettyText(diffs))
return nil
}
// updateModelCrud updated the parsed code file with the new code.
func updateModelCrud(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir string, baseModel *modelDef, tmplData map[string]interface{}) error {
modelDir := filepath.Dir(modelFile)
crudFile := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+".go")
var crudDoc *goparse.GoDocument
if _, err := os.Stat(crudFile); os.IsNotExist(err) {
crudDoc, err = goparse.NewGoDocument(baseModel.Package)
if err != nil {
return err
}
} else {
// Parse the supplied model file.
crudDoc, err = goparse.ParseFile(log, modelFile)
if err != nil {
return err
}
}
// Load all the updated struct fields from the base model file.
structFields := make(map[string]map[string]modelField)
for _, obj := range baseModel.GoDocument.Objects().List() {
if obj.Type != goparse.GoObjectType_Struct || obj.Name == baseModel.Name {
continue
}
objFields, err := parseModelFields(baseModel.GoDocument, obj.Name, baseModel)
if err != nil {
return err
}
structFields[obj.Name] = make(map[string]modelField)
for _, f := range objFields {
structFields[obj.Name][f.FieldName] = f
}
}
// Append the struct fields to be used for template execution.
if tmplData == nil {
tmplData = make(map[string]interface{})
}
tmplData["StructFields"] = structFields
// Execute template and parse code to be used to compare against modelFile.
tmplObjs, err := loadTemplateObjects(log, baseModel, templateDir, "model_crud.tmpl", tmplData)
if err != nil {
return err
}
// Store the current code as a string to produce a diff.
curCode := crudDoc.String()
objHeaders := []*goparse.GoObject{}
for _, obj := range tmplObjs {
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
objHeaders = append(objHeaders, obj)
continue
}
if crudDoc.HasType(obj.Name, obj.Type) {
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
newObjs := []*goparse.GoObject{}
if len(objHeaders) > 0 {
// Remove any comments and linebreaks before the existing object so updates can be added.
removeObjs := []*goparse.GoObject{}
for idx := cur.Index - 1; idx > 0; idx-- {
prevObj := crudDoc.Objects().List()[idx]
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
removeObjs = append(removeObjs, prevObj)
} else {
break
}
}
if len(removeObjs) > 0 {
err := crudDoc.Objects().Remove(removeObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
// Make sure the current index is correct.
cur = crudDoc.Objects().Get(obj.Name, obj.Type)
}
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
newObjs = append(newObjs, c)
}
}
newObjs = append(newObjs, obj)
// Do the object replacement.
err := crudDoc.Objects().Replace(cur, newObjs...)
if err != nil {
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
} else {
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
err := crudDoc.Objects().Add(c)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
return err
}
}
err := crudDoc.Objects().Add(obj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
}
objHeaders = []*goparse.GoObject{}
}
/*
// Set some flags to determine additional imports and need to be added.
var hasEnum bool
var hasPq bool
for _, f := range crudModel.Fields {
if f.DbColumn != nil && f.DbColumn.IsEnum {
hasEnum = true
}
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
hasPq = true
}
}
reqImports := []string{}
if hasEnum {
reqImports = append(reqImports, "database/sql/driver")
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
reqImports = append(reqImports, "github.com/pkg/errors")
}
if hasPq {
reqImports = append(reqImports, "github.com/lib/pq")
}
for _, in := range reqImports {
err := model.AddImport(goparse.GoImport{Name: in})
if err != nil {
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, crudModel.Name)
return err
}
}
*/
// Produce a diff after the updates have been applied.
dmp := diffmatchpatch.New()
diffs := dmp.DiffMain(curCode, crudDoc.String(), true)
fmt.Println(dmp.DiffPrettyText(diffs))
return nil
}

View File

@ -0,0 +1,229 @@
package dbtable2crud
import (
"log"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
"github.com/fatih/structtag"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)
// modelDef defines info about the struct and associated db table.
type modelDef struct {
*goparse.GoDocument
Name string
TableName string
PrimaryField string
PrimaryColumn string
PrimaryType string
Fields []modelField
FieldNames []string
ColumnNames []string
}
// modelField defines a struct field and associated db column.
type modelField struct {
ColumnName string
DbColumn *psqlColumn
FieldName string
FieldType string
FieldIsPtr bool
Tags *structtag.Tags
ApiHide bool
ApiRead bool
ApiCreate bool
ApiUpdate bool
DefaultValue string
}
// parseModelFile parses the entire model file and then loads the specified model struct.
func parseModelFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName string) (*modelDef, error) {
// Parse the supplied model file.
doc, err := goparse.ParseFile(log, modelFile)
if err != nil {
return nil, err
}
// Init new modelDef.
model := &modelDef{
GoDocument: doc,
Name: modelName,
TableName: dbTable,
}
// Append the field the the model def.
model.Fields, err = parseModelFields(doc, modelName, nil)
if err != nil {
return nil, err
}
for _, sf := range model.Fields {
model.FieldNames = append(model.FieldNames, sf.FieldName)
model.ColumnNames = append(model.ColumnNames, sf.ColumnName)
}
// Query the database for a table definition.
dbCols, err := descTable(db, dbName, dbTable)
if err != nil {
return model, err
}
// Loop over all the database table columns and append to the associated
// struct field. Don't force all database table columns to be defined in the
// in the struct.
for _, dbCol := range dbCols {
for idx, sf := range model.Fields {
if sf.ColumnName != dbCol.Column {
continue
}
if dbCol.IsPrimaryKey {
model.PrimaryColumn = sf.ColumnName
model.PrimaryField = sf.FieldName
model.PrimaryType = sf.FieldType
}
if dbCol.DefaultValue != nil {
sf.DefaultValue = *dbCol.DefaultValue
if dbCol.IsEnum {
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
sf.DefaultValue = sf.FieldType + "_" + FormatCamel(sf.DefaultValue)
} else if strings.HasPrefix(sf.DefaultValue, "'") {
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
sf.DefaultValue = "\"" + sf.DefaultValue + "\""
}
}
c := dbCol
sf.DbColumn = &c
model.Fields[idx] = sf
}
}
// Print out the model for debugging.
//modelJSON, err := json.MarshalIndent(model, "", " ")
//if err != nil {
// return model, errors.WithStack(err )
//}
//log.Printf(string(modelJSON))
return model, nil
}
// parseModelFields parses the fields from a struct.
func parseModelFields(doc *goparse.GoDocument, modelName string, baseModel *modelDef) ([]modelField, error) {
// Ensure the model file has a struct with the model name supplied.
if !doc.HasType(modelName, goparse.GoObjectType_Struct) {
err := errors.Errorf("Struct with the name %s does not exist", modelName)
return nil, err
}
// Load the struct from parsed go file.
docModel := doc.Get(modelName, goparse.GoObjectType_Struct)
// Loop over all the objects contained between the struct definition start and end.
// This should be a list of variables defined for model.
resp := []modelField{}
for _, l := range docModel.Objects().List() {
// Skip all lines that are not a var.
if l.Type != goparse.GoObjectType_Line {
log.Printf("parseModelFile : Model %s has line that is %s, not type line, skipping - %s\n", modelName, l.Type, l.String())
continue
}
// Extract the var name, type and defined tags from the line.
sv, err := goparse.ParseStructProp(l)
if err != nil {
return nil, err
}
// Init new modelField for the struct var.
sf := modelField{
FieldName: sv.Name,
FieldType: sv.Type,
FieldIsPtr: strings.HasPrefix(sv.Type, "*"),
Tags: sv.Tags,
}
// Extract the column name from the var tags.
if sf.Tags != nil {
// First try to get the column name from the db tag.
dbt, err := sf.Tags.Get("db")
if err != nil && !strings.Contains(err.Error(), "not exist") {
err = errors.WithStack(err)
return nil, err
} else if dbt != nil {
sf.ColumnName = dbt.Name
}
// Second try to get the column name from the json tag.
if sf.ColumnName == "" {
jt, err := sf.Tags.Get("json")
if err != nil && !strings.Contains(err.Error(), "not exist") {
err = errors.WithStack(err)
return nil, err
} else if jt != nil && jt.Name != "-" {
sf.ColumnName = jt.Name
}
}
var apiActionsSet bool
tt, err := sf.Tags.Get("truss")
if err != nil && !strings.Contains(err.Error(), "not exist") {
err = errors.WithStack(err)
return nil, err
} else if tt != nil {
if tt.Name == "api-create" || tt.HasOption("api-create") {
sf.ApiCreate = true
apiActionsSet = true
}
if tt.Name == "api-read" || tt.HasOption("api-read") {
sf.ApiRead = true
apiActionsSet = true
}
if tt.Name == "api-update" || tt.HasOption("api-update") {
sf.ApiUpdate = true
apiActionsSet = true
}
if tt.Name == "api-hide" || tt.HasOption("api-hide") {
sf.ApiHide = true
apiActionsSet = true
}
}
if !apiActionsSet {
sf.ApiCreate = true
sf.ApiRead = true
sf.ApiUpdate = true
}
}
// Set the column name to the field name if empty and does not equal '-'.
if sf.ColumnName == "" {
sf.ColumnName = sf.FieldName
}
// If a base model as already been parsed with the db columns,
// append to the current field.
if baseModel != nil {
for _, baseSf := range baseModel.Fields {
if baseSf.ColumnName == sf.ColumnName {
sf.DefaultValue = baseSf.DefaultValue
sf.DbColumn = baseSf.DbColumn
break
}
}
}
// Append the field the the model def.
resp = append(resp, sf)
}
return resp, nil
}

View File

@ -0,0 +1,345 @@
package dbtable2crud
import (
"bufio"
"bytes"
"fmt"
"go/format"
"io/ioutil"
"log"
"os"
"path/filepath"
"sort"
"strings"
"text/template"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/internal/goparse"
"github.com/dustin/go-humanize/english"
"github.com/fatih/camelcase"
"github.com/iancoleman/strcase"
"github.com/pkg/errors"
)
// loadTemplateObjects executes a template file based on the given model struct and
// returns the parsed go objects.
func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename string, tmptData map[string]interface{}) ([]*goparse.GoObject, error) {
// Data used to execute all the of defined code sections in the template file.
if tmptData == nil {
tmptData = make(map[string]interface{})
}
tmptData["Model"] = model
// geeks-accelerator/oss/saas-starter-kit/example-project
// Read the template file from the local file system.
tempFilePath := filepath.Join(templateDir, filename)
dat, err := ioutil.ReadFile(tempFilePath)
if err != nil {
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
return nil, err
}
// New template with custom functions.
baseTmpl := template.New("base")
baseTmpl.Funcs(template.FuncMap{
"Concat": func(vals ...string) string {
return strings.Join(vals, "")
},
"JoinStrings": func(vals []string, sep string) string {
return strings.Join(vals, sep)
},
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
l := []string{}
for _, v := range vals {
l = append(l, pre+v)
}
return strings.Join(l, sep)
},
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
l := []string{}
for _, v := range vals {
l = append(l, fmt.Sprintf(fmtStr, v))
}
return strings.Join(l, sep)
},
"FormatCamel": func(name string) string {
return FormatCamel(name)
},
"FormatCamelTitle": func(name string) string {
return FormatCamelTitle(name)
},
"FormatCamelLower": func(name string) string {
if name == "ID" {
return "id"
}
return FormatCamelLower(name)
},
"FormatCamelLowerTitle": func(name string) string {
return FormatCamelLowerTitle(name)
},
"FormatCamelPluralTitle": func(name string) string {
return FormatCamelPluralTitle(name)
},
"FormatCamelPluralTitleLower": func(name string) string {
return FormatCamelPluralTitleLower(name)
},
"FormatCamelPluralCamel": func(name string) string {
return FormatCamelPluralCamel(name)
},
"FormatCamelPluralLower": func(name string) string {
return FormatCamelPluralLower(name)
},
"FormatCamelPluralUnderscore": func(name string) string {
return FormatCamelPluralUnderscore(name)
},
"FormatCamelPluralLowerUnderscore": func(name string) string {
return FormatCamelPluralLowerUnderscore(name)
},
"FormatCamelUnderscore": func(name string) string {
return FormatCamelUnderscore(name)
},
"FormatCamelLowerUnderscore": func(name string) string {
return FormatCamelLowerUnderscore(name)
},
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
if f.Tags == nil {
return false
}
ft, err := f.Tags.Get(tagName)
if ft == nil || err != nil {
return false
}
if ft.Name == optName || ft.HasOption(optName) {
return true
}
return false
},
"FieldTag": func(f modelField, tagName string) string {
if f.Tags == nil {
return ""
}
ft, err := f.Tags.Get(tagName)
if ft == nil || err != nil {
return ""
}
return ft.String()
},
"FieldTagReplaceOrPrepend": func(f modelField, tagName, oldVal, newVal string) string {
if f.Tags == nil {
return ""
}
ft, err := f.Tags.Get(tagName)
if ft == nil || err != nil {
return ""
}
if ft.Name == oldVal || ft.Name == newVal {
ft.Name = newVal
} else if ft.HasOption(oldVal) {
for idx, val := range ft.Options {
if val == oldVal {
ft.Options[idx] = newVal
}
}
} else if !ft.HasOption(newVal) {
if ft.Name == "" {
ft.Name = newVal
} else {
ft.Options = append(ft.Options, newVal)
}
}
return ft.String()
},
"StringListHasValue": func(list []string, val string) bool {
for _, v := range list {
if v == val {
return true
}
}
return false
},
})
// Load the template file using the text/template package.
tmpl, err := baseTmpl.Parse(string(dat))
if err != nil {
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
return nil, err
}
// Generate a list of template names defined in the template file.
tmplNames := []string{}
for _, defTmpl := range tmpl.Templates() {
tmplNames = append(tmplNames, defTmpl.Name())
}
// Stupid hack to return template names the in order they are defined in the file.
tmplNames, err = templateFileOrderedNames(tempFilePath, tmplNames)
if err != nil {
return nil, err
}
// Loop over all the defined templates, execute using the defined data, parse the
// formatted code and append the parsed go objects to the result list.
var resp []*goparse.GoObject
for _, tmplName := range tmplNames {
// Executed the defined template with the given data.
var tpl bytes.Buffer
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
return resp, err
}
// Format the source code to ensure its valid and code to parsed consistently.
codeBytes, err := format.Source(tpl.Bytes())
if err != nil {
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
dl := []string{}
for idx, l := range strings.Split(tpl.String(), "\n") {
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
}
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
return resp, err
}
// Remove extra white space from the code.
codeStr := strings.TrimSpace(string(codeBytes))
// Split the code into a list of strings.
codeLines := strings.Split(codeStr, "\n")
// Parse the code lines into a set of objects.
objs, err := goparse.ParseLines(codeLines, 0)
if err != nil {
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
return resp, err
}
// Append the parsed objects to the return result list.
for _, obj := range objs.List() {
if obj.Name == "" && obj.Type != goparse.GoObjectType_Import && obj.Type != goparse.GoObjectType_Var && obj.Type != goparse.GoObjectType_Const && obj.Type != goparse.GoObjectType_Comment && obj.Type != goparse.GoObjectType_LineBreak {
// All objects should have a name except for multiline var/const declarations and comments.
err = errors.Errorf("Failed to parse name with type %s from lines: %v", obj.Type, obj.Lines())
return resp, err
} else if string(obj.Type) == "" {
err = errors.Errorf("Failed to parse type for %s from lines: %v", obj.Name, obj.Lines())
return resp, err
}
resp = append(resp, obj)
}
}
return resp, nil
}
// FormatCamel formats Valdez mountain to ValdezMountain
func FormatCamel(name string) string {
return strcase.ToCamel(name)
}
// FormatCamelLower formats ValdezMountain to valdezmountain
func FormatCamelLower(name string) string {
return strcase.ToLowerCamel(FormatCamel(name))
}
// FormatCamelTitle formats ValdezMountain to Valdez Mountain
func FormatCamelTitle(name string) string {
return strings.Join(camelcase.Split(name), " ")
}
// FormatCamelLowerTitle formats ValdezMountain to valdez mountain
func FormatCamelLowerTitle(name string) string {
return strings.ToLower(FormatCamelTitle(name))
}
// FormatCamelPluralTitle formats ValdezMountain to Valdez Mountains
func FormatCamelPluralTitle(name string) string {
pts := camelcase.Split(name)
lastIdx := len(pts) - 1
pts[lastIdx] = english.PluralWord(2, pts[lastIdx], "")
return strings.Join(pts, " ")
}
// FormatCamelPluralTitleLower formats ValdezMountain to valdez mountains
func FormatCamelPluralTitleLower(name string) string {
return strings.ToLower(FormatCamelPluralTitle(name))
}
// FormatCamelPluralCamel formats ValdezMountain to ValdezMountains
func FormatCamelPluralCamel(name string) string {
return strcase.ToCamel(FormatCamelPluralTitle(name))
}
// FormatCamelPluralLower formats ValdezMountain to valdezmountains
func FormatCamelPluralLower(name string) string {
return strcase.ToLowerCamel(FormatCamelPluralTitle(name))
}
// FormatCamelPluralUnderscore formats ValdezMountain to Valdez_Mountains
func FormatCamelPluralUnderscore(name string) string {
return strings.Replace(FormatCamelPluralTitle(name), " ", "_", -1)
}
// FormatCamelPluralLowerUnderscore formats ValdezMountain to valdez_mountains
func FormatCamelPluralLowerUnderscore(name string) string {
return strings.ToLower(FormatCamelPluralUnderscore(name))
}
// FormatCamelUnderscore formats ValdezMountain to Valdez_Mountain
func FormatCamelUnderscore(name string) string {
return strings.Replace(FormatCamelTitle(name), " ", "_", -1)
}
// FormatCamelLowerUnderscore formats ValdezMountain to valdez_mountain
func FormatCamelLowerUnderscore(name string) string {
return strings.ToLower(FormatCamelUnderscore(name))
}
// templateFileOrderedNames returns the template names the in order they are defined in the file.
func templateFileOrderedNames(localPath string, names []string) (resp []string, err error) {
file, err := os.Open(localPath)
if err != nil {
return resp, errors.WithStack(err)
}
defer file.Close()
idxList := []int{}
idxNames := make(map[int]string)
idx := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
continue
}
for _, name := range names {
if strings.Contains(scanner.Text(), "\""+name+"\"") {
idxList = append(idxList, idx)
idxNames[idx] = name
break
}
}
idx = idx + 1
}
if err := scanner.Err(); err != nil {
return resp, errors.WithStack(err)
}
sort.Ints(idxList)
for _, idx := range idxList {
resp = append(resp, idxNames[idx])
}
return resp, nil
}

View File

@ -0,0 +1,301 @@
package goparse
import (
"fmt"
"go/format"
"io/ioutil"
"strings"
"github.com/pkg/errors"
)
// GoDocument defines a single go code file.
type GoDocument struct {
*GoObjects
Package string
imports GoImports
}
// GoImport defines a single import line with optional alias.
type GoImport struct {
Name string
Alias string
}
// GoImports holds a list of import lines.
type GoImports []GoImport
// NewGoDocument creates a new GoDocument with the package line set.
func NewGoDocument(packageName string) (doc *GoDocument, err error) {
doc = &GoDocument{
GoObjects: &GoObjects{
list: []*GoObject{},
},
}
err = doc.SetPackage(packageName)
return doc, err
}
// Objects returns a list of root GoObject.
func (doc *GoDocument) Objects() *GoObjects {
if doc.GoObjects == nil {
doc.GoObjects = &GoObjects{
list: []*GoObject{},
}
}
return doc.GoObjects
}
// NewObjectPackage returns a new GoObject with a single package definition line.
func NewObjectPackage(packageName string) *GoObject {
lines := []string{
fmt.Sprintf("package %s", packageName),
"",
}
obj, _ := ParseGoObject(lines, 0)
return obj
}
// SetPackage appends sets the package line for the code file.
func (doc *GoDocument) SetPackage(packageName string) error {
var existing *GoObject
for _, obj := range doc.Objects().List() {
if obj.Type == GoObjectType_Package {
existing = obj
break
}
}
new := NewObjectPackage(packageName)
var err error
if existing != nil {
err = doc.Objects().Replace(existing, new)
} else if len(doc.Objects().List()) > 0 {
// Insert after any existing comments or line breaks.
var insertPos int
//for idx, obj := range doc.Objects().List() {
// switch obj.Type {
// case GoObjectType_Comment, GoObjectType_LineBreak:
// insertPos = idx
// default:
// break
// }
//}
err = doc.Objects().Insert(insertPos, new)
} else {
err = doc.Objects().Add(new)
}
return err
}
// AddObject appends a new GoObject to the doc root object list.
func (doc *GoDocument) AddObject(newObj *GoObject) error {
return doc.Objects().Add(newObj)
}
// InsertObject inserts a new GoObject at the desired position to the doc root object list.
func (doc *GoDocument) InsertObject(pos int, newObj *GoObject) error {
return doc.Objects().Insert(pos, newObj)
}
// Imports returns the GoDocument imports.
func (doc *GoDocument) Imports() (GoImports, error) {
// If the doc imports are empty, try to load them from the root objects.
if len(doc.imports) == 0 {
for _, obj := range doc.Objects().List() {
if obj.Type != GoObjectType_Import {
continue
}
res, err := ParseImportObject(obj)
if err != nil {
return doc.imports, err
}
// Combine all the imports into a single definition.
for _, n := range res {
doc.imports = append(doc.imports, n)
}
}
}
return doc.imports, nil
}
// Lines returns all the code lines.
func (doc *GoDocument) Lines() []string {
l := []string{}
for _, ol := range doc.Objects().Lines() {
l = append(l, ol)
}
return l
}
// String returns a single value for all the code lines.
func (doc *GoDocument) String() string {
return strings.Join(doc.Lines(), "\n")
}
// Print writes all the code lines to standard out.
func (doc *GoDocument) Print() {
for _, l := range doc.Lines() {
fmt.Println(l)
}
}
// Save renders all the code lines for the document, formats the code
// and then saves it to the supplied file path.
func (doc *GoDocument) Save(localpath string) error {
res, err := format.Source([]byte(doc.String()))
if err != nil {
err = errors.WithMessage(err, "Failed formatted source code")
return err
}
err = ioutil.WriteFile(localpath, res, 0644)
if err != nil {
err = errors.WithMessagef(err, "Failed write source code to file %s", localpath)
return err
}
return nil
}
// AddImport checks for any duplicate imports by name and adds it if not.
func (doc *GoDocument) AddImport(impt GoImport) error {
impt.Name = strings.Trim(impt.Name, "\"")
// Get a list of current imports for the document.
impts, err := doc.Imports()
if err != nil {
return err
}
// If the document has as the import, don't add it.
if impts.Has(impt.Name) {
return nil
}
// Loop through all the document root objects for an object of type import.
// If one exists, append the import to the existing list.
for _, obj := range doc.Objects().List() {
if obj.Type != GoObjectType_Import || len(obj.Lines()) == 1 {
continue
}
obj.subLines = append(obj.subLines, impt.String())
obj.goObjects.list = append(obj.goObjects.list, impt.Object())
doc.imports = append(doc.imports, impt)
return nil
}
// Document does not have an existing import object, so need to create one and
// then append to the document.
newObj := NewObjectImports(impt)
// Insert after any package, any existing comments or line breaks should be included.
var insertPos int
for idx, obj := range doc.Objects().List() {
switch obj.Type {
case GoObjectType_Package, GoObjectType_Comment, GoObjectType_LineBreak:
insertPos = idx
default:
break
}
}
// Insert the new import object.
err = doc.InsertObject(insertPos, newObj)
if err != nil {
return err
}
return nil
}
// NewObjectImports returns a new GoObject with a single import definition.
func NewObjectImports(impt GoImport) *GoObject {
lines := []string{
"import (",
impt.String(),
")",
"",
}
obj, _ := ParseGoObject(lines, 0)
children, err := ParseLines(obj.subLines, 1)
if err != nil {
return nil
}
for _, child := range children.List() {
obj.Objects().Add(child)
}
return obj
}
// Has checks to see if an import exists by name or alias.
func (impts GoImports) Has(name string) bool {
for _, impt := range impts {
if name == impt.Name || (impt.Alias != "" && name == impt.Alias) {
return true
}
}
return false
}
// Line formats an import as a string.
func (impt GoImport) String() string {
var imptLine string
if impt.Alias != "" {
imptLine = fmt.Sprintf("\t%s \"%s\"", impt.Alias, impt.Name)
} else {
imptLine = fmt.Sprintf("\t\"%s\"", impt.Name)
}
return imptLine
}
// Object returns the first GoObject for an import.
func (impt GoImport) Object() *GoObject {
imptObj := NewObjectImports(impt)
return imptObj.Objects().List()[0]
}
// ParseImportObject extracts all the import definitions.
func ParseImportObject(obj *GoObject) (resp GoImports, err error) {
if obj.Type != GoObjectType_Import {
return resp, errors.Errorf("Invalid type %s", string(obj.Type))
}
for _, l := range obj.Lines() {
if !strings.Contains(l, "\"") {
continue
}
l = strings.TrimSpace(l)
pts := strings.Split(l, "\"")
var impt GoImport
if strings.HasPrefix(l, "\"") {
impt.Name = pts[1]
} else {
impt.Alias = strings.TrimSpace(pts[0])
impt.Name = pts[1]
}
resp = append(resp, impt)
}
return resp, nil
}

View File

@ -0,0 +1,458 @@
package goparse
import (
"log"
"strings"
"github.com/fatih/structtag"
"github.com/pkg/errors"
)
// GoEmptyLine defined a GoObject for a code line break.
var GoEmptyLine = GoObject{
Type: GoObjectType_LineBreak,
goObjects: &GoObjects{
list: []*GoObject{},
},
}
// GoObjectType defines a set of possible types to group
// parsed code by.
type GoObjectType = string
var (
GoObjectType_Package = "package"
GoObjectType_Import = "import"
GoObjectType_Var = "var"
GoObjectType_Const = "const"
GoObjectType_Func = "func"
GoObjectType_Struct = "struct"
GoObjectType_Comment = "comment"
GoObjectType_LineBreak = "linebreak"
GoObjectType_Line = "line"
GoObjectType_Type = "type"
)
// GoObject defines a section of code with a nested set of children.
type GoObject struct {
Type GoObjectType
Name string
startLines []string
endLines []string
subLines []string
goObjects *GoObjects
Index int
}
// GoObjects stores a list of GoObject.
type GoObjects struct {
list []*GoObject
}
// Objects returns the list of *GoObject.
func (obj *GoObject) Objects() *GoObjects {
if obj.goObjects == nil {
obj.goObjects = &GoObjects{
list: []*GoObject{},
}
}
return obj.goObjects
}
// Clone performs a deep copy of the struct.
func (obj *GoObject) Clone() *GoObject {
n := &GoObject{
Type: obj.Type,
Name: obj.Name,
startLines: obj.startLines,
endLines: obj.endLines,
subLines: obj.subLines,
goObjects: &GoObjects{
list: []*GoObject{},
},
Index: obj.Index,
}
for _, sub := range obj.Objects().List() {
n.Objects().Add(sub.Clone())
}
return n
}
// IsComment returns whether an object is of type GoObjectType_Comment.
func (obj *GoObject) IsComment() bool {
if obj.Type != GoObjectType_Comment {
return false
}
return true
}
// Contains searches all the lines for the object for a matching string.
func (obj *GoObject) Contains(match string) bool {
for _, l := range obj.Lines() {
if strings.Contains(l, match) {
return true
}
}
return false
}
// UpdateLines parses the new code and replaces the current GoObject.
func (obj *GoObject) UpdateLines(newLines []string) error {
// Parse the new lines.
objs, err := ParseLines(newLines, 0)
if err != nil {
return err
}
var newObj *GoObject
for _, obj := range objs.List() {
if obj.Type == GoObjectType_LineBreak {
continue
}
if newObj == nil {
newObj = obj
}
// There should only be one resulting parsed object that is
// not of type GoObjectType_LineBreak.
return errors.New("Can only update single blocks of code")
}
// No new code was parsed, return error.
if newObj == nil {
return errors.New("Failed to render replacement code")
}
return obj.Update(newObj)
}
// Update performs a deep copy that overwrites the existing values.
func (obj *GoObject) Update(newObj *GoObject) error {
obj.Type = newObj.Type
obj.Name = newObj.Name
obj.startLines = newObj.startLines
obj.endLines = newObj.endLines
obj.subLines = newObj.subLines
obj.goObjects = newObj.goObjects
return nil
}
// Lines returns a list of strings for current object and all children.
func (obj *GoObject) Lines() []string {
l := []string{}
// First include any lines before the sub objects.
for _, sl := range obj.startLines {
l = append(l, sl)
}
// If there are parsed sub objects include those lines else when
// no sub objects, just use the sub lines.
if len(obj.Objects().List()) > 0 {
for _, sl := range obj.Objects().Lines() {
l = append(l, sl)
}
} else {
for _, sl := range obj.subLines {
l = append(l, sl)
}
}
// Lastly include any other lines that are after all parsed sub objects.
for _, sl := range obj.endLines {
l = append(l, sl)
}
return l
}
// String returns the lines separated by line break.
func (obj *GoObject) String() string {
return strings.Join(obj.Lines(), "\n")
}
// Lines returns a list of strings for all the list objects.
func (objs *GoObjects) Lines() []string {
l := []string{}
for _, obj := range objs.List() {
for _, oj := range obj.Lines() {
l = append(l, oj)
}
}
return l
}
// String returns all the lines for the list objects.
func (objs *GoObjects) String() string {
lines := []string{}
for _, obj := range objs.List() {
lines = append(lines, obj.String())
}
return strings.Join(lines, "\n")
}
// List returns the list of GoObjects.
func (objs *GoObjects) List() []*GoObject {
return objs.list
}
// HasFunc searches the current list of objects for a function object by name.
func (objs *GoObjects) HasFunc(name string) bool {
return objs.HasType(name, GoObjectType_Func)
}
// Get returns the GoObject for the matching name and type.
func (objs *GoObjects) Get(name string, objType GoObjectType) *GoObject {
for _, obj := range objs.list {
if obj.Name == name && (objType == "" || obj.Type == objType) {
return obj
}
}
return nil
}
// HasType checks is a GoObject exists for the matching name and type.
func (objs *GoObjects) HasType(name string, objType GoObjectType) bool {
for _, obj := range objs.list {
if obj.Name == name && (objType == "" || obj.Type == objType) {
return true
}
}
return false
}
// HasObject checks to see if the exact code block exists.
func (objs *GoObjects) HasObject(src *GoObject) bool {
if src == nil {
return false
}
// Generate the code for the supplied object.
srcLines := []string{}
for _, l := range src.Lines() {
// Exclude empty lines.
l = strings.TrimSpace(l)
if l != "" {
srcLines = append(srcLines, l)
}
}
srcStr := strings.Join(srcLines, "\n")
// Loop over all the objects and match with src code.
for _, obj := range objs.list {
objLines := []string{}
for _, l := range obj.Lines() {
// Exclude empty lines.
l = strings.TrimSpace(l)
if l != "" {
objLines = append(objLines, l)
}
}
objStr := strings.Join(objLines, "\n")
// Return true if the current object code matches src code.
if srcStr == objStr {
return true
}
}
return false
}
// Add appends a new GoObject to the list.
func (objs *GoObjects) Add(newObj *GoObject) error {
newObj.Index = len(objs.list)
objs.list = append(objs.list, newObj)
return nil
}
// Insert appends a new GoObject at the desired position to the list.
func (objs *GoObjects) Insert(pos int, newObj *GoObject) error {
newList := []*GoObject{}
var newIdx int
for _, obj := range objs.list {
if obj.Index < pos {
obj.Index = newIdx
newList = append(newList, obj)
} else {
if obj.Index == pos {
newObj.Index = newIdx
newList = append(newList, newObj)
newIdx++
}
obj.Index = newIdx
newList = append(newList, obj)
}
newIdx++
}
objs.list = newList
return nil
}
// Remove deletes a GoObject from the list.
func (objs *GoObjects) Remove(delObjs ...*GoObject) error {
for _, delObj := range delObjs {
oldList := objs.List()
objs.list = []*GoObject{}
var newIdx int
for _, obj := range oldList {
if obj.Index == delObj.Index {
continue
}
obj.Index = newIdx
objs.list = append(objs.list, obj)
newIdx++
}
}
return nil
}
// Replace updates an existing GoObject while maintaining is same position.
func (objs *GoObjects) Replace(oldObj *GoObject, newObjs ...*GoObject) error {
if oldObj.Index >= len(objs.list) {
return errors.WithStack(errGoObjectNotExist)
} else if len(newObjs) == 0 {
return nil
}
oldList := objs.List()
objs.list = []*GoObject{}
var newIdx int
for _, obj := range oldList {
if obj.Index < oldObj.Index {
obj.Index = newIdx
objs.list = append(objs.list, obj)
newIdx++
} else if obj.Index == oldObj.Index {
for _, newObj := range newObjs {
newObj.Index = newIdx
objs.list = append(objs.list, newObj)
newIdx++
}
} else {
obj.Index = newIdx
objs.list = append(objs.list, obj)
newIdx++
}
}
return nil
}
// ReplaceFuncByName finds an existing GoObject with type GoObjectType_Func by name
// and then performs a replace with the supplied new GoObject.
func (objs *GoObjects) ReplaceFuncByName(name string, fn *GoObject) error {
return objs.ReplaceTypeByName(name, fn, GoObjectType_Func)
}
// ReplaceTypeByName finds an existing GoObject with type by name
// and then performs a replace with the supplied new GoObject.
func (objs *GoObjects) ReplaceTypeByName(name string, newObj *GoObject, objType GoObjectType) error {
if newObj.Name == "" {
newObj.Name = name
}
if newObj.Type == "" && objType != "" {
newObj.Type = objType
}
for _, obj := range objs.list {
if obj.Name == name && (objType == "" || objType == obj.Type) {
return objs.Replace(obj, newObj)
}
}
return errors.WithStack(errGoObjectNotExist)
}
// Empty determines if all the GoObject in the list are line breaks.
func (objs *GoObjects) Empty() bool {
var hasStuff bool
for _, obj := range objs.List() {
switch obj.Type {
case GoObjectType_LineBreak:
//case GoObjectType_Comment:
//case GoObjectType_Import:
// do nothing
default:
hasStuff = true
}
}
return hasStuff
}
// Debug prints out the GoObject to logger.
func (obj *GoObject) Debug(log *log.Logger) {
log.Println(obj.Name)
log.Println(" > type:", obj.Type)
log.Println(" > start lines:")
for _, l := range obj.startLines {
log.Println(" ", l)
}
log.Println(" > sub lines:")
for _, l := range obj.subLines {
log.Println(" ", l)
}
log.Println(" > end lines:")
for _, l := range obj.endLines {
log.Println(" ", l)
}
}
// Defines a property of a struct.
type structProp struct {
Name string
Type string
Tags *structtag.Tags
}
// ParseStructProp extracts the details for a struct property.
func ParseStructProp(obj *GoObject) (structProp, error) {
if obj.Type != GoObjectType_Line {
return structProp{}, errors.Errorf("Unable to parse object of type %s", obj.Type)
}
// Remove any white space from the code line.
ls := strings.TrimSpace(strings.Join(obj.Lines(), " "))
// Extract the property name and type for the line.
// ie: ID string `json:"id"`
var resp structProp
for _, p := range strings.Split(ls, " ") {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if resp.Name == "" {
resp.Name = p
} else if resp.Type == "" {
resp.Type = p
} else {
break
}
}
// If the line contains tags, extract and parse them.
if strings.Contains(ls, "`") {
tagStr := strings.Split(ls, "`")[1]
var err error
resp.Tags, err = structtag.Parse(tagStr)
if err != nil {
err = errors.WithMessagef(err, "Failed to parse struct tag for field %s: %s", resp.Name, tagStr)
return structProp{}, err
}
}
return resp, nil
}

View File

@ -0,0 +1,329 @@
package goparse
import (
"bufio"
"bytes"
"fmt"
"go/format"
"io/ioutil"
"log"
"strings"
"unicode"
"github.com/pkg/errors"
)
var (
errGoParseType = errors.New("Unable to determine type for line")
errGoTypeMissingCodeTemplate = errors.New("No code defined for type")
errGoObjectNotExist = errors.New("GoObject does not exist")
)
// ParseFile reads a go code file and parses into a easily transformable set of objects.
func ParseFile(log *log.Logger, localPath string) (*GoDocument, error) {
// Read the code file.
src, err := ioutil.ReadFile(localPath)
if err != nil {
err = errors.WithMessagef(err, "Failed to read file %s", localPath)
return nil, err
}
// Format the code file source to ensure parse works.
dat, err := format.Source(src)
if err != nil {
err = errors.WithMessagef(err, "Failed to format source for file %s", localPath)
log.Printf("ParseFile : %v\n%v", err, string(src))
return nil, err
}
// Loop of the formatted source code and generate a list of code lines.
lines := []string{}
r := bytes.NewReader(dat)
scanner := bufio.NewScanner(r)
for scanner.Scan() {
lines = append(lines, scanner.Text())
}
if err := scanner.Err(); err != nil {
err = errors.WithMessagef(err, "Failed read formatted source code for file %s", localPath)
return nil, err
}
// Parse the code lines into a set of objects.
objs, err := ParseLines(lines, 0)
if err != nil {
log.Println(err)
return nil, err
}
// Append the resulting objects to the document.
doc := &GoDocument{}
for _, obj := range objs.List() {
if obj.Type == GoObjectType_Package {
doc.Package = obj.Name
}
doc.AddObject(obj)
}
return doc, nil
}
// ParseLines takes the list of formatted code lines and returns the GoObjects.
func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
objs = &GoObjects{
list: []*GoObject{},
}
var (
multiLine bool
multiComment bool
muiliVar bool
)
curDepth := -1
objLines := []string{}
for idx, l := range lines {
ls := strings.TrimSpace(l)
ld := lineDepth(l)
if ld == depth {
if strings.HasPrefix(ls, "/*") {
multiLine = true
multiComment = true
} else if strings.HasSuffix(ls, "(") ||
strings.HasSuffix(ls, "{") {
if !multiLine {
multiLine = true
}
} else if strings.Contains(ls, "`") {
if !multiLine && strings.Count(ls, "`")%2 != 0 {
if muiliVar {
muiliVar = false
} else {
muiliVar = true
}
}
}
objLines = append(objLines, l)
if multiComment {
if strings.HasSuffix(ls, "*/") {
multiComment = false
multiLine = false
}
} else {
if strings.HasPrefix(ls, ")") ||
strings.HasPrefix(ls, "}") {
multiLine = false
}
}
if !multiLine && !muiliVar {
for eidx := idx + 1; eidx < len(lines); eidx++ {
if el := lines[eidx]; strings.TrimSpace(el) == "" {
objLines = append(objLines, el)
} else {
break
}
}
obj, err := ParseGoObject(objLines, depth)
if err != nil {
log.Println(err)
return objs, err
}
err = objs.Add(obj)
if err != nil {
log.Println(err)
return objs, err
}
objLines = []string{}
}
} else if (multiLine && ld >= curDepth && ld >= depth && len(objLines) > 0) || muiliVar {
objLines = append(objLines, l)
if strings.Contains(ls, "`") {
if !multiLine && strings.Count(ls, "`")%2 != 0 {
if muiliVar {
muiliVar = false
} else {
muiliVar = true
}
}
}
}
}
for _, obj := range objs.List() {
children, err := ParseLines(obj.subLines, depth+1)
if err != nil {
log.Println(err)
return objs, err
}
for _, child := range children.List() {
obj.Objects().Add(child)
}
}
return objs, nil
}
// ParseGoObject generates a GoObjected for the given code lines.
func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
// If there are no lines, return a line break.
if len(lines) == 0 {
return &GoEmptyLine, nil
}
firstLine := lines[0]
firstStrip := strings.TrimSpace(firstLine)
if len(firstStrip) == 0 {
return &GoEmptyLine, nil
}
obj = &GoObject{
goObjects: &GoObjects{
list: []*GoObject{},
},
}
if strings.HasPrefix(firstStrip, "var") {
obj.Type = GoObjectType_Var
} else if strings.HasPrefix(firstStrip, "const") {
obj.Type = GoObjectType_Const
} else if strings.HasPrefix(firstStrip, "func") {
obj.Type = GoObjectType_Func
if strings.HasPrefix(firstStrip, "func (") {
funcLine := strings.TrimLeft(strings.TrimSpace(strings.TrimLeft(firstStrip, "func ")), "(")
var structName string
pts := strings.Split(strings.Split(funcLine, ")")[0], " ")
for i := len(pts) - 1; i >= 0; i-- {
ptVal := strings.TrimSpace(pts[i])
if ptVal != "" {
structName = ptVal
break
}
}
var funcName string
pts = strings.Split(strings.Split(funcLine, "(")[0], " ")
for i := len(pts) - 1; i >= 0; i-- {
ptVal := strings.TrimSpace(pts[i])
if ptVal != "" {
funcName = ptVal
break
}
}
obj.Name = fmt.Sprintf("%s.%s", structName, funcName)
} else {
obj.Name = strings.TrimLeft(firstStrip, "func ")
obj.Name = strings.Split(obj.Name, "(")[0]
}
} else if strings.HasSuffix(firstStrip, "struct {") || strings.HasSuffix(firstStrip, "struct{") {
obj.Type = GoObjectType_Struct
if strings.HasPrefix(firstStrip, "type ") {
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
}
obj.Name = strings.Split(firstStrip, " ")[0]
} else if strings.HasPrefix(firstStrip, "type") {
obj.Type = GoObjectType_Type
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
obj.Name = strings.Split(firstStrip, " ")[0]
} else if strings.HasPrefix(firstStrip, "package") {
obj.Name = strings.TrimSpace(strings.TrimLeft(firstStrip, "package "))
obj.Type = GoObjectType_Package
} else if strings.HasPrefix(firstStrip, "import") {
obj.Type = GoObjectType_Import
} else if strings.HasPrefix(firstStrip, "//") {
obj.Type = GoObjectType_Comment
} else if strings.HasPrefix(firstStrip, "/*") {
obj.Type = GoObjectType_Comment
} else {
if depth > 0 {
obj.Type = GoObjectType_Line
} else {
err = errors.WithStack(errGoParseType)
return obj, err
}
}
var (
hasSub bool
muiliVarStart bool
muiliVarSub bool
muiliVarEnd bool
)
for _, l := range lines {
ld := lineDepth(l)
if (ld == depth && !muiliVarSub) || muiliVarStart || muiliVarEnd {
if hasSub && !muiliVarStart {
if strings.TrimSpace(l) != "" {
obj.endLines = append(obj.endLines, l)
}
if strings.Count(l, "`")%2 != 0 {
if muiliVarEnd {
muiliVarEnd = false
} else {
muiliVarEnd = true
}
}
} else {
obj.startLines = append(obj.startLines, l)
if strings.Count(l, "`")%2 != 0 {
if muiliVarStart {
muiliVarStart = false
} else {
muiliVarStart = true
}
}
}
} else if ld > depth || muiliVarSub {
obj.subLines = append(obj.subLines, l)
hasSub = true
if strings.Count(l, "`")%2 != 0 {
if muiliVarSub {
muiliVarSub = false
} else {
muiliVarSub = true
}
}
}
}
// add trailing linebreak
if len(obj.endLines) > 0 {
obj.endLines = append(obj.endLines, "")
}
return obj, err
}
// lineDepth returns the number of spaces for the given code line
// used to determine the code level for nesting objects.
func lineDepth(l string) int {
depth := len(l) - len(strings.TrimLeftFunc(l, unicode.IsSpace))
ls := strings.TrimSpace(l)
if strings.HasPrefix(ls, "}") && strings.Contains(ls, " else ") {
depth = depth + 1
} else if strings.HasPrefix(ls, "case ") {
depth = depth + 1
}
return depth
}

View File

@ -0,0 +1,195 @@
package goparse
import (
"log"
"os"
"strings"
"testing"
"github.com/onsi/gomega"
)
var logger *log.Logger
func init() {
logger = log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
}
func TestParseFileModel1(t *testing.T) {
_, err := ParseFile(logger, "test_gofile_model1.txt")
if err != nil {
t.Fatalf("got error %v", err)
}
}
func TestMultilineVar(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `func ContextAllowedAccountIds(ctx context.Context, db *gorm.DB) (resp akdatamodels.Uint32List, err error) {
resp = []uint32{}
accountId := akcontext.ContextAccountId(ctx)
m := datamodels.UserAccount{}
q := fmt.Sprintf("select
distinct account_id
from %s where account_id = ?", m.TableName())
db = db.Raw(q, accountId)
}
`
code = strings.Replace(code, "\"", "`", -1)
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
func TestNewDocImports(t *testing.T) {
g := gomega.NewGomegaWithT(t)
expected := []string{
"package goparse",
"",
"import (",
" \"github.com/go/pkg1\"",
" \"github.com/go/pkg2\"",
")",
"",
}
doc := &GoDocument{}
doc.SetPackage("goparse")
doc.AddImport(GoImport{Name: "github.com/go/pkg1"})
doc.AddImport(GoImport{Name: "github.com/go/pkg2"})
g.Expect(doc.Lines()).Should(gomega.Equal(expected))
}
func TestParseLines1(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
g := gomega.NewGomegaWithT(t)
obj := datamodels.MockModelNew()
resp, err := ModelCreate(ctx, DB, &obj)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(resp.Name).Should(gomega.Equal(obj.Name))
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
return resp
}
`
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
func TestParseLines2(t *testing.T) {
code := `func structToMap(s interface{}) (resp map[string]interface{}) {
dat, _ := json.Marshal(s)
_ = json.Unmarshal(dat, &resp)
for k, x := range resp {
switch v := x.(type) {
case time.Time:
if v.IsZero() {
delete(resp, k)
}
case *time.Time:
if v == nil || v.IsZero() {
delete(resp, k)
}
case nil:
delete(resp, k)
}
}
return resp
}
`
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
testLineTextMatches(t, objs.Lines(), lines)
}
func TestParseLines3(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `type UserAccountRoleName = string
const (
UserAccountRoleName_None UserAccountRoleName = ""
UserAccountRoleName_Admin UserAccountRoleName = "admin"
UserAccountRoleName_User UserAccountRoleName = "user"
)
type UserAccountRole struct {
Id uint32 ^gorm:"column:id;type:int(10) unsigned AUTO_INCREMENT;primary_key;not null;auto_increment;" truss:"internal:true"^
CreatedAt time.Time ^gorm:"column:created_at;type:datetime;default:CURRENT_TIMESTAMP;not null;" truss:"internal:true"^
UpdatedAt time.Time ^gorm:"column:updated_at;type:datetime;" truss:"internal:true"^
DeletedAt *time.Time ^gorm:"column:deleted_at;type:datetime;" truss:"internal:true"^
Role UserAccountRoleName ^gorm:"unique_index:user_account_role;column:role;type:enum('admin', 'user')"^
// belongs to User
User *User ^gorm:"foreignkey:UserId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true"^
UserId uint32 ^gorm:"unique_index:user_account_role;"^
// belongs to Account
Account *Account ^gorm:"foreignkey:AccountId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true;api_ro:true;"^
AccountId uint32 ^gorm:"unique_index:user_account_role;" truss:"internal:true;api_ro:true;"^
}
func (UserAccountRole) TableName() string {
return "user_account_roles"
}
`
code = strings.Replace(code, "^", "'", -1)
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
func testLineTextMatches(t *testing.T, l1, l2 []string) {
g := gomega.NewGomegaWithT(t)
m1 := []string{}
for _, l := range l1 {
l = strings.TrimSpace(l)
if l != "" {
m1 = append(m1, l)
}
}
m2 := []string{}
for _, l := range l2 {
l = strings.TrimSpace(l)
if l != "" {
m2 = append(m2, l)
}
}
g.Expect(m1).Should(gomega.Equal(m2))
}

View File

@ -0,0 +1,126 @@
package account
import (
"database/sql"
"database/sql/driver"
"time"
"github.com/lib/pq"
"github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9"
)
// Account represents someone with access to our system.
type Account struct {
ID string `json:"id"`
Name string `json:"name"`
Address1 string `json:"address1"`
Address2 string `json:"address2"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
Zipcode string `json:"zipcode"`
Status AccountStatus `json:"status"`
Timezone string `json:"timezone"`
SignupUserID sql.NullString `json:"signup_user_id"`
BillingUserID sql.NullString `json:"billing_user_id"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ArchivedAt pq.NullTime `json:"archived_at"`
}
// CreateAccountRequest contains information needed to create a new Account.
type CreateAccountRequest struct {
Name string `json:"name" validate:"required,unique"`
Address1 string `json:"address1" validate:"required"`
Address2 string `json:"address2" validate:"omitempty"`
City string `json:"city" validate:"required"`
Region string `json:"region" validate:"required"`
Country string `json:"country" validate:"required"`
Zipcode string `json:"zipcode" validate:"required"`
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
Timezone *string `json:"timezone" validate:"omitempty"`
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
}
// UpdateAccountRequest defines what information may be provided to modify an existing
// Account. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type UpdateAccountRequest struct {
ID string `validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty,unique"`
Address1 *string `json:"address1" validate:"omitempty"`
Address2 *string `json:"address2" validate:"omitempty"`
City *string `json:"city" validate:"omitempty"`
Region *string `json:"region" validate:"omitempty"`
Country *string `json:"country" validate:"omitempty"`
Zipcode *string `json:"zipcode" validate:"omitempty"`
Status *AccountStatus `json:"status" validate:"omitempty,oneof=active pending disabled"`
Timezone *string `json:"timezone" validate:"omitempty"`
SignupUserID *string `json:"signup_user_id" validate:"omitempty,uuid"`
BillingUserID *string `json:"billing_user_id" validate:"omitempty,uuid"`
}
// AccountFindRequest defines the possible options to search for accounts. By default
// archived accounts will be excluded from response.
type AccountFindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
IncludedArchived bool
}
// AccountStatus represents the status of an account.
type AccountStatus string
// AccountStatus values define the status field of a user account.
const (
// AccountStatus_Active defines the state when a user can access an account.
AccountStatus_Active AccountStatus = "active"
// AccountStatus_Pending defined the state when an account was created but
// not activated.
AccountStatus_Pending AccountStatus = "pending"
// AccountStatus_Disabled defines the state when a user has been disabled from
// accessing an account.
AccountStatus_Disabled AccountStatus = "disabled"
)
// AccountStatus_Values provides list of valid AccountStatus values.
var AccountStatus_Values = []AccountStatus{
AccountStatus_Active,
AccountStatus_Pending,
AccountStatus_Disabled,
}
// Scan supports reading the AccountStatus value from the database.
func (s *AccountStatus) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = AccountStatus(string(asBytes))
return nil
}
// Value converts the AccountStatus value to be stored in the database.
func (s AccountStatus) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof=active invited disabled")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the AccountStatus value to a string.
func (s AccountStatus) String() string {
return string(s)
}

View File

@ -0,0 +1,227 @@
package main
import (
"encoding/json"
"expvar"
"io/ioutil"
"log"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"geeks-accelerator/oss/saas-starter-kit/example-project/tools/truss/cmd/dbtable2crud"
"github.com/kelseyhightower/envconfig"
"github.com/lib/pq"
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/urfave/cli"
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
)
// build is the git version of this program. It is set using build flags in the makefile.
var build = "develop"
// service is the name of the program used for logging, tracing and the
// the prefix used for loading env variables
// ie: export TRUSS_ENV=dev
var service = "TRUSS"
func main() {
// =========================================================================
// Logging
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
// =========================================================================
// Configuration
var cfg struct {
DB struct {
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
User string `default:"postgres" envconfig:"USER"`
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
Database string `default:"shared" envconfig:"DATABASE"`
Driver string `default:"postgres" envconfig:"DRIVER"`
Timezone string `default:"utc" envconfig:"TIMEZONE"`
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
}
}
// For additional details refer to https://github.com/kelseyhightower/envconfig
if err := envconfig.Process(service, &cfg); err != nil {
log.Fatalf("main : Parsing Config : %v", err)
}
// TODO: can't use flag.Process here since it doesn't support nested arg options
//if err := flag.Process(&cfg); err != nil {
/// if err != flag.ErrHelp {
// log.Fatalf("main : Parsing Command Line : %v", err)
// }
// return // We displayed help.
//}
// =========================================================================
// Log App Info
// Print the build version for our logs. Also expose it under /debug/vars.
expvar.NewString("build").Set(build)
log.Printf("main : Started : Application Initializing version %q", build)
defer log.Println("main : Completed")
// Print the config for our logs. It's important to any credentials in the config
// that could expose a security risk are excluded from being json encoded by
// applying the tag `json:"-"` to the struct var.
{
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
log.Fatalf("main : Marshalling Config to JSON : %v", err)
}
log.Printf("main : Config : %v\n", string(cfgJSON))
}
// =========================================================================
// Start Database
var dbUrl url.URL
{
// Query parameters.
var q url.Values = make(map[string][]string)
// Handle SSL Mode
if cfg.DB.DisableTLS {
q.Set("sslmode", "disable")
} else {
q.Set("sslmode", "require")
}
q.Set("timezone", cfg.DB.Timezone)
// Construct url.
dbUrl = url.URL{
Scheme: cfg.DB.Driver,
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
Host: cfg.DB.Host,
Path: cfg.DB.Database,
RawQuery: q.Encode(),
}
}
// Register informs the sqlxtrace package of the driver that we will be using in our program.
// It uses a default service name, in the below case "postgres.db". To use a custom service
// name use RegisterWithServiceName.
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
if err != nil {
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
}
defer masterDb.Close()
// =========================================================================
// Start Truss
app := cli.NewApp()
app.Commands = []cli.Command{
{
Name: "dbtable2crud",
Aliases: []string{"dbtable2crud"},
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project",
Flags: []cli.Flag{
cli.StringFlag{Name: "dbtable, table"},
cli.StringFlag{Name: "modelFile, modelfile, file"},
cli.StringFlag{Name: "modelName, modelname, model"},
cli.StringFlag{Name: "templateDir, templates", Value: "./templates/dbtable2crud"},
cli.StringFlag{Name: "projectPath", Value: ""},
},
Action: func(c *cli.Context) error {
dbTable := strings.TrimSpace(c.String("dbtable"))
modelFile := strings.TrimSpace(c.String("modelFile"))
modelName := strings.TrimSpace(c.String("modelName"))
templateDir := strings.TrimSpace(c.String("templateDir"))
projectPath := strings.TrimSpace(c.String("projectPath"))
pwd, err := os.Getwd()
if err != nil {
return errors.WithMessage(err, "Failed to get current working directory")
}
if !path.IsAbs(templateDir) {
templateDir = filepath.Join(pwd, templateDir)
}
ok, err := exists(templateDir)
if err != nil {
return errors.WithMessage(err, "Failed to load template directory")
} else if !ok {
return errors.Errorf("Template directory %s does not exist", templateDir)
}
if modelFile == "" {
return errors.Errorf("Model file path is required")
}
if !path.IsAbs(modelFile) {
modelFile = filepath.Join(pwd, modelFile)
}
ok, err = exists(modelFile)
if err != nil {
return errors.WithMessage(err, "Failed to load model file")
} else if !ok {
return errors.Errorf("Model file %s does not exist", modelFile)
}
// Load the project path from go.mod if not set.
if projectPath == "" {
goModFile := filepath.Join(pwd, "../../go.mod")
ok, err = exists(goModFile)
if err != nil {
return errors.WithMessage(err, "Failed to load go.mod for project")
} else if !ok {
return errors.Errorf("Failed to locate project go.mod at %s", goModFile)
}
b, err := ioutil.ReadFile(goModFile)
if err != nil {
return errors.WithMessagef(err, "Failed to read go.mod at %s", goModFile)
}
lines := strings.Split(string(b), "\n")
for _, l := range lines {
if strings.HasPrefix(l, "module ") {
projectPath = strings.TrimSpace(strings.Split(l, " ")[1])
break
}
}
}
if modelName == "" {
modelName = strings.Split(filepath.Base(modelFile), ".")[0]
modelName = strings.Replace(modelName, "_", " ", -1)
modelName = strings.Replace(modelName, "-", " ", -1)
modelName = strings.Title(modelName)
modelName = strings.Replace(modelName, " ", "", -1)
}
return dbtable2crud.Run(masterDb, log, cfg.DB.Database, dbTable, modelFile, modelName, templateDir, projectPath)
},
},
}
err = app.Run(os.Args)
if err != nil {
log.Fatalf("main : Truss : %+v", err)
}
log.Printf("main : Truss : Completed")
}
// exists returns a bool as to whether a file path exists.
func exists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return true, err
}

View File

@ -0,0 +1,4 @@
export TRUSS_DB_HOST=127.0.0.1:5433
export TRUSS_DB_USER=postgres
export TRUSS_DB_PASS=postgres
export TRUSS_DB_DISABLE_TLS=true

View File

@ -0,0 +1,503 @@
{{ define "imports"}}
import (
"context"
"database/sql"
"time"
"{{ $.GoSrcPath }}/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
{{ end }}
{{ define "Globals"}}
const (
// The database table for {{ $.Model.Name }}
{{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}"
)
var (
// ErrNotFound abstracts the postgres not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
{{ end }}
{{ define "Helpers"}}
// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }}
var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}"
// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct
// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns
func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
var (
m {{ $.Model.Name }}
err error
)
err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }})
if err != nil {
return nil, errors.WithStack(err)
}
return &a, nil
}
{{ end }}
{{ define "ACL"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
{{ if $hasAccountId }}
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
if claims.Audience != "" {
// select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID]
query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var {{ FormatCamelLower $.Model.PrimaryField }} string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }})
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access
// to the specified {{ FormatCamelLowerTitle $.Model.Name }}.
if {{ FormatCamelLower $.Model.PrimaryField }} == "" {
return errors.WithStack(ErrForbidden)
}
}
{{ else }}
// TODO: Unable to auto generate sql statement, update accordingly.
panic("Not implemented!")
{{ end }}
return nil
}
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
if err != nil {
return err
}
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
return nil
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
// 1. No claims, request is internal, no ACL applied
{{ if $hasAccountId }}
// 2. All role types can access their user ID
{{ end }}
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" {
return nil
}
{{ if $hasAccountId }}
query.Where(query.Equal("account_id", claims.Audience))
{{ end }}
return nil
}
{{ end }}
{{ define "Find"}}
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}
// selectQuery constructs a base select query for {{ $.Model.Name }}
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
query.From({{ FormatCamelLower $.Model.Name }}TableName)
return query
}
// findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }})
}
// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find")
defer span.Finish()
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
query.From({{ FormatCamelLower $.Model.Name }}TableName)
{{ if $hasArchived }}
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
{{ end }}
// Check to see if a sub query needs to be applied for the claims.
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// Fetch all entries from the db.
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed")
return nil, err
}
// Iterate over each row.
resp := []*{{ $.Model.Name }}{}
for rows.Next() {
u, err := mapRowsTo{{ $.Model.Name }}(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read")
defer span.Finish()
// Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }}
query := selectQuery()
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }}))
res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }})
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
{{ end }}
{{ define "Create"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
{{ $createFields := (index $.StructFields $reqName) }}
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
defer span.Finish()
if claims.Audience != "" {
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
{{ if $hasAccountId }}
if req.AccountId != "" {
// Request accountId must match claims.
if req.AccountId != claims.Audience {
return errors.WithStack(ErrForbidden)
}
} else {
// Set the accountId from claims.
req.AccountId = claims.Audience
}
{{ end }}
}
v := validator.New()
// Validate the request.
err = v.Struct(req)
if err != nil {
return nil, err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
m := {{ $.Model.Name }}{
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}
{{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }},
{{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now,
{{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }}
{{ end }}{{ end }}
}
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
if req.{{ $f.FieldName }} != nil {
{{ if eq $f.FieldType "sql.NullString" }}
m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
{{ else if eq $f.FieldType "*sql.NullString" }}
m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
{{ else }}
m.{{ $f.FieldName }} = *req.{{ $f.FieldName }}
{{ end }}
}
{{ end }}{{ end }}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName)
query.Cols(
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}",
{{ end }}{{ end }}
)
query.Values(
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }},
{{ end }}{{ end }}
)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed")
return nil, err
}
return &a, nil
}
{{ end }}
{{ define "Update"}}
{{ $reqName := (Concat $.Model.Name "UpdateRequest") }}
{{ $updateFields := (index $.StructFields $reqName) }}
// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update")
defer span.Finish()
v := validator.New()
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }})
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
var fields []string
{{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }}
{{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }}
if req.{{ $uf.FieldName }} != nil {
{{ if and ($optional) ($isUuid) }}
if *req.{{ $uf.FieldName }} != "" {
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
} else {
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil))
}
{{ else }}
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
{{ end }}
}
{{ end }}{{ end }}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
{{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
{{ end }}
query.Set(fields...)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}
{{ define "Archive"}}
// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
}{
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}
{{ define "Delete"}}
// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
}{
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}

View File

@ -0,0 +1,79 @@
{{ define "CreateRequest"}}
// {{ FormatCamel $.Model.Name }}CreateRequest contains information needed to create a new {{ FormatCamel $.Model.Name }}.
type {{ FormatCamel $.Model.Name }}CreateRequest struct {
{{ range $fk, $f := .Model.Fields }}{{ if and ($f.ApiCreate) (ne $f.FieldName $.Model.PrimaryField) }}{{ $optional := (FieldTagHasOption $f "validate" "omitempty") }}
{{ $f.FieldName }} {{ if and ($optional) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ FieldTag $f "validate" }}`
{{ end }}{{ end }}
}
{{ end }}
{{ define "UpdateRequest"}}
// {{ FormatCamel $.Model.Name }}UpdateRequest defines what information may be provided to modify an existing
// {{ FormatCamel $.Model.Name }}. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank.
type {{ FormatCamel $.Model.Name }}UpdateRequest struct {
{{ range $fk, $f := .Model.Fields }}{{ if $f.ApiUpdate }}
{{ $f.FieldName }} {{ if and (ne $f.FieldName $.Model.PrimaryField) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ if ne $f.FieldName $.Model.PrimaryField }}{{ FieldTagReplaceOrPrepend $f "validate" "required" "omitempty" }}{{ else }}{{ FieldTagReplaceOrPrepend $f "validate" "omitempty" "required" }}{{ end }}`
{{ end }}{{ end }}
}
{{ end }}
{{ define "FindRequest"}}
// {{ FormatCamel $.Model.Name }}FindRequest defines the possible options to search for {{ FormatCamelPluralTitleLower $.Model.Name }}. By default
// archived {{ FormatCamelLowerTitle $.Model.Name }} will be excluded from response.
type {{ FormatCamel $.Model.Name }}FindRequest struct {
Where *string
Args []interface{}
Order []string
Limit *uint
Offset *uint
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}{{ if $hasArchived }}IncludedArchived bool{{ end }}
}
{{ end }}
{{ define "Enums"}}
{{ range $fk, $f := .Model.Fields }}{{ if $f.DbColumn }}{{ if $f.DbColumn.IsEnum }}
// {{ $f.FieldType }} represents the {{ $f.ColumnName }} of {{ FormatCamelLowerTitle $.Model.Name }}.
type {{ $f.FieldType }} string
// {{ $f.FieldType }} values define the {{ $f.ColumnName }} field of {{ FormatCamelLowerTitle $.Model.Name }}.
const (
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}"
{{ end }}
)
// {{ $f.FieldType }}_Values provides list of valid {{ $f.FieldType }} values.
var {{ $f.FieldType }}_Values = []{{ $f.FieldType }}{
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
{{ $f.FieldType }}_{{ FormatCamel $ev }},
{{ end }}
}
// Scan supports reading the {{ $f.FieldType }} value from the database.
func (s *{{ $f.FieldType }}) Scan(value interface{}) error {
asBytes, ok := value.([]byte)
if !ok {
return errors.New("Scan source is not []byte")
}
*s = {{ $f.FieldType }}(string(asBytes))
return nil
}
// Value converts the {{ $f.FieldType }} value to be stored in the database.
func (s {{ $f.FieldType }}) Value() (driver.Value, error) {
v := validator.New()
errs := v.Var(s, "required,oneof={{ JoinStrings $f.DbColumn.EnumValues " " }}")
if errs != nil {
return nil, errs
}
return string(s), nil
}
// String converts the {{ $f.FieldType }} value to a string.
func (s {{ $f.FieldType }}) String() string {
return string(s)
}
{{ end }}{{ end }}{{ end }}
{{ end }}