From c16ab0e3f9580112a6e21a88672f78c436219bac Mon Sep 17 00:00:00 2001 From: Lee Brown Date: Mon, 12 Aug 2019 12:37:23 -0800 Subject: [PATCH] Remove depreciated tool truss --- tools/truss/.gitignore | 1 - tools/truss/README.md | 68 --- tools/truss/cmd/dbtable2crud/db.go | 149 ----- tools/truss/cmd/dbtable2crud/dbtable2crud.go | 431 --------------- tools/truss/cmd/dbtable2crud/models.go | 229 -------- tools/truss/cmd/dbtable2crud/templates.go | 345 ------------ tools/truss/internal/goparse/doc.go | 301 ----------- tools/truss/internal/goparse/doc_object.go | 458 ---------------- tools/truss/internal/goparse/goparse.go | 352 ------------ tools/truss/internal/goparse/goparse_test.go | 201 ------- tools/truss/main.go | 228 -------- tools/truss/makefile | 10 - tools/truss/sample.env | 22 - .../templates/dbtable2crud/model_crud.tmpl | 510 ------------------ .../dbtable2crud/model_crud_test.tmpl | 131 ----- .../truss/templates/dbtable2crud/models.tmpl | 80 --- 16 files changed, 3516 deletions(-) delete mode 100644 tools/truss/.gitignore delete mode 100644 tools/truss/README.md delete mode 100644 tools/truss/cmd/dbtable2crud/db.go delete mode 100644 tools/truss/cmd/dbtable2crud/dbtable2crud.go delete mode 100644 tools/truss/cmd/dbtable2crud/models.go delete mode 100644 tools/truss/cmd/dbtable2crud/templates.go delete mode 100644 tools/truss/internal/goparse/doc.go delete mode 100644 tools/truss/internal/goparse/doc_object.go delete mode 100644 tools/truss/internal/goparse/goparse.go delete mode 100644 tools/truss/internal/goparse/goparse_test.go delete mode 100644 tools/truss/main.go delete mode 100644 tools/truss/makefile delete mode 100644 tools/truss/sample.env delete mode 100644 tools/truss/templates/dbtable2crud/model_crud.tmpl delete mode 100644 tools/truss/templates/dbtable2crud/model_crud_test.tmpl delete mode 100644 tools/truss/templates/dbtable2crud/models.tmpl diff --git a/tools/truss/.gitignore b/tools/truss/.gitignore deleted file mode 100644 index 0a931f8..0000000 --- a/tools/truss/.gitignore +++ /dev/null @@ -1 +0,0 @@ -truss diff --git a/tools/truss/README.md b/tools/truss/README.md deleted file mode 100644 index 46c5522..0000000 --- a/tools/truss/README.md +++ /dev/null @@ -1,68 +0,0 @@ -# SaaS Truss - -Copyright 2019, Geeks Accelerator -accelerator@geeksinthewoods.com.com - - -## Description - -Truss provides code generation to reduce copy/pasting. - - -## Local Installation - -### Build -```bash -go build . -``` - -### Configuration -```bash -./truss -h - -Usage of ./truss ---cmd string ---db_host string <127.0.0.1:5433> ---db_user string ---db_pass string ---db_database string ---db_driver string ---db_timezone string ---db_disabletls bool -``` - -## Commands: - -## dbtable2crud - -Used to bootstrap a new business logic package with basic CRUD. - -**Usage** -```bash -./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] -``` - -**Example** -1. Define a new database table in `internal/schema/migrations.go` - - -2. Create a new file for the base model at `internal/projects/models.go`. Only the following struct needs to be included. All the other times will be generated. -```go -// Project represents a workflow. -type Project struct { - ID string `json:"id" validate:"required,uuid"` - AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"` - Name string `json:"name" validate:"required"` - Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` - CreatedAt time.Time `json:"created_at" truss:"api-read"` - UpdatedAt time.Time `json:"updated_at" truss:"api-read"` - ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"` -} -``` - -3. Run `dbtable2crud` -```bash -./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -save=true -``` - - diff --git a/tools/truss/cmd/dbtable2crud/db.go b/tools/truss/cmd/dbtable2crud/db.go deleted file mode 100644 index a14fb39..0000000 --- a/tools/truss/cmd/dbtable2crud/db.go +++ /dev/null @@ -1,149 +0,0 @@ -package dbtable2crud - -import ( - "fmt" - "strings" - - "github.com/jmoiron/sqlx" - "github.com/lib/pq" - "github.com/pkg/errors" -) - -type psqlColumn struct { - Table string - Column string - ColumnId int64 - NotNull bool - DataTypeFull string - DataTypeName string - DataTypeLength *int - NumericPrecision *int - NumericScale *int - IsPrimaryKey bool - PrimaryKeyName *string - IsUniqueKey bool - UniqueKeyName *string - IsForeignKey bool - ForeignKeyName *string - ForeignKeyColumnId pq.Int64Array - ForeignKeyTable *string - ForeignKeyLocalColumnId pq.Int64Array - DefaultFull *string - DefaultValue *string - IsEnum bool - EnumTypeId *string - EnumValues []string -} - -// descTable lists all the columns for a table. -func descTable(db *sqlx.DB, dbName, dbTable string) ([]psqlColumn, error) { - - queryStr := fmt.Sprintf(`SELECT - c.relname as table, - f.attname as column, - f.attnum as columnId, - f.attnotnull as not_null, - pg_catalog.format_type(f.atttypid,f.atttypmod) AS data_type_full, - t.typname AS data_type_name, - CASE WHEN f.atttypmod >= 0 AND t.typname <> 'numeric'THEN (f.atttypmod - 4) --first 4 bytes are for storing actual length of data - END AS data_type_length, - CASE WHEN t.typname = 'numeric' THEN (((f.atttypmod - 4) >> 16) & 65535) - END AS numeric_precision, - CASE WHEN t.typname = 'numeric' THEN ((f.atttypmod - 4)& 65535 ) - END AS numeric_scale, - CASE WHEN p.contype = 'p' THEN true ELSE false - END AS is_primary_key, - CASE WHEN p.contype = 'p' THEN p.conname - END AS primary_key_name, - CASE WHEN p.contype = 'u' THEN true ELSE false - END AS is_unique_key, - CASE WHEN p.contype = 'u' THEN p.conname - END AS unique_key_name, - CASE WHEN p.contype = 'f' THEN true ELSE false - END AS is_foreign_key, - CASE WHEN p.contype = 'f' THEN p.conname - END AS foreignkey_name, - CASE WHEN p.contype = 'f' THEN p.confkey - END AS foreign_key_columnid, - CASE WHEN p.contype = 'f' THEN g.relname - END AS foreign_key_table, - CASE WHEN p.contype = 'f' THEN p.conkey - END AS foreign_key_local_column_id, - CASE WHEN f.atthasdef = 't' THEN d.adsrc - END AS default_value, - CASE WHEN t.typtype = 'e' THEN true ELSE false - END AS is_enum, - CASE WHEN t.typtype = 'e' THEN t.oid - END AS enum_type_id - FROM pg_attribute f - JOIN pg_class c ON c.oid = f.attrelid - JOIN pg_type t ON t.oid = f.atttypid - LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum - LEFT JOIN pg_namespace n ON n.oid = c.relnamespace - LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) - LEFT JOIN pg_class AS g ON p.confrelid = g.oid - WHERE c.relkind = 'r'::char - AND f.attisdropped = false - AND c.relname = '%s' - AND f.attnum > 0 - ORDER BY f.attnum - ;`, dbTable) // AND n.nspname = '%s' - - rows, err := db.Query(queryStr) - if err != nil { - err = errors.Wrapf(err, "query - %s", queryStr) - return nil, err - } - - // iterate over each row - var resp []psqlColumn - for rows.Next() { - var c psqlColumn - err = rows.Scan(&c.Table, &c.Column, &c.ColumnId, &c.NotNull, &c.DataTypeFull, &c.DataTypeName, &c.DataTypeLength, &c.NumericPrecision, &c.NumericScale, &c.IsPrimaryKey, &c.PrimaryKeyName, &c.IsUniqueKey, &c.UniqueKeyName, &c.IsForeignKey, &c.ForeignKeyName, &c.ForeignKeyColumnId, &c.ForeignKeyTable, &c.ForeignKeyLocalColumnId, &c.DefaultFull, &c.IsEnum, &c.EnumTypeId) - if err != nil { - err = errors.Wrapf(err, "query - %s", queryStr) - return nil, err - } - - if c.DefaultFull != nil { - defaultValue := *c.DefaultFull - - // "'active'::project_status_t" - defaultValue = strings.Split(defaultValue, "::")[0] - c.DefaultValue = &defaultValue - } - - resp = append(resp, c) - } - - for colIdx, dbCol := range resp { - if !dbCol.IsEnum { - continue - } - - queryStr := fmt.Sprintf(`SELECT e.enumlabel - FROM pg_enum AS e - WHERE e.enumtypid = '%s' - ORDER BY e.enumsortorder`, *dbCol.EnumTypeId) - - rows, err := db.Query(queryStr) - if err != nil { - err = errors.Wrapf(err, "query - %s", queryStr) - return nil, err - } - - for rows.Next() { - var v string - err = rows.Scan(&v) - if err != nil { - err = errors.Wrapf(err, "query - %s", queryStr) - return nil, err - } - dbCol.EnumValues = append(dbCol.EnumValues, v) - } - - resp[colIdx] = dbCol - } - - return resp, nil -} diff --git a/tools/truss/cmd/dbtable2crud/dbtable2crud.go b/tools/truss/cmd/dbtable2crud/dbtable2crud.go deleted file mode 100644 index 1e68774..0000000 --- a/tools/truss/cmd/dbtable2crud/dbtable2crud.go +++ /dev/null @@ -1,431 +0,0 @@ -package dbtable2crud - -import ( - "fmt" - "log" - "os" - "path/filepath" - "strings" - - "geeks-accelerator/oss/saas-starter-kit/internal/schema" - "geeks-accelerator/oss/saas-starter-kit/tools/truss/internal/goparse" - "github.com/dustin/go-humanize/english" - "github.com/fatih/camelcase" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - "github.com/sergi/go-diff/diffmatchpatch" -) - -// Run in the main entry point for the dbtable2crud cmd. -func Run(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir, goSrcPath string, saveChanges bool) error { - log.SetPrefix(log.Prefix() + " : dbtable2crud") - - // Ensure the schema is up to date - if err := schema.Migrate(db, log); err != nil { - return err - } - - // When dbTable is empty, lower case the model name - if dbTable == "" { - dbTable = strings.Join(camelcase.Split(modelName), " ") - dbTable = english.PluralWord(2, dbTable, "") - dbTable = strings.Replace(dbTable, " ", "_", -1) - dbTable = strings.ToLower(dbTable) - } - - // Parse the model file and load the specified model struct. - model, err := parseModelFile(db, log, dbName, dbTable, modelFile, modelName) - if err != nil { - return err - } - - // Basic lint of the model struct. - err = validateModel(log, model) - if err != nil { - return err - } - - tmplData := map[string]interface{}{ - "GoSrcPath": goSrcPath, - } - - // Update the model file with new or updated code. - err = updateModel(log, model, modelFile, templateDir, tmplData, saveChanges) - if err != nil { - return err - } - - // Update the model crud file with new or updated code. - err = updateModelCrud(db, log, dbName, dbTable, modelFile, modelName, templateDir, model, tmplData, saveChanges) - if err != nil { - return err - } - - return nil -} - -// validateModel performs a basic lint of the model struct to ensure -// code gen output is correct. -func validateModel(log *log.Logger, model *modelDef) error { - for _, sf := range model.Fields { - if sf.DbColumn == nil && sf.ColumnName != "-" { - log.Printf("validateStruct : Unable to find struct field for db column %s\n", sf.ColumnName) - } - - var expectedType string - switch sf.FieldName { - case "ID": - expectedType = "string" - case "CreatedAt": - expectedType = "time.Time" - case "UpdatedAt": - expectedType = "time.Time" - case "ArchivedAt": - expectedType = "pq.NullTime" - } - - if expectedType != "" && expectedType != sf.FieldType { - log.Printf("validateStruct : Struct field %s should be of type %s not %s\n", sf.FieldName, expectedType, sf.FieldType) - } - } - - return nil -} - -// updateModel updated the parsed code file with the new code. -func updateModel(log *log.Logger, model *modelDef, modelFile, templateDir string, tmplData map[string]interface{}, saveChanges bool) error { - - // Execute template and parse code to be used to compare against modelFile. - tmplObjs, err := loadTemplateObjects(log, model, templateDir, "models.tmpl", tmplData) - if err != nil { - return err - } - - // Store the current code as a string to produce a diff. - curCode := model.String() - - objHeaders := []*goparse.GoObject{} - - for _, obj := range tmplObjs { - if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak { - objHeaders = append(objHeaders, obj) - continue - } - - if model.HasType(obj.Name, obj.Type) { - cur := model.Objects().Get(obj.Name, obj.Type) - - newObjs := []*goparse.GoObject{} - if len(objHeaders) > 0 { - // Remove any comments and linebreaks before the existing object so updates can be added. - removeObjs := []*goparse.GoObject{} - for idx := cur.Index - 1; idx > 0; idx-- { - prevObj := model.Objects().List()[idx] - if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak { - removeObjs = append(removeObjs, prevObj) - } else { - break - } - } - - if len(removeObjs) > 0 { - err := model.Objects().Remove(removeObjs...) - if err != nil { - err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name) - return err - } - - // Make sure the current index is correct. - cur = model.Objects().Get(obj.Name, obj.Type) - } - - // Append comments and line breaks before adding the object - for _, c := range objHeaders { - newObjs = append(newObjs, c) - } - } - - newObjs = append(newObjs, obj) - - // Do the object replacement. - err := model.Objects().Replace(cur, newObjs...) - if err != nil { - err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name) - return err - } - } else { - // Append comments and line breaks before adding the object - for _, c := range objHeaders { - err := model.Objects().Add(c) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, model.Name) - return err - } - } - - err := model.Objects().Add(obj) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, model.Name) - return err - } - } - - objHeaders = []*goparse.GoObject{} - } - - // Set some flags to determine additional imports and need to be added. - var hasEnum bool - var hasPq bool - for _, f := range model.Fields { - if f.DbColumn != nil && f.DbColumn.IsEnum { - hasEnum = true - } - if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") { - hasPq = true - } - } - - reqImports := []string{} - if hasEnum { - reqImports = append(reqImports, "database/sql/driver") - reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9") - reqImports = append(reqImports, "github.com/pkg/errors") - } - - if hasPq { - reqImports = append(reqImports, "github.com/lib/pq") - } - - for _, in := range reqImports { - err := model.AddImport(goparse.GoImport{Name: in}) - if err != nil { - err = errors.WithMessagef(err, "Failed to add import %s for %s", in, model.Name) - return err - } - } - - if saveChanges { - err = model.Save(modelFile) - if err != nil { - err = errors.WithMessagef(err, "Failed to save changes for %s to %s", model.Name, modelFile) - return err - } - } else { - // Produce a diff after the updates have been applied. - dmp := diffmatchpatch.New() - diffs := dmp.DiffMain(curCode, model.String(), true) - fmt.Println(dmp.DiffPrettyText(diffs)) - } - - return nil -} - -// updateModelCrud updated the parsed code file with the new code. -func updateModelCrud(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir string, baseModel *modelDef, tmplData map[string]interface{}, saveChanges bool) error { - - // Load all the updated struct fields from the base model file. - structFields := make(map[string]map[string]modelField) - for _, obj := range baseModel.GoDocument.Objects().List() { - if obj.Type != goparse.GoObjectType_Struct || obj.Name == baseModel.Name { - continue - } - - objFields, err := parseModelFields(baseModel.GoDocument, obj.Name, baseModel) - if err != nil { - return err - } - - structFields[obj.Name] = make(map[string]modelField) - for _, f := range objFields { - structFields[obj.Name][f.FieldName] = f - } - } - - // Append the struct fields to be used for template execution. - if tmplData == nil { - tmplData = make(map[string]interface{}) - } - tmplData["StructFields"] = structFields - - // Get the dir to store crud methods and test files. - modelDir := filepath.Dir(modelFile) - - // Process the CRUD hanlders template and write to file. - crudFilePath := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+".go") - crudTmplFile := "model_crud.tmpl" - err := updateModelCrudFile(db, log, dbName, dbTable, templateDir, crudFilePath, crudTmplFile, baseModel, tmplData, saveChanges) - if err != nil { - return err - } - - // Process the CRUD test template and write to file. - testFilePath := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+"_test.go") - testTmplFile := "model_crud_test.tmpl" - err = updateModelCrudFile(db, log, dbName, dbTable, templateDir, testFilePath, testTmplFile, baseModel, tmplData, saveChanges) - if err != nil { - return err - } - - return nil -} - -// updateModelCrudFile processes the input file. -func updateModelCrudFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, templateDir, crudFilePath, tmplFile string, baseModel *modelDef, tmplData map[string]interface{}, saveChanges bool) error { - - // Execute template and parse code to be used to compare against modelFile. - tmplObjs, err := loadTemplateObjects(log, baseModel, templateDir, tmplFile, tmplData) - if err != nil { - return err - } - - var crudDoc *goparse.GoDocument - if _, err := os.Stat(crudFilePath); os.IsNotExist(err) { - crudDoc, err = goparse.NewGoDocument(baseModel.Package) - if err != nil { - return err - } - } else { - // Parse the supplied model file. - crudDoc, err = goparse.ParseFile(log, crudFilePath) - if err != nil { - return err - } - } - - // Store the current code as a string to produce a diff. - curCode := crudDoc.String() - - objHeaders := []*goparse.GoObject{} - - for _, obj := range tmplObjs { - if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak { - objHeaders = append(objHeaders, obj) - continue - } - - if obj.Name == "" && (obj.Type == goparse.GoObjectType_Var || obj.Type == goparse.GoObjectType_Const) { - var curDocObj *goparse.GoObject - for _, subObj := range obj.Objects().List() { - for _, do := range crudDoc.Objects().List() { - if do.Name == "" && (do.Type == goparse.GoObjectType_Var || do.Type == goparse.GoObjectType_Const) { - for _, subDocObj := range do.Objects().List() { - if subDocObj.String() == subObj.String() && subObj.Type != goparse.GoObjectType_LineBreak { - curDocObj = do - break - } - - } - } - } - } - - if curDocObj != nil { - for _, subObj := range obj.Objects().List() { - var hasSubObj bool - for _, subDocObj := range curDocObj.Objects().List() { - if subDocObj.String() == subObj.String() { - hasSubObj = true - break - } - } - - if !hasSubObj { - curDocObj.Objects().Add(subObj) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name) - return err - } - } - } - } else { - // Append comments and line breaks before adding the object - for _, c := range objHeaders { - err := crudDoc.Objects().Add(c) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name) - return err - } - } - - err := crudDoc.Objects().Add(obj) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name) - return err - } - } - } else if crudDoc.HasType(obj.Name, obj.Type) { - cur := crudDoc.Objects().Get(obj.Name, obj.Type) - - newObjs := []*goparse.GoObject{} - if len(objHeaders) > 0 { - // Remove any comments and linebreaks before the existing object so updates can be added. - removeObjs := []*goparse.GoObject{} - for idx := cur.Index - 1; idx > 0; idx-- { - prevObj := crudDoc.Objects().List()[idx] - if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak { - removeObjs = append(removeObjs, prevObj) - } else { - break - } - } - - if len(removeObjs) > 0 { - err := crudDoc.Objects().Remove(removeObjs...) - if err != nil { - err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name) - return err - } - - // Make sure the current index is correct. - cur = crudDoc.Objects().Get(obj.Name, obj.Type) - } - - // Append comments and line breaks before adding the object - for _, c := range objHeaders { - newObjs = append(newObjs, c) - } - } - - newObjs = append(newObjs, obj) - - // Do the object replacement. - err := crudDoc.Objects().Replace(cur, newObjs...) - if err != nil { - err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name) - return err - } - } else { - // Append comments and line breaks before adding the object - for _, c := range objHeaders { - err := crudDoc.Objects().Add(c) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name) - return err - } - } - - err := crudDoc.Objects().Add(obj) - if err != nil { - err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name) - return err - } - } - - objHeaders = []*goparse.GoObject{} - } - - if saveChanges { - err = crudDoc.Save(crudFilePath) - if err != nil { - err = errors.WithMessagef(err, "Failed to save changes for %s to %s", baseModel.Name, crudFilePath) - return err - } - } else { - // Produce a diff after the updates have been applied. - dmp := diffmatchpatch.New() - diffs := dmp.DiffMain(curCode, crudDoc.String(), true) - fmt.Println(dmp.DiffPrettyText(diffs)) - } - - return nil -} diff --git a/tools/truss/cmd/dbtable2crud/models.go b/tools/truss/cmd/dbtable2crud/models.go deleted file mode 100644 index 5cc4076..0000000 --- a/tools/truss/cmd/dbtable2crud/models.go +++ /dev/null @@ -1,229 +0,0 @@ -package dbtable2crud - -import ( - "log" - "strings" - - "geeks-accelerator/oss/saas-starter-kit/tools/truss/internal/goparse" - "github.com/fatih/structtag" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" -) - -// modelDef defines info about the struct and associated db table. -type modelDef struct { - *goparse.GoDocument - Name string - TableName string - PrimaryField string - PrimaryColumn string - PrimaryType string - Fields []modelField - FieldNames []string - ColumnNames []string -} - -// modelField defines a struct field and associated db column. -type modelField struct { - ColumnName string - DbColumn *psqlColumn - FieldName string - FieldType string - FieldIsPtr bool - Tags *structtag.Tags - ApiHide bool - ApiRead bool - ApiCreate bool - ApiUpdate bool - DefaultValue string -} - -// parseModelFile parses the entire model file and then loads the specified model struct. -func parseModelFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName string) (*modelDef, error) { - - // Parse the supplied model file. - doc, err := goparse.ParseFile(log, modelFile) - if err != nil { - return nil, err - } - - // Init new modelDef. - model := &modelDef{ - GoDocument: doc, - Name: modelName, - TableName: dbTable, - } - - // Append the field the the model def. - model.Fields, err = parseModelFields(doc, modelName, nil) - if err != nil { - return nil, err - } - - for _, sf := range model.Fields { - model.FieldNames = append(model.FieldNames, sf.FieldName) - model.ColumnNames = append(model.ColumnNames, sf.ColumnName) - } - - // Query the database for a table definition. - dbCols, err := descTable(db, dbName, dbTable) - if err != nil { - return model, err - } - - // Loop over all the database table columns and append to the associated - // struct field. Don't force all database table columns to be defined in the - // in the struct. - for _, dbCol := range dbCols { - for idx, sf := range model.Fields { - if sf.ColumnName != dbCol.Column { - continue - } - - if dbCol.IsPrimaryKey { - model.PrimaryColumn = sf.ColumnName - model.PrimaryField = sf.FieldName - model.PrimaryType = sf.FieldType - } - - if dbCol.DefaultValue != nil { - sf.DefaultValue = *dbCol.DefaultValue - - if dbCol.IsEnum { - sf.DefaultValue = strings.Trim(sf.DefaultValue, "'") - sf.DefaultValue = sf.FieldType + "_" + FormatCamel(sf.DefaultValue) - } else if strings.HasPrefix(sf.DefaultValue, "'") { - sf.DefaultValue = strings.Trim(sf.DefaultValue, "'") - sf.DefaultValue = "\"" + sf.DefaultValue + "\"" - } - } - - c := dbCol - sf.DbColumn = &c - model.Fields[idx] = sf - } - } - - // Print out the model for debugging. - //modelJSON, err := json.MarshalIndent(model, "", " ") - //if err != nil { - // return model, errors.WithStack(err ) - //} - //log.Printf(string(modelJSON)) - - return model, nil -} - -// parseModelFields parses the fields from a struct. -func parseModelFields(doc *goparse.GoDocument, modelName string, baseModel *modelDef) ([]modelField, error) { - - // Ensure the model file has a struct with the model name supplied. - if !doc.HasType(modelName, goparse.GoObjectType_Struct) { - err := errors.Errorf("Struct with the name %s does not exist", modelName) - return nil, err - } - - // Load the struct from parsed go file. - docModel := doc.Get(modelName, goparse.GoObjectType_Struct) - - // Loop over all the objects contained between the struct definition start and end. - // This should be a list of variables defined for model. - resp := []modelField{} - for _, l := range docModel.Objects().List() { - - // Skip all lines that are not a var. - if l.Type != goparse.GoObjectType_Line { - log.Printf("parseModelFile : Model %s has line that is %s, not type line, skipping - %s\n", modelName, l.Type, l.String()) - continue - } - - // Extract the var name, type and defined tags from the line. - sv, err := goparse.ParseStructProp(l) - if err != nil { - return nil, err - } - - // Init new modelField for the struct var. - sf := modelField{ - FieldName: sv.Name, - FieldType: sv.Type, - FieldIsPtr: strings.HasPrefix(sv.Type, "*"), - Tags: sv.Tags, - } - - // Extract the column name from the var tags. - if sf.Tags != nil { - // First try to get the column name from the db tag. - dbt, err := sf.Tags.Get("db") - if err != nil && !strings.Contains(err.Error(), "not exist") { - err = errors.WithStack(err) - return nil, err - } else if dbt != nil { - sf.ColumnName = dbt.Name - } - - // Second try to get the column name from the json tag. - if sf.ColumnName == "" { - jt, err := sf.Tags.Get("json") - if err != nil && !strings.Contains(err.Error(), "not exist") { - err = errors.WithStack(err) - return nil, err - } else if jt != nil && jt.Name != "-" { - sf.ColumnName = jt.Name - } - } - - var apiActionsSet bool - tt, err := sf.Tags.Get("truss") - if err != nil && !strings.Contains(err.Error(), "not exist") { - err = errors.WithStack(err) - return nil, err - } else if tt != nil { - if tt.Name == "api-create" || tt.HasOption("api-create") { - sf.ApiCreate = true - apiActionsSet = true - } - if tt.Name == "api-read" || tt.HasOption("api-read") { - sf.ApiRead = true - apiActionsSet = true - } - if tt.Name == "api-update" || tt.HasOption("api-update") { - sf.ApiUpdate = true - apiActionsSet = true - } - if tt.Name == "api-hide" || tt.HasOption("api-hide") { - sf.ApiHide = true - apiActionsSet = true - } - } - - if !apiActionsSet { - sf.ApiCreate = true - sf.ApiRead = true - sf.ApiUpdate = true - } - } - - // Set the column name to the field name if empty and does not equal '-'. - if sf.ColumnName == "" { - sf.ColumnName = sf.FieldName - } - - // If a base model as already been parsed with the db columns, - // append to the current field. - if baseModel != nil { - for _, baseSf := range baseModel.Fields { - if baseSf.ColumnName == sf.ColumnName { - sf.DefaultValue = baseSf.DefaultValue - sf.DbColumn = baseSf.DbColumn - break - } - } - } - - // Append the field the the model def. - resp = append(resp, sf) - } - - return resp, nil -} diff --git a/tools/truss/cmd/dbtable2crud/templates.go b/tools/truss/cmd/dbtable2crud/templates.go deleted file mode 100644 index 5ff3841..0000000 --- a/tools/truss/cmd/dbtable2crud/templates.go +++ /dev/null @@ -1,345 +0,0 @@ -package dbtable2crud - -import ( - "bufio" - "bytes" - "fmt" - "go/format" - "io/ioutil" - "log" - "os" - "path/filepath" - "sort" - "strings" - "text/template" - - "geeks-accelerator/oss/saas-starter-kit/tools/truss/internal/goparse" - "github.com/dustin/go-humanize/english" - "github.com/fatih/camelcase" - "github.com/iancoleman/strcase" - "github.com/pkg/errors" -) - -// loadTemplateObjects executes a template file based on the given model struct and -// returns the parsed go objects. -func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename string, tmptData map[string]interface{}) ([]*goparse.GoObject, error) { - - // Data used to execute all the of defined code sections in the template file. - if tmptData == nil { - tmptData = make(map[string]interface{}) - } - tmptData["Model"] = model - - // geeks-accelerator/oss/saas-starter-kit/example-project - - // Read the template file from the local file system. - tempFilePath := filepath.Join(templateDir, filename) - dat, err := ioutil.ReadFile(tempFilePath) - if err != nil { - err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath) - return nil, err - } - - // New template with custom functions. - baseTmpl := template.New("base") - baseTmpl.Funcs(template.FuncMap{ - "Concat": func(vals ...string) string { - return strings.Join(vals, "") - }, - "JoinStrings": func(vals []string, sep string) string { - return strings.Join(vals, sep) - }, - "PrefixAndJoinStrings": func(vals []string, pre, sep string) string { - l := []string{} - for _, v := range vals { - l = append(l, pre+v) - } - return strings.Join(l, sep) - }, - "FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string { - l := []string{} - for _, v := range vals { - l = append(l, fmt.Sprintf(fmtStr, v)) - } - return strings.Join(l, sep) - }, - "FormatCamel": func(name string) string { - return FormatCamel(name) - }, - "FormatCamelTitle": func(name string) string { - return FormatCamelTitle(name) - }, - "FormatCamelLower": func(name string) string { - if name == "ID" { - return "id" - } - return FormatCamelLower(name) - }, - "FormatCamelLowerTitle": func(name string) string { - return FormatCamelLowerTitle(name) - }, - "FormatCamelPluralTitle": func(name string) string { - return FormatCamelPluralTitle(name) - }, - "FormatCamelPluralTitleLower": func(name string) string { - return FormatCamelPluralTitleLower(name) - }, - "FormatCamelPluralCamel": func(name string) string { - return FormatCamelPluralCamel(name) - }, - "FormatCamelPluralLower": func(name string) string { - return FormatCamelPluralLower(name) - }, - "FormatCamelPluralUnderscore": func(name string) string { - return FormatCamelPluralUnderscore(name) - }, - "FormatCamelPluralLowerUnderscore": func(name string) string { - return FormatCamelPluralLowerUnderscore(name) - }, - "FormatCamelUnderscore": func(name string) string { - return FormatCamelUnderscore(name) - }, - "FormatCamelLowerUnderscore": func(name string) string { - return FormatCamelLowerUnderscore(name) - }, - "FieldTagHasOption": func(f modelField, tagName, optName string) bool { - if f.Tags == nil { - return false - } - ft, err := f.Tags.Get(tagName) - if ft == nil || err != nil { - return false - } - if ft.Name == optName || ft.HasOption(optName) { - return true - } - return false - }, - "FieldTag": func(f modelField, tagName string) string { - if f.Tags == nil { - return "" - } - ft, err := f.Tags.Get(tagName) - if ft == nil || err != nil { - return "" - } - return ft.String() - }, - "FieldTagReplaceOrPrepend": func(f modelField, tagName, oldVal, newVal string) string { - if f.Tags == nil { - return "" - } - ft, err := f.Tags.Get(tagName) - if ft == nil || err != nil { - return "" - } - - if ft.Name == oldVal || ft.Name == newVal { - ft.Name = newVal - } else if ft.HasOption(oldVal) { - for idx, val := range ft.Options { - if val == oldVal { - ft.Options[idx] = newVal - } - } - } else if !ft.HasOption(newVal) { - if ft.Name == "" { - ft.Name = newVal - } else { - ft.Options = append(ft.Options, newVal) - } - } - - return ft.String() - }, - "StringListHasValue": func(list []string, val string) bool { - for _, v := range list { - if v == val { - return true - } - } - return false - }, - }) - - // Load the template file using the text/template package. - tmpl, err := baseTmpl.Parse(string(dat)) - if err != nil { - err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath) - log.Printf("loadTemplateObjects : %v\n%v", err, string(dat)) - return nil, err - } - - // Generate a list of template names defined in the template file. - tmplNames := []string{} - for _, defTmpl := range tmpl.Templates() { - tmplNames = append(tmplNames, defTmpl.Name()) - } - - // Stupid hack to return template names the in order they are defined in the file. - tmplNames, err = templateFileOrderedNames(tempFilePath, tmplNames) - if err != nil { - return nil, err - } - - // Loop over all the defined templates, execute using the defined data, parse the - // formatted code and append the parsed go objects to the result list. - var resp []*goparse.GoObject - for _, tmplName := range tmplNames { - // Executed the defined template with the given data. - var tpl bytes.Buffer - if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil { - err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath) - return resp, err - } - - // Format the source code to ensure its valid and code to parsed consistently. - codeBytes, err := format.Source(tpl.Bytes()) - if err != nil { - err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename) - - dl := []string{} - for idx, l := range strings.Split(tpl.String(), "\n") { - dl = append(dl, fmt.Sprintf("%d -> ", idx)+l) - } - - log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n")) - return resp, err - } - - // Remove extra white space from the code. - codeStr := strings.TrimSpace(string(codeBytes)) - - // Split the code into a list of strings. - codeLines := strings.Split(codeStr, "\n") - - // Parse the code lines into a set of objects. - objs, err := goparse.ParseLines(codeLines, 0) - if err != nil { - err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename) - log.Printf("loadTemplateObjects : %v\n%v", err, codeStr) - return resp, err - } - - // Append the parsed objects to the return result list. - for _, obj := range objs.List() { - if obj.Name == "" && obj.Type != goparse.GoObjectType_Import && obj.Type != goparse.GoObjectType_Var && obj.Type != goparse.GoObjectType_Const && obj.Type != goparse.GoObjectType_Comment && obj.Type != goparse.GoObjectType_LineBreak { - // All objects should have a name except for multiline var/const declarations and comments. - err = errors.Errorf("Failed to parse name with type %s from lines: %v", obj.Type, obj.Lines()) - return resp, err - } else if string(obj.Type) == "" { - err = errors.Errorf("Failed to parse type for %s from lines: %v", obj.Name, obj.Lines()) - return resp, err - } - - resp = append(resp, obj) - } - } - - return resp, nil -} - -// FormatCamel formats Valdez mountain to ValdezMountain -func FormatCamel(name string) string { - return strcase.ToCamel(name) -} - -// FormatCamelLower formats ValdezMountain to valdezmountain -func FormatCamelLower(name string) string { - return strcase.ToLowerCamel(FormatCamel(name)) -} - -// FormatCamelTitle formats ValdezMountain to Valdez Mountain -func FormatCamelTitle(name string) string { - return strings.Join(camelcase.Split(name), " ") -} - -// FormatCamelLowerTitle formats ValdezMountain to valdez mountain -func FormatCamelLowerTitle(name string) string { - return strings.ToLower(FormatCamelTitle(name)) -} - -// FormatCamelPluralTitle formats ValdezMountain to Valdez Mountains -func FormatCamelPluralTitle(name string) string { - pts := camelcase.Split(name) - lastIdx := len(pts) - 1 - pts[lastIdx] = english.PluralWord(2, pts[lastIdx], "") - return strings.Join(pts, " ") -} - -// FormatCamelPluralTitleLower formats ValdezMountain to valdez mountains -func FormatCamelPluralTitleLower(name string) string { - return strings.ToLower(FormatCamelPluralTitle(name)) -} - -// FormatCamelPluralCamel formats ValdezMountain to ValdezMountains -func FormatCamelPluralCamel(name string) string { - return strcase.ToCamel(FormatCamelPluralTitle(name)) -} - -// FormatCamelPluralLower formats ValdezMountain to valdezmountains -func FormatCamelPluralLower(name string) string { - return strcase.ToLowerCamel(FormatCamelPluralTitle(name)) -} - -// FormatCamelPluralUnderscore formats ValdezMountain to Valdez_Mountains -func FormatCamelPluralUnderscore(name string) string { - return strings.Replace(FormatCamelPluralTitle(name), " ", "_", -1) -} - -// FormatCamelPluralLowerUnderscore formats ValdezMountain to valdez_mountains -func FormatCamelPluralLowerUnderscore(name string) string { - return strings.ToLower(FormatCamelPluralUnderscore(name)) -} - -// FormatCamelUnderscore formats ValdezMountain to Valdez_Mountain -func FormatCamelUnderscore(name string) string { - return strings.Replace(FormatCamelTitle(name), " ", "_", -1) -} - -// FormatCamelLowerUnderscore formats ValdezMountain to valdez_mountain -func FormatCamelLowerUnderscore(name string) string { - return strings.ToLower(FormatCamelUnderscore(name)) -} - -// templateFileOrderedNames returns the template names the in order they are defined in the file. -func templateFileOrderedNames(localPath string, names []string) (resp []string, err error) { - file, err := os.Open(localPath) - if err != nil { - return resp, errors.WithStack(err) - } - defer file.Close() - - idxList := []int{} - idxNames := make(map[int]string) - - idx := 0 - scanner := bufio.NewScanner(file) - for scanner.Scan() { - if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") { - continue - } - - for _, name := range names { - if strings.Contains(scanner.Text(), "\""+name+"\"") { - idxList = append(idxList, idx) - idxNames[idx] = name - break - } - } - - idx = idx + 1 - } - - if err := scanner.Err(); err != nil { - return resp, errors.WithStack(err) - } - - sort.Ints(idxList) - - for _, idx := range idxList { - resp = append(resp, idxNames[idx]) - } - - return resp, nil -} diff --git a/tools/truss/internal/goparse/doc.go b/tools/truss/internal/goparse/doc.go deleted file mode 100644 index e7e5810..0000000 --- a/tools/truss/internal/goparse/doc.go +++ /dev/null @@ -1,301 +0,0 @@ -package goparse - -import ( - "fmt" - "go/format" - "io/ioutil" - "strings" - - "github.com/pkg/errors" -) - -// GoDocument defines a single go code file. -type GoDocument struct { - *GoObjects - Package string - imports GoImports -} - -// GoImport defines a single import line with optional alias. -type GoImport struct { - Name string - Alias string -} - -// GoImports holds a list of import lines. -type GoImports []GoImport - -// NewGoDocument creates a new GoDocument with the package line set. -func NewGoDocument(packageName string) (doc *GoDocument, err error) { - doc = &GoDocument{ - GoObjects: &GoObjects{ - list: []*GoObject{}, - }, - } - err = doc.SetPackage(packageName) - return doc, err -} - -// Objects returns a list of root GoObject. -func (doc *GoDocument) Objects() *GoObjects { - if doc.GoObjects == nil { - doc.GoObjects = &GoObjects{ - list: []*GoObject{}, - } - } - - return doc.GoObjects -} - -// NewObjectPackage returns a new GoObject with a single package definition line. -func NewObjectPackage(packageName string) *GoObject { - lines := []string{ - fmt.Sprintf("package %s", packageName), - "", - } - - obj, _ := ParseGoObject(lines, 0) - - return obj -} - -// SetPackage appends sets the package line for the code file. -func (doc *GoDocument) SetPackage(packageName string) error { - - var existing *GoObject - for _, obj := range doc.Objects().List() { - if obj.Type == GoObjectType_Package { - existing = obj - break - } - } - - new := NewObjectPackage(packageName) - - var err error - if existing != nil { - err = doc.Objects().Replace(existing, new) - } else if len(doc.Objects().List()) > 0 { - - // Insert after any existing comments or line breaks. - var insertPos int - //for idx, obj := range doc.Objects().List() { - // switch obj.Type { - // case GoObjectType_Comment, GoObjectType_LineBreak: - // insertPos = idx - // default: - // break - // } - //} - - err = doc.Objects().Insert(insertPos, new) - } else { - err = doc.Objects().Add(new) - } - - return err -} - -// AddObject appends a new GoObject to the doc root object list. -func (doc *GoDocument) AddObject(newObj *GoObject) error { - return doc.Objects().Add(newObj) -} - -// InsertObject inserts a new GoObject at the desired position to the doc root object list. -func (doc *GoDocument) InsertObject(pos int, newObj *GoObject) error { - return doc.Objects().Insert(pos, newObj) -} - -// Imports returns the GoDocument imports. -func (doc *GoDocument) Imports() (GoImports, error) { - // If the doc imports are empty, try to load them from the root objects. - if len(doc.imports) == 0 { - for _, obj := range doc.Objects().List() { - if obj.Type != GoObjectType_Import { - continue - } - - res, err := ParseImportObject(obj) - if err != nil { - return doc.imports, err - } - - // Combine all the imports into a single definition. - for _, n := range res { - doc.imports = append(doc.imports, n) - } - } - } - - return doc.imports, nil -} - -// Lines returns all the code lines. -func (doc *GoDocument) Lines() []string { - l := []string{} - - for _, ol := range doc.Objects().Lines() { - l = append(l, ol) - } - return l -} - -// String returns a single value for all the code lines. -func (doc *GoDocument) String() string { - return strings.Join(doc.Lines(), "\n") -} - -// Print writes all the code lines to standard out. -func (doc *GoDocument) Print() { - for _, l := range doc.Lines() { - fmt.Println(l) - } -} - -// Save renders all the code lines for the document, formats the code -// and then saves it to the supplied file path. -func (doc *GoDocument) Save(localpath string) error { - res, err := format.Source([]byte(doc.String())) - if err != nil { - err = errors.WithMessage(err, "Failed formatted source code") - return err - } - - err = ioutil.WriteFile(localpath, res, 0644) - if err != nil { - err = errors.WithMessagef(err, "Failed write source code to file %s", localpath) - return err - } - - return nil -} - -// AddImport checks for any duplicate imports by name and adds it if not. -func (doc *GoDocument) AddImport(impt GoImport) error { - impt.Name = strings.Trim(impt.Name, "\"") - - // Get a list of current imports for the document. - impts, err := doc.Imports() - if err != nil { - return err - } - - // If the document has as the import, don't add it. - if impts.Has(impt.Name) { - return nil - } - - // Loop through all the document root objects for an object of type import. - // If one exists, append the import to the existing list. - for _, obj := range doc.Objects().List() { - if obj.Type != GoObjectType_Import || len(obj.Lines()) == 1 { - continue - } - obj.subLines = append(obj.subLines, impt.String()) - obj.goObjects.list = append(obj.goObjects.list, impt.Object()) - - doc.imports = append(doc.imports, impt) - - return nil - } - - // Document does not have an existing import object, so need to create one and - // then append to the document. - newObj := NewObjectImports(impt) - - // Insert after any package, any existing comments or line breaks should be included. - var insertPos int - for idx, obj := range doc.Objects().List() { - switch obj.Type { - case GoObjectType_Package, GoObjectType_Comment, GoObjectType_LineBreak: - insertPos = idx - default: - break - } - } - - // Insert the new import object. - err = doc.InsertObject(insertPos, newObj) - if err != nil { - return err - } - - return nil -} - -// NewObjectImports returns a new GoObject with a single import definition. -func NewObjectImports(impt GoImport) *GoObject { - lines := []string{ - "import (", - impt.String(), - ")", - "", - } - - obj, _ := ParseGoObject(lines, 0) - children, err := ParseLines(obj.subLines, 1) - if err != nil { - return nil - } - for _, child := range children.List() { - obj.Objects().Add(child) - } - - return obj -} - -// Has checks to see if an import exists by name or alias. -func (impts GoImports) Has(name string) bool { - for _, impt := range impts { - if name == impt.Name || (impt.Alias != "" && name == impt.Alias) { - return true - } - } - return false -} - -// Line formats an import as a string. -func (impt GoImport) String() string { - var imptLine string - if impt.Alias != "" { - imptLine = fmt.Sprintf("\t%s \"%s\"", impt.Alias, impt.Name) - } else { - imptLine = fmt.Sprintf("\t\"%s\"", impt.Name) - } - return imptLine -} - -// Object returns the first GoObject for an import. -func (impt GoImport) Object() *GoObject { - imptObj := NewObjectImports(impt) - - return imptObj.Objects().List()[0] -} - -// ParseImportObject extracts all the import definitions. -func ParseImportObject(obj *GoObject) (resp GoImports, err error) { - if obj.Type != GoObjectType_Import { - return resp, errors.Errorf("Invalid type %s", string(obj.Type)) - } - - for _, l := range obj.Lines() { - if !strings.Contains(l, "\"") { - continue - } - l = strings.TrimSpace(l) - - pts := strings.Split(l, "\"") - - var impt GoImport - if strings.HasPrefix(l, "\"") { - impt.Name = pts[1] - } else { - impt.Alias = strings.TrimSpace(pts[0]) - impt.Name = pts[1] - } - - resp = append(resp, impt) - } - - return resp, nil -} diff --git a/tools/truss/internal/goparse/doc_object.go b/tools/truss/internal/goparse/doc_object.go deleted file mode 100644 index 5b20efe..0000000 --- a/tools/truss/internal/goparse/doc_object.go +++ /dev/null @@ -1,458 +0,0 @@ -package goparse - -import ( - "log" - "strings" - - "github.com/fatih/structtag" - "github.com/pkg/errors" -) - -// GoEmptyLine defined a GoObject for a code line break. -var GoEmptyLine = GoObject{ - Type: GoObjectType_LineBreak, - goObjects: &GoObjects{ - list: []*GoObject{}, - }, -} - -// GoObjectType defines a set of possible types to group -// parsed code by. -type GoObjectType = string - -var ( - GoObjectType_Package = "package" - GoObjectType_Import = "import" - GoObjectType_Var = "var" - GoObjectType_Const = "const" - GoObjectType_Func = "func" - GoObjectType_Struct = "struct" - GoObjectType_Comment = "comment" - GoObjectType_LineBreak = "linebreak" - GoObjectType_Line = "line" - GoObjectType_Type = "type" -) - -// GoObject defines a section of code with a nested set of children. -type GoObject struct { - Type GoObjectType - Name string - startLines []string - endLines []string - subLines []string - goObjects *GoObjects - Index int -} - -// GoObjects stores a list of GoObject. -type GoObjects struct { - list []*GoObject -} - -// Objects returns the list of *GoObject. -func (obj *GoObject) Objects() *GoObjects { - if obj.goObjects == nil { - obj.goObjects = &GoObjects{ - list: []*GoObject{}, - } - } - return obj.goObjects -} - -// Clone performs a deep copy of the struct. -func (obj *GoObject) Clone() *GoObject { - n := &GoObject{ - Type: obj.Type, - Name: obj.Name, - startLines: obj.startLines, - endLines: obj.endLines, - subLines: obj.subLines, - goObjects: &GoObjects{ - list: []*GoObject{}, - }, - Index: obj.Index, - } - for _, sub := range obj.Objects().List() { - n.Objects().Add(sub.Clone()) - } - return n -} - -// IsComment returns whether an object is of type GoObjectType_Comment. -func (obj *GoObject) IsComment() bool { - if obj.Type != GoObjectType_Comment { - return false - } - return true -} - -// Contains searches all the lines for the object for a matching string. -func (obj *GoObject) Contains(match string) bool { - for _, l := range obj.Lines() { - if strings.Contains(l, match) { - return true - } - } - return false -} - -// UpdateLines parses the new code and replaces the current GoObject. -func (obj *GoObject) UpdateLines(newLines []string) error { - - // Parse the new lines. - objs, err := ParseLines(newLines, 0) - if err != nil { - return err - } - - var newObj *GoObject - for _, obj := range objs.List() { - if obj.Type == GoObjectType_LineBreak { - continue - } - - if newObj == nil { - newObj = obj - } - - // There should only be one resulting parsed object that is - // not of type GoObjectType_LineBreak. - return errors.New("Can only update single blocks of code") - } - - // No new code was parsed, return error. - if newObj == nil { - return errors.New("Failed to render replacement code") - } - - return obj.Update(newObj) -} - -// Update performs a deep copy that overwrites the existing values. -func (obj *GoObject) Update(newObj *GoObject) error { - obj.Type = newObj.Type - obj.Name = newObj.Name - obj.startLines = newObj.startLines - obj.endLines = newObj.endLines - obj.subLines = newObj.subLines - obj.goObjects = newObj.goObjects - return nil -} - -// Lines returns a list of strings for current object and all children. -func (obj *GoObject) Lines() []string { - l := []string{} - - // First include any lines before the sub objects. - for _, sl := range obj.startLines { - l = append(l, sl) - } - - // If there are parsed sub objects include those lines else when - // no sub objects, just use the sub lines. - if len(obj.Objects().List()) > 0 { - for _, sl := range obj.Objects().Lines() { - l = append(l, sl) - } - } else { - for _, sl := range obj.subLines { - l = append(l, sl) - } - } - - // Lastly include any other lines that are after all parsed sub objects. - for _, sl := range obj.endLines { - l = append(l, sl) - } - - return l -} - -// String returns the lines separated by line break. -func (obj *GoObject) String() string { - return strings.Join(obj.Lines(), "\n") -} - -// Lines returns a list of strings for all the list objects. -func (objs *GoObjects) Lines() []string { - l := []string{} - for _, obj := range objs.List() { - for _, oj := range obj.Lines() { - l = append(l, oj) - } - } - return l -} - -// String returns all the lines for the list objects. -func (objs *GoObjects) String() string { - lines := []string{} - for _, obj := range objs.List() { - lines = append(lines, obj.String()) - } - return strings.Join(lines, "\n") -} - -// List returns the list of GoObjects. -func (objs *GoObjects) List() []*GoObject { - return objs.list -} - -// HasFunc searches the current list of objects for a function object by name. -func (objs *GoObjects) HasFunc(name string) bool { - return objs.HasType(name, GoObjectType_Func) -} - -// Get returns the GoObject for the matching name and type. -func (objs *GoObjects) Get(name string, objType GoObjectType) *GoObject { - for _, obj := range objs.list { - if obj.Name == name && (objType == "" || obj.Type == objType) { - return obj - } - } - return nil -} - -// HasType checks is a GoObject exists for the matching name and type. -func (objs *GoObjects) HasType(name string, objType GoObjectType) bool { - for _, obj := range objs.list { - if obj.Name == name && (objType == "" || obj.Type == objType) { - return true - } - } - return false -} - -// HasObject checks to see if the exact code block exists. -func (objs *GoObjects) HasObject(src *GoObject) bool { - if src == nil { - return false - } - - // Generate the code for the supplied object. - srcLines := []string{} - for _, l := range src.Lines() { - // Exclude empty lines. - l = strings.TrimSpace(l) - if l != "" { - srcLines = append(srcLines, l) - } - } - srcStr := strings.Join(srcLines, "\n") - - // Loop over all the objects and match with src code. - for _, obj := range objs.list { - objLines := []string{} - for _, l := range obj.Lines() { - // Exclude empty lines. - l = strings.TrimSpace(l) - if l != "" { - objLines = append(objLines, l) - } - } - objStr := strings.Join(objLines, "\n") - - // Return true if the current object code matches src code. - if srcStr == objStr { - return true - } - } - - return false -} - -// Add appends a new GoObject to the list. -func (objs *GoObjects) Add(newObj *GoObject) error { - newObj.Index = len(objs.list) - objs.list = append(objs.list, newObj) - return nil -} - -// Insert appends a new GoObject at the desired position to the list. -func (objs *GoObjects) Insert(pos int, newObj *GoObject) error { - newList := []*GoObject{} - - var newIdx int - for _, obj := range objs.list { - if obj.Index < pos { - obj.Index = newIdx - newList = append(newList, obj) - } else { - if obj.Index == pos { - newObj.Index = newIdx - newList = append(newList, newObj) - newIdx++ - } - obj.Index = newIdx - newList = append(newList, obj) - } - - newIdx++ - } - - objs.list = newList - - return nil -} - -// Remove deletes a GoObject from the list. -func (objs *GoObjects) Remove(delObjs ...*GoObject) error { - for _, delObj := range delObjs { - oldList := objs.List() - objs.list = []*GoObject{} - - var newIdx int - for _, obj := range oldList { - if obj.Index == delObj.Index { - continue - } - obj.Index = newIdx - objs.list = append(objs.list, obj) - newIdx++ - } - } - - return nil -} - -// Replace updates an existing GoObject while maintaining is same position. -func (objs *GoObjects) Replace(oldObj *GoObject, newObjs ...*GoObject) error { - if oldObj.Index >= len(objs.list) { - return errors.WithStack(errGoObjectNotExist) - } else if len(newObjs) == 0 { - return nil - } - - oldList := objs.List() - objs.list = []*GoObject{} - - var newIdx int - for _, obj := range oldList { - if obj.Index < oldObj.Index { - obj.Index = newIdx - objs.list = append(objs.list, obj) - newIdx++ - } else if obj.Index == oldObj.Index { - for _, newObj := range newObjs { - newObj.Index = newIdx - objs.list = append(objs.list, newObj) - newIdx++ - } - } else { - obj.Index = newIdx - objs.list = append(objs.list, obj) - newIdx++ - } - } - - return nil -} - -// ReplaceFuncByName finds an existing GoObject with type GoObjectType_Func by name -// and then performs a replace with the supplied new GoObject. -func (objs *GoObjects) ReplaceFuncByName(name string, fn *GoObject) error { - return objs.ReplaceTypeByName(name, fn, GoObjectType_Func) -} - -// ReplaceTypeByName finds an existing GoObject with type by name -// and then performs a replace with the supplied new GoObject. -func (objs *GoObjects) ReplaceTypeByName(name string, newObj *GoObject, objType GoObjectType) error { - if newObj.Name == "" { - newObj.Name = name - } - if newObj.Type == "" && objType != "" { - newObj.Type = objType - } - - for _, obj := range objs.list { - if obj.Name == name && (objType == "" || objType == obj.Type) { - return objs.Replace(obj, newObj) - } - } - return errors.WithStack(errGoObjectNotExist) -} - -// Empty determines if all the GoObject in the list are line breaks. -func (objs *GoObjects) Empty() bool { - var hasStuff bool - for _, obj := range objs.List() { - switch obj.Type { - case GoObjectType_LineBreak: - //case GoObjectType_Comment: - //case GoObjectType_Import: - // do nothing - default: - hasStuff = true - } - } - return hasStuff -} - -// Debug prints out the GoObject to logger. -func (obj *GoObject) Debug(log *log.Logger) { - log.Println(obj.Name) - log.Println(" > type:", obj.Type) - log.Println(" > start lines:") - for _, l := range obj.startLines { - log.Println(" ", l) - } - - log.Println(" > sub lines:") - for _, l := range obj.subLines { - log.Println(" ", l) - } - - log.Println(" > end lines:") - for _, l := range obj.endLines { - log.Println(" ", l) - } -} - -// Defines a property of a struct. -type structProp struct { - Name string - Type string - Tags *structtag.Tags -} - -// ParseStructProp extracts the details for a struct property. -func ParseStructProp(obj *GoObject) (structProp, error) { - - if obj.Type != GoObjectType_Line { - return structProp{}, errors.Errorf("Unable to parse object of type %s", obj.Type) - } - - // Remove any white space from the code line. - ls := strings.TrimSpace(strings.Join(obj.Lines(), " ")) - - // Extract the property name and type for the line. - // ie: ID string `json:"id"` - var resp structProp - for _, p := range strings.Split(ls, " ") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if resp.Name == "" { - resp.Name = p - } else if resp.Type == "" { - resp.Type = p - } else { - break - } - } - - // If the line contains tags, extract and parse them. - if strings.Contains(ls, "`") { - tagStr := strings.Split(ls, "`")[1] - - var err error - resp.Tags, err = structtag.Parse(tagStr) - if err != nil { - err = errors.WithMessagef(err, "Failed to parse struct tag for field %s: %s", resp.Name, tagStr) - return structProp{}, err - } - } - - return resp, nil -} diff --git a/tools/truss/internal/goparse/goparse.go b/tools/truss/internal/goparse/goparse.go deleted file mode 100644 index c7e129a..0000000 --- a/tools/truss/internal/goparse/goparse.go +++ /dev/null @@ -1,352 +0,0 @@ -package goparse - -import ( - "bufio" - "bytes" - "fmt" - "go/format" - "io/ioutil" - "log" - "strings" - "unicode" - - "github.com/pkg/errors" -) - -var ( - errGoParseType = errors.New("Unable to determine type for line") - errGoTypeMissingCodeTemplate = errors.New("No code defined for type") - errGoObjectNotExist = errors.New("GoObject does not exist") -) - -// ParseFile reads a go code file and parses into a easily transformable set of objects. -func ParseFile(log *log.Logger, localPath string) (*GoDocument, error) { - - // Read the code file. - src, err := ioutil.ReadFile(localPath) - if err != nil { - err = errors.WithMessagef(err, "Failed to read file %s", localPath) - return nil, err - } - - // Format the code file source to ensure parse works. - dat, err := format.Source(src) - if err != nil { - err = errors.WithMessagef(err, "Failed to format source for file %s", localPath) - log.Printf("ParseFile : %v\n%v", err, string(src)) - return nil, err - } - - // Loop of the formatted source code and generate a list of code lines. - lines := []string{} - r := bytes.NewReader(dat) - scanner := bufio.NewScanner(r) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - - if err := scanner.Err(); err != nil { - err = errors.WithMessagef(err, "Failed read formatted source code for file %s", localPath) - return nil, err - } - - // Parse the code lines into a set of objects. - objs, err := ParseLines(lines, 0) - if err != nil { - log.Println(err) - return nil, err - } - - // Append the resulting objects to the document. - doc := &GoDocument{} - for _, obj := range objs.List() { - if obj.Type == GoObjectType_Package { - doc.Package = obj.Name - } - doc.AddObject(obj) - } - - return doc, nil -} - -// ParseLines takes the list of formatted code lines and returns the GoObjects. -func ParseLines(lines []string, depth int) (objs *GoObjects, err error) { - objs = &GoObjects{ - list: []*GoObject{}, - } - - var ( - multiLine bool - multiComment bool - muiliVar bool - ) - curDepth := -1 - objLines := []string{} - - for idx, l := range lines { - ls := strings.TrimSpace(l) - - ld := lineDepth(l) - - //fmt.Println("l", l) - //fmt.Println("> Depth", ld, "???", depth) - - if ld == depth { - if strings.HasPrefix(ls, "/*") { - multiLine = true - multiComment = true - } else if strings.HasSuffix(ls, "(") || - strings.HasSuffix(ls, "{") { - - if !multiLine { - multiLine = true - } - } else if strings.Contains(ls, "`") { - if !multiLine && strings.Count(ls, "`")%2 != 0 { - if muiliVar { - muiliVar = false - } else { - muiliVar = true - } - } - } - - //fmt.Println("> multiLine", multiLine) - //fmt.Println("> multiComment", multiComment) - //fmt.Println("> muiliVar", muiliVar) - - objLines = append(objLines, l) - - if multiComment { - if strings.HasSuffix(ls, "*/") { - multiComment = false - multiLine = false - } - } else { - if strings.HasPrefix(ls, ")") || - strings.HasPrefix(ls, "}") { - multiLine = false - } - } - - if !multiLine && !muiliVar { - for eidx := idx + 1; eidx < len(lines); eidx++ { - if el := lines[eidx]; strings.TrimSpace(el) == "" { - objLines = append(objLines, el) - } else { - break - } - } - - //fmt.Println(" > objLines", objLines) - - obj, err := ParseGoObject(objLines, depth) - if err != nil { - log.Println(err) - return objs, err - } - err = objs.Add(obj) - if err != nil { - log.Println(err) - return objs, err - } - - objLines = []string{} - } - - } else if (multiLine && ld >= curDepth && ld >= depth && len(objLines) > 0) || muiliVar { - objLines = append(objLines, l) - - if strings.Contains(ls, "`") { - if !multiLine && strings.Count(ls, "`")%2 != 0 { - if muiliVar { - muiliVar = false - } else { - muiliVar = true - } - } - } - } - } - - for _, obj := range objs.List() { - children, err := ParseLines(obj.subLines, depth+1) - if err != nil { - log.Println(err) - return objs, err - } - for _, child := range children.List() { - obj.Objects().Add(child) - } - } - - return objs, nil -} - -// ParseGoObject generates a GoObjected for the given code lines. -func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) { - - // If there are no lines, return a line break. - if len(lines) == 0 { - return &GoEmptyLine, nil - } - - firstLine := lines[0] - firstStrip := strings.TrimSpace(firstLine) - - if len(firstStrip) == 0 { - return &GoEmptyLine, nil - } - - obj = &GoObject{ - goObjects: &GoObjects{ - list: []*GoObject{}, - }, - } - - if strings.HasPrefix(firstStrip, "var") { - obj.Type = GoObjectType_Var - - if !strings.HasSuffix(firstStrip, "(") { - if strings.HasPrefix(firstStrip, "var ") { - firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "var ", "", 1)) - } - obj.Name = strings.Split(firstStrip, " ")[0] - } - } else if strings.HasPrefix(firstStrip, "const") { - obj.Type = GoObjectType_Const - - if !strings.HasSuffix(firstStrip, "(") { - if strings.HasPrefix(firstStrip, "const ") { - firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "const ", "", 1)) - } - obj.Name = strings.Split(firstStrip, " ")[0] - } - } else if strings.HasPrefix(firstStrip, "func") { - obj.Type = GoObjectType_Func - - if strings.HasPrefix(firstStrip, "func (") { - funcLine := strings.TrimLeft(strings.TrimSpace(strings.Replace(firstStrip, "func ", "", 1)), "(") - - var structName string - pts := strings.Split(strings.Split(funcLine, ")")[0], " ") - for i := len(pts) - 1; i >= 0; i-- { - ptVal := strings.TrimSpace(pts[i]) - if ptVal != "" { - structName = ptVal - break - } - } - - var funcName string - pts = strings.Split(strings.Split(funcLine, "(")[0], " ") - for i := len(pts) - 1; i >= 0; i-- { - ptVal := strings.TrimSpace(pts[i]) - if ptVal != "" { - funcName = ptVal - break - } - } - - obj.Name = fmt.Sprintf("%s.%s", structName, funcName) - } else { - obj.Name = strings.Replace(firstStrip, "func ", "", 1) - obj.Name = strings.Split(obj.Name, "(")[0] - } - } else if strings.HasSuffix(firstStrip, "struct {") || strings.HasSuffix(firstStrip, "struct{") { - obj.Type = GoObjectType_Struct - - if strings.HasPrefix(firstStrip, "type ") { - firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1)) - } - obj.Name = strings.Split(firstStrip, " ")[0] - } else if strings.HasPrefix(firstStrip, "type") { - obj.Type = GoObjectType_Type - firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1)) - obj.Name = strings.Split(firstStrip, " ")[0] - } else if strings.HasPrefix(firstStrip, "package") { - obj.Name = strings.TrimSpace(strings.Replace(firstStrip, "package ", "", 1)) - - obj.Type = GoObjectType_Package - } else if strings.HasPrefix(firstStrip, "import") { - obj.Type = GoObjectType_Import - } else if strings.HasPrefix(firstStrip, "//") { - obj.Type = GoObjectType_Comment - } else if strings.HasPrefix(firstStrip, "/*") { - obj.Type = GoObjectType_Comment - } else { - if depth > 0 { - obj.Type = GoObjectType_Line - } else { - err = errors.WithStack(errGoParseType) - return obj, err - } - } - - var ( - hasSub bool - muiliVarStart bool - muiliVarSub bool - muiliVarEnd bool - ) - for _, l := range lines { - ld := lineDepth(l) - if (ld == depth && !muiliVarSub) || muiliVarStart || muiliVarEnd { - if hasSub && !muiliVarStart { - if strings.TrimSpace(l) != "" { - obj.endLines = append(obj.endLines, l) - } - - if strings.Count(l, "`")%2 != 0 { - if muiliVarEnd { - muiliVarEnd = false - } else { - muiliVarEnd = true - } - } - } else { - obj.startLines = append(obj.startLines, l) - - if strings.Count(l, "`")%2 != 0 { - if muiliVarStart { - muiliVarStart = false - } else { - muiliVarStart = true - } - } - } - } else if ld > depth || muiliVarSub { - obj.subLines = append(obj.subLines, l) - hasSub = true - - if strings.Count(l, "`")%2 != 0 { - if muiliVarSub { - muiliVarSub = false - } else { - muiliVarSub = true - } - } - } - } - - // add trailing linebreak - if len(obj.endLines) > 0 { - obj.endLines = append(obj.endLines, "") - } - - return obj, err -} - -// lineDepth returns the number of spaces for the given code line -// used to determine the code level for nesting objects. -func lineDepth(l string) int { - depth := len(l) - len(strings.TrimLeftFunc(l, unicode.IsSpace)) - - ls := strings.TrimSpace(l) - if strings.HasPrefix(ls, "}") && strings.Contains(ls, " else ") { - depth = depth + 1 - } else if strings.HasPrefix(ls, "case ") { - depth = depth + 1 - } - return depth -} diff --git a/tools/truss/internal/goparse/goparse_test.go b/tools/truss/internal/goparse/goparse_test.go deleted file mode 100644 index bc87ed7..0000000 --- a/tools/truss/internal/goparse/goparse_test.go +++ /dev/null @@ -1,201 +0,0 @@ -package goparse - -import ( - "log" - "os" - "strings" - "testing" - - "github.com/onsi/gomega" -) - -var logger *log.Logger - -func init() { - logger = log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) -} - -func TestMultilineVar(t *testing.T) { - g := gomega.NewGomegaWithT(t) - - code := `func ContextAllowedAccountIds(ctx context.Context, db *gorm.DB) (resp akdatamodels.Uint32List, err error) { - resp = []uint32{} - accountId := akcontext.ContextAccountId(ctx) - m := datamodels.UserAccount{} - q := fmt.Sprintf("select - distinct account_id - from %s where account_id = ?", m.TableName()) - db = db.Raw(q, accountId) -} -` - code = strings.Replace(code, "\"", "`", -1) - lines := strings.Split(code, "\n") - - objs, err := ParseLines(lines, 0) - if err != nil { - t.Fatalf("got error %v", err) - } - - g.Expect(objs.Lines()).Should(gomega.Equal(lines)) -} - -func TestNewDocImports(t *testing.T) { - g := gomega.NewGomegaWithT(t) - - expected := []string{ - "package goparse", - "", - "import (", - " \"github.com/go/pkg1\"", - " \"github.com/go/pkg2\"", - ")", - "", - } - - doc := &GoDocument{} - doc.SetPackage("goparse") - - doc.AddImport(GoImport{Name: "github.com/go/pkg1"}) - doc.AddImport(GoImport{Name: "github.com/go/pkg2"}) - - g.Expect(doc.Lines()).Should(gomega.Equal(expected)) -} - -func TestParseLines1(t *testing.T) { - g := gomega.NewGomegaWithT(t) - - codeTests := []string{ - `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model { - g := gomega.NewGomegaWithT(t) - obj := datamodels.MockModelNew() - resp, err := ModelCreate(ctx, DB, &obj) - if err != nil { - t.Fatalf("got error %v", err) - } - - g.Expect(resp.Name).Should(gomega.Equal(obj.Name)) - g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active)) - return resp -} -`, - `var ( - // ErrNotFound abstracts the postgres not found error. - ErrNotFound = errors.New("Entity not found") - // ErrInvalidID occurs when an ID is not in a valid form. - ErrInvalidID = errors.New("ID is not in its proper form") - // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. - ErrForbidden = errors.New("Attempted action is not allowed") -) -`, - } - - for _, code := range codeTests { - lines := strings.Split(code, "\n") - - objs, err := ParseLines(lines, 0) - if err != nil { - t.Fatalf("got error %v", err) - } - - g.Expect(objs.Lines()).Should(gomega.Equal(lines)) - } - -} - -func TestParseLines2(t *testing.T) { - code := `func structToMap(s interface{}) (resp map[string]interface{}) { - dat, _ := json.Marshal(s) - _ = json.Unmarshal(dat, &resp) - for k, x := range resp { - switch v := x.(type) { - case time.Time: - if v.IsZero() { - delete(resp, k) - } - - case *time.Time: - if v == nil || v.IsZero() { - delete(resp, k) - } - - case nil: - delete(resp, k) - - } - - } - - return resp -} -` - lines := strings.Split(code, "\n") - - objs, err := ParseLines(lines, 0) - if err != nil { - t.Fatalf("got error %v", err) - } - - testLineTextMatches(t, objs.Lines(), lines) -} - -func TestParseLines3(t *testing.T) { - g := gomega.NewGomegaWithT(t) - - code := `type UserAccountRoleName = string - -const ( - UserAccountRoleName_None UserAccountRoleName = "" - UserAccountRoleName_Admin UserAccountRoleName = "admin" - UserAccountRoleName_User UserAccountRoleName = "user" -) - -type UserAccountRole struct { - Id uint32 ^gorm:"column:id;type:int(10) unsigned AUTO_INCREMENT;primary_key;not null;auto_increment;" truss:"internal:true"^ - CreatedAt time.Time ^gorm:"column:created_at;type:datetime;default:CURRENT_TIMESTAMP;not null;" truss:"internal:true"^ - UpdatedAt time.Time ^gorm:"column:updated_at;type:datetime;" truss:"internal:true"^ - DeletedAt *time.Time ^gorm:"column:deleted_at;type:datetime;" truss:"internal:true"^ - Role UserAccountRoleName ^gorm:"unique_index:user_account_role;column:role;type:enum('admin', 'user')"^ - // belongs to User - User *User ^gorm:"foreignkey:UserId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true"^ - UserId uint32 ^gorm:"unique_index:user_account_role;"^ - // belongs to Account - Account *Account ^gorm:"foreignkey:AccountId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true;api_ro:true;"^ - AccountId uint32 ^gorm:"unique_index:user_account_role;" truss:"internal:true;api_ro:true;"^ -} - -func (UserAccountRole) TableName() string { - return "user_account_roles" -} -` - code = strings.Replace(code, "^", "'", -1) - lines := strings.Split(code, "\n") - - objs, err := ParseLines(lines, 0) - if err != nil { - t.Fatalf("got error %v", err) - } - - g.Expect(objs.Lines()).Should(gomega.Equal(lines)) -} - -func testLineTextMatches(t *testing.T, l1, l2 []string) { - g := gomega.NewGomegaWithT(t) - - m1 := []string{} - for _, l := range l1 { - l = strings.TrimSpace(l) - if l != "" { - m1 = append(m1, l) - } - } - - m2 := []string{} - for _, l := range l2 { - l = strings.TrimSpace(l) - if l != "" { - m2 = append(m2, l) - } - } - - g.Expect(m1).Should(gomega.Equal(m2)) -} diff --git a/tools/truss/main.go b/tools/truss/main.go deleted file mode 100644 index e05745f..0000000 --- a/tools/truss/main.go +++ /dev/null @@ -1,228 +0,0 @@ -package main - -import ( - "encoding/json" - "expvar" - "io/ioutil" - "log" - "net/url" - "os" - "path" - "path/filepath" - "strings" - - "geeks-accelerator/oss/saas-starter-kit/tools/truss/cmd/dbtable2crud" - "github.com/kelseyhightower/envconfig" - "github.com/lib/pq" - _ "github.com/lib/pq" - "github.com/pkg/errors" - "github.com/urfave/cli" - sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql" - sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx" -) - -// build is the git version of this program. It is set using build flags in the makefile. -var build = "develop" - -// service is the name of the program used for logging, tracing and the -// the prefix used for loading env variables -// ie: export TRUSS_ENV=dev -var service = "TRUSS" - -func main() { - // ========================================================================= - // Logging - - log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) - - // ========================================================================= - // Configuration - var cfg struct { - DB struct { - Host string `default:"127.0.0.1:5433" envconfig:"HOST"` - User string `default:"postgres" envconfig:"USER"` - Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print - Database string `default:"shared" envconfig:"DATABASE"` - Driver string `default:"postgres" envconfig:"DRIVER"` - Timezone string `default:"utc" envconfig:"TIMEZONE"` - DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"` - } - } - - // For additional details refer to https://github.com/kelseyhightower/envconfig - if err := envconfig.Process(service, &cfg); err != nil { - log.Fatalf("main : Parsing Config : %v", err) - } - - // TODO: can't use flag.Process here since it doesn't support nested arg options - //if err := flag.Process(&cfg); err != nil { - /// if err != flag.ErrHelp { - // log.Fatalf("main : Parsing Command Line : %v", err) - // } - // return // We displayed help. - //} - - // ========================================================================= - // Log App Info - - // Print the build version for our logs. Also expose it under /debug/vars. - expvar.NewString("build").Set(build) - log.Printf("main : Started : Application Initializing version %q", build) - defer log.Println("main : Completed") - - // Print the config for our logs. It's important to any credentials in the config - // that could expose a security risk are excluded from being json encoded by - // applying the tag `json:"-"` to the struct var. - { - cfgJSON, err := json.MarshalIndent(cfg, "", " ") - if err != nil { - log.Fatalf("main : Marshalling Config to JSON : %v", err) - } - log.Printf("main : Config : %v\n", string(cfgJSON)) - } - - // ========================================================================= - // Start Database - var dbUrl url.URL - { - // Query parameters. - var q url.Values = make(map[string][]string) - - // Handle SSL Mode - if cfg.DB.DisableTLS { - q.Set("sslmode", "disable") - } else { - q.Set("sslmode", "require") - } - - q.Set("timezone", cfg.DB.Timezone) - - // Construct url. - dbUrl = url.URL{ - Scheme: cfg.DB.Driver, - User: url.UserPassword(cfg.DB.User, cfg.DB.Pass), - Host: cfg.DB.Host, - Path: cfg.DB.Database, - RawQuery: q.Encode(), - } - } - - // Register informs the sqlxtrace package of the driver that we will be using in our program. - // It uses a default service name, in the below case "postgres.db". To use a custom service - // name use RegisterWithServiceName. - sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service)) - masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String()) - if err != nil { - log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err) - } - defer masterDb.Close() - - // ========================================================================= - // Start Truss - - app := cli.NewApp() - app.Commands = []cli.Command{ - { - Name: "dbtable2crud", - Aliases: []string{}, - Usage: "-table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] ", - Flags: []cli.Flag{ - cli.StringFlag{Name: "dbtable, table"}, - cli.StringFlag{Name: "modelFile, modelfile, file"}, - cli.StringFlag{Name: "modelName, modelname, model"}, - cli.StringFlag{Name: "templateDir, templates", Value: "./templates/dbtable2crud"}, - cli.StringFlag{Name: "projectPath"}, - cli.BoolFlag{Name: "saveChanges, save"}, - }, - Action: func(c *cli.Context) error { - dbTable := strings.TrimSpace(c.String("dbtable")) - modelFile := strings.TrimSpace(c.String("modelFile")) - modelName := strings.TrimSpace(c.String("modelName")) - templateDir := strings.TrimSpace(c.String("templateDir")) - projectPath := strings.TrimSpace(c.String("projectPath")) - - pwd, err := os.Getwd() - if err != nil { - return errors.WithMessage(err, "Failed to get current working directory") - } - - if !path.IsAbs(templateDir) { - templateDir = filepath.Join(pwd, templateDir) - } - ok, err := exists(templateDir) - if err != nil { - return errors.WithMessage(err, "Failed to load template directory") - } else if !ok { - return errors.Errorf("Template directory %s does not exist", templateDir) - } - - if modelFile == "" { - return errors.Errorf("Model file path is required") - } - - if !path.IsAbs(modelFile) { - modelFile = filepath.Join(pwd, modelFile) - } - ok, err = exists(modelFile) - if err != nil { - return errors.WithMessage(err, "Failed to load model file") - } else if !ok { - return errors.Errorf("Model file %s does not exist", modelFile) - } - - // Load the project path from go.mod if not set. - if projectPath == "" { - goModFile := filepath.Join(pwd, "../../go.mod") - ok, err = exists(goModFile) - if err != nil { - return errors.WithMessage(err, "Failed to load go.mod for project") - } else if !ok { - return errors.Errorf("Failed to locate project go.mod at %s", goModFile) - } - - b, err := ioutil.ReadFile(goModFile) - if err != nil { - return errors.WithMessagef(err, "Failed to read go.mod at %s", goModFile) - } - - lines := strings.Split(string(b), "\n") - for _, l := range lines { - if strings.HasPrefix(l, "module ") { - projectPath = strings.TrimSpace(strings.Split(l, " ")[1]) - break - } - } - } - - if modelName == "" { - modelName = strings.Split(filepath.Base(modelFile), ".")[0] - modelName = strings.Replace(modelName, "_", " ", -1) - modelName = strings.Replace(modelName, "-", " ", -1) - modelName = strings.Title(modelName) - modelName = strings.Replace(modelName, " ", "", -1) - } - - return dbtable2crud.Run(masterDb, log, cfg.DB.Database, dbTable, modelFile, modelName, templateDir, projectPath, c.Bool("saveChanges")) - }, - }, - } - - err = app.Run(os.Args) - if err != nil { - log.Fatalf("main : Truss : %+v", err) - } - - log.Printf("main : Truss : Completed") -} - -// exists returns a bool as to whether a file path exists. -func exists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - if os.IsNotExist(err) { - return false, nil - } - return true, err -} diff --git a/tools/truss/makefile b/tools/truss/makefile deleted file mode 100644 index 0e327ea..0000000 --- a/tools/truss/makefile +++ /dev/null @@ -1,10 +0,0 @@ -SHELL := /bin/bash - -install: - go install . - -build: - go install . - -run: - go build . && ./truss diff --git a/tools/truss/sample.env b/tools/truss/sample.env deleted file mode 100644 index 890ff56..0000000 --- a/tools/truss/sample.env +++ /dev/null @@ -1,22 +0,0 @@ -# Variables to configure Postgres for database migration. -export TRUSS_DB_HOST=127.0.0.1:5433 -export TRUSS_DB_USER=postgres -export TRUSS_DB_PASS=postgres -export TRUSS_DB_DISABLE_TLS=true - -# Variables to configure AWS for service build and deploy. -# Use the same set for AWS credentials for all target envinments. -#AWS_ACCESS_KEY_ID=XXXXXXXXXXXXXX -#AWS_SECRET_ACCESS_KEY=XXXXXXXXXXXXXX -#AWS_REGION=us-west-2 - -# AWS credentials can be prefixed with the target uppercased target envinments. -# This allows credentials unique accounts to be used for each target envinments. -# Default target envinments are: DEV, STAGE, PROD -#DEV_AWS_ACCESS_KEY_ID=XXXXXXXXXXXXXX -#DEV_AWS_SECRET_ACCESS_KEY=XXXXXXXXXXXXXX -#DEV_AWS_REGION=us-west-2 - -# GitLab CI/CD environment variables. These are set by the GitLab when the build -# pipeline is running. These can be optional set for testing/debugging locally. -#CI_COMMIT_REF_NAME=master diff --git a/tools/truss/templates/dbtable2crud/model_crud.tmpl b/tools/truss/templates/dbtable2crud/model_crud.tmpl deleted file mode 100644 index 5972bc3..0000000 --- a/tools/truss/templates/dbtable2crud/model_crud.tmpl +++ /dev/null @@ -1,510 +0,0 @@ -{{ define "imports"}} -import ( - "context" - "database/sql" - "time" - - "{{ $.GoSrcPath }}/internal/platform/auth" - "github.com/huandu/go-sqlbuilder" - "github.com/jmoiron/sqlx" - "github.com/pborman/uuid" - "github.com/pkg/errors" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" - "gopkg.in/go-playground/validator.v9" -) -{{ end }} -{{ define "Globals"}} -const ( - // The database table for {{ $.Model.Name }} - {{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}" -) - -var ( - // ErrNotFound abstracts the postgres not found error. - ErrNotFound = errors.New("Entity not found") - - // ErrInvalidID occurs when an ID is not in a valid form. - ErrInvalidID = errors.New("ID is not in its proper form") - - // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. - ErrForbidden = errors.New("Attempted action is not allowed") -) -{{ end }} -{{ define "Helpers"}} -// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }} -var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}" - -// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct -// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns -func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) { - var ( - m {{ $.Model.Name }} - err error - ) - err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }}) - if err != nil { - return nil, errors.WithStack(err) - } - - return &m, nil -} -{{ end }} -{{ define "ACL"}} -{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }} -// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. -func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { - - {{ if $hasAccountID }} - // If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims - // has the correct access to the {{ FormatCamelLower $.Model.Name }}. - if claims.Audience != "" { - // select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID] - query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName) - query.Where(query.And( - query.Equal("account_id", claims.Audience), - query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}), - )) - queryStr, args := query.Build() - queryStr = dbConn.Rebind(queryStr) - - var {{ FormatCamelLower $.Model.PrimaryField }} string - err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }}) - if err != nil && err != sql.ErrNoRows { - err = errors.Wrapf(err, "query - %s", query.String()) - return err - } - - // When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access - // to the specified {{ FormatCamelLowerTitle $.Model.Name }}. - if {{ FormatCamelLower $.Model.PrimaryField }} == "" { - return errors.WithStack(ErrForbidden) - } - } - {{ else }} - // TODO: Unable to auto generate sql statement, update accordingly. - panic("Not implemented!") - {{ end }} - - return nil -} - -// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. -func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { - err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }}) - if err != nil { - return err - } - - // Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to. - if !claims.HasRole(auth.RoleAdmin) { - return errors.WithStack(ErrForbidden) - } - - return nil -} - -// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided. -// 1. No claims, request is internal, no ACL applied -{{ if $hasAccountID }} -// 2. All role types can access their user ID -{{ end }} -func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error { - // Claims are empty, don't apply any ACL - if claims.Audience == "" { - return nil - } - - {{ if $hasAccountID }} - query.Where(query.Equal("account_id", claims.Audience)) - {{ end }} - - return nil -} -{{ end }} -{{ define "Find"}} -{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }} - -// selectQuery constructs a base select query for {{ $.Model.Name }} -func selectQuery() *sqlbuilder.SelectBuilder { - query := sqlbuilder.NewSelectBuilder() - query.Select({{ FormatCamelLower $.Model.Name }}MapColumns) - query.From({{ FormatCamelLower $.Model.Name }}TableName) - return query -} - -// findRequestQuery generates the select query for the given find request. -// TODO: Need to figure out why can't parse the args when appending the where -// to the query. -func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { - query := selectQuery() - if req.Where != nil { - query.Where(query.And(*req.Where)) - } - if len(req.Order) > 0 { - query.OrderBy(req.Order...) - } - if req.Limit != nil { - query.Limit(int(*req.Limit)) - } - if req.Offset != nil { - query.Offset(int(*req.Offset)) - } - - return query, req.Args -} - -// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params. -func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) { - query, args := findRequestQuery(req) - return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }}) -} - -// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query. -func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find") - defer span.Finish() - - query.Select({{ FormatCamelLower $.Model.Name }}MapColumns) - query.From({{ FormatCamelLower $.Model.Name }}TableName) - - {{ if $hasArchived }} - if !includedArchived { - query.Where(query.IsNull("archived_at")) - } - {{ end }} - - // Check to see if a sub query needs to be applied for the claims. - err := applyClaimsSelect(ctx, claims, query) - if err != nil { - return nil, err - } - queryStr, queryArgs := query.Build() - queryStr = dbConn.Rebind(queryStr) - args = append(args, queryArgs...) - - // Fetch all entries from the db. - rows, err := dbConn.QueryContext(ctx, queryStr, args...) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed") - return nil, err - } - - // Iterate over each row. - resp := []*{{ $.Model.Name }}{} - for rows.Next() { - u, err := mapRowsTo{{ $.Model.Name }}(rows) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - return nil, err - } - resp = append(resp, u) - } - - return resp, nil -} - -// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database. -func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read") - defer span.Finish() - - // Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }} - query := selectQuery() - query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }})) - - res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }}) - if err != nil { - return nil, err - } else if res == nil || len(res) == 0 { - err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id) - return nil, err - } - u := res[0] - - return u, nil -} -{{ end }} -{{ define "Create"}} -{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }} -{{ $reqName := (Concat $.Model.Name "CreateRequest") }} -{{ $createFields := (index $.StructFields $reqName) }} -{{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }} -// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database. -func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create") - defer span.Finish() - - if claims.Audience != "" { - // Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to. - if !claims.HasRole(auth.RoleAdmin) { - return nil, errors.WithStack(ErrForbidden) - } - - {{ if $reqHasAccountID }} - if req.AccountID != "" { - // Request accountId must match claims. - if req.AccountID != claims.Audience { - return nil, errors.WithStack(ErrForbidden) - } - } else { - // Set the accountId from claims. - req.AccountID = claims.Audience - } - {{ end }} - } - - v := validator.New() - - // Validate the request. - err := v.Struct(req) - if err != nil { - return nil, err - } - - // If now empty set it to the current time. - if now.IsZero() { - now = time.Now() - } - - // Always store the time as UTC. - now = now.UTC() - - // Postgres truncates times to milliseconds when storing. We and do the same - // here so the value we return is consistent with what we store. - now = now.Truncate(time.Millisecond) - - m := {{ $.Model.Name }}{ - {{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }} - {{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }}, - {{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now, - {{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }} - {{ end }}{{ end }} - } - - {{ if and (not $reqHasAccountID) ($hasAccountID) }} - // Set the accountId from claims. - if claims.Audience != "" && m.AccountID == "" { - req.AccountID = claims.Audience - } - {{ end }} - - {{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }} - if req.{{ $f.FieldName }} != nil { - {{ if eq $f.FieldType "sql.NullString" }} - m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true} - {{ else if eq $f.FieldType "*sql.NullString" }} - m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true} - {{ else }} - m.{{ $f.FieldName }} = *req.{{ $f.FieldName }} - {{ end }} - } - {{ end }}{{ end }} - - // Build the insert SQL statement. - query := sqlbuilder.NewInsertBuilder() - query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName) - query.Cols( - {{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}", - {{ end }}{{ end }} - ) - query.Values( - {{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }}, - {{ end }}{{ end }} - ) - - // Execute the query with the provided context. - sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed") - return nil, err - } - - return &m, nil -} -{{ end }} -{{ define "Update"}} -{{ $reqName := (Concat $.Model.Name "UpdateRequest") }} -{{ $updateFields := (index $.StructFields $reqName) }} -// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database. -func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update") - defer span.Finish() - - v := validator.New() - - // Validate the request. - err := v.Struct(req) - if err != nil { - return err - } - - // Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request. - err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }}) - if err != nil { - return err - } - - // If now empty set it to the current time. - if now.IsZero() { - now = time.Now() - } - - // Always store the time as UTC. - now = now.UTC() - - // Postgres truncates times to milliseconds when storing. We and do the same - // here so the value we return is consistent with what we store. - now = now.Truncate(time.Millisecond) - - // Build the update SQL statement. - query := sqlbuilder.NewUpdateBuilder() - query.Update({{ FormatCamelLower $.Model.Name }}TableName) - - var fields []string - - {{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }} - {{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }} - if req.{{ $uf.FieldName }} != nil { - {{ if and ($optional) ($isUuid) }} - if *req.{{ $uf.FieldName }} != "" { - fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }})) - } else { - fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil)) - } - {{ else }} - fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }})) - {{ end }} - } - {{ end }}{{ end }} - - // If there's nothing to update we can quit early. - if len(fields) == 0 { - return nil - } - - {{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }} - // Append the updated_at field - fields = append(fields, query.Assign("updated_at", now)) - {{ end }} - - query.Set(fields...) - query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }})) - - // Execute the query with the provided context. - sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }}) - return err - } - - return nil -} -{{ end }} -{{ define "Archive"}} -// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database. -func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive") - defer span.Finish() - - // Defines the struct to apply validation - req := struct { - {{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"` - }{ - {{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }}, - } - - // Validate the request. - err := validator.New().Struct(req) - if err != nil { - return err - } - - // Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request. - err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID) - if err != nil { - return err - } - - // If now empty set it to the current time. - if now.IsZero() { - now = time.Now() - } - - // Always store the time as UTC. - now = now.UTC() - - // Postgres truncates times to milliseconds when storing. We and do the same - // here so the value we return is consistent with what we store. - now = now.Truncate(time.Millisecond) - - // Build the update SQL statement. - query := sqlbuilder.NewUpdateBuilder() - query.Update({{ FormatCamelLower $.Model.Name }}TableName) - query.Set( - query.Assign("archived_at", now), - ) - query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }})) - - // Execute the query with the provided context. - sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }}) - return err - } - - return nil -} -{{ end }} -{{ define "Delete"}} -// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database. -func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error { - span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete") - defer span.Finish() - - // Defines the struct to apply validation - req := struct { - {{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"` - }{ - {{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }}, - } - - // Validate the request. - err := validator.New().Struct(req) - if err != nil { - return err - } - - // Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request. - err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID) - if err != nil { - return err - } - - // Build the delete SQL statement. - query := sqlbuilder.NewDeleteBuilder() - query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName) - query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }})) - - // Execute the query with the provided context. - sql, args := query.Build() - sql = dbConn.Rebind(sql) - _, err = dbConn.ExecContext(ctx, sql, args...) - if err != nil { - err = errors.Wrapf(err, "query - %s", query.String()) - err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }}) - return err - } - - return nil -} -{{ end }} diff --git a/tools/truss/templates/dbtable2crud/model_crud_test.tmpl b/tools/truss/templates/dbtable2crud/model_crud_test.tmpl deleted file mode 100644 index 65477d1..0000000 --- a/tools/truss/templates/dbtable2crud/model_crud_test.tmpl +++ /dev/null @@ -1,131 +0,0 @@ -{{ define "imports"}} -import ( - "{{ $.GoSrcPath }}/internal/platform/auth" - "{{ $.GoSrcPath }}/internal/platform/tests" - "github.com/google/go-cmp/cmp" - "github.com/huandu/go-sqlbuilder" - "os" - "testing" -) -{{ end }} -{{ define "Globals"}} -var test *tests.Test - -// TestMain is the entry point for testing. -func TestMain(m *testing.M) { - os.Exit(testMain(m)) -} - -func testMain(m *testing.M) int { - test = tests.New() - defer test.TearDown() - return m.Run() -} -{{ end }} -{{ define "TestFindRequestQuery"}} -// TestFindRequestQuery validates findRequestQuery -func TestFindRequestQuery(t *testing.T) { - where := "field1 = ? or field2 = ?" - var ( - limit uint = 12 - offset uint = 34 - ) - - req := {{ $.Model.Name }}FindRequest{ - Where: &where, - Args: []interface{}{ - "lee brown", - "103 East Main St.", - }, - Order: []string{ - "id asc", - "created_at desc", - }, - Limit: &limit, - Offset: &offset, - } - expected := "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" - - res, args := findRequestQuery(req) - - if diff := cmp.Diff(res.String(), expected); diff != "" { - t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) - } - if diff := cmp.Diff(args, req.Args); diff != "" { - t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) - } -} -{{ end }} -{{ define "TestApplyClaimsSelect"}} -// TestApplyClaimsSelect applyClaimsSelect -func TestApplyClaimsSelect(t *testing.T) { - var claimTests = []struct { - name string - claims auth.Claims - expectedSql string - error error - }{ - {"EmptyClaims", - auth.Claims{}, - "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName, - nil, - }, - {"RoleAccount", - auth.Claims{ - Roles: []string{auth.RoleAdmin}, - StandardClaims: jwt.StandardClaims{ - Subject: "user1", - Audience: "acc1", - }, - }, - "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'", - nil, - }, - {"RoleAdmin", - auth.Claims{ - Roles: []string{auth.RoleAdmin}, - StandardClaims: jwt.StandardClaims{ - Subject: "user1", - Audience: "acc1", - }, - }, - "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'", - nil, - }, - } - - t.Log("Given the need to validate ACLs are enforced by claims to a select query.") - { - for i, tt := range claimTests { - t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) - { - ctx := tests.Context() - - query := selectQuery() - - err := applyClaimsSelect(ctx, tt.claims, query) - if err != tt.error { - t.Logf("\t\tGot : %+v", err) - t.Logf("\t\tWant: %+v", tt.error) - t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed) - } - - sql, args := query.Build() - - // Use mysql flavor so placeholders will get replaced for comparison. - sql, err = sqlbuilder.MySQL.Interpolate(sql, args) - if err != nil { - t.Log("\t\tGot :", err) - t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed) - } - - if diff := cmp.Diff(sql, tt.expectedSql); diff != "" { - t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) - } - - t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success) - } - } - } -} -{{ end }} \ No newline at end of file diff --git a/tools/truss/templates/dbtable2crud/models.tmpl b/tools/truss/templates/dbtable2crud/models.tmpl deleted file mode 100644 index 4d96f5b..0000000 --- a/tools/truss/templates/dbtable2crud/models.tmpl +++ /dev/null @@ -1,80 +0,0 @@ -{{ define "CreateRequest"}} -// {{ FormatCamel $.Model.Name }}CreateRequest contains information needed to create a new {{ FormatCamel $.Model.Name }}. -type {{ FormatCamel $.Model.Name }}CreateRequest struct { - {{ range $fk, $f := .Model.Fields }}{{ if and ($f.ApiCreate) (ne $f.FieldName $.Model.PrimaryField) }}{{ $optional := (FieldTagHasOption $f "validate" "omitempty") }} - {{ $f.FieldName }} {{ if and ($optional) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ FieldTag $f "validate" }}` - {{ end }}{{ end }} -} -{{ end }} -{{ define "UpdateRequest"}} -// {{ FormatCamel $.Model.Name }}UpdateRequest defines what information may be provided to modify an existing -// {{ FormatCamel $.Model.Name }}. All fields are optional so clients can send just the fields they want -// changed. It uses pointer fields so we can differentiate between a field that -// was not provided and a field that was provided as explicitly blank. -type {{ FormatCamel $.Model.Name }}UpdateRequest struct { - {{ range $fk, $f := .Model.Fields }}{{ if $f.ApiUpdate }} - {{ $f.FieldName }} {{ if and (ne $f.FieldName $.Model.PrimaryField) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ if ne $f.FieldName $.Model.PrimaryField }}{{ FieldTagReplaceOrPrepend $f "validate" "required" "omitempty" }}{{ else }}{{ FieldTagReplaceOrPrepend $f "validate" "omitempty" "required" }}{{ end }}` - {{ end }}{{ end }} -} -{{ end }} -{{ define "FindRequest"}} -// {{ FormatCamel $.Model.Name }}FindRequest defines the possible options to search for {{ FormatCamelPluralTitleLower $.Model.Name }}. By default -// archived {{ FormatCamelLowerTitle $.Model.Name }} will be excluded from response. -type {{ FormatCamel $.Model.Name }}FindRequest struct { - Where *string `schema:"where"` - Args []interface{} `schema:"args"` - Order []string `schema:"order"` - Limit *uint `schema:"limit"` - Offset *uint `schema:"offset"` - IncludedArchived bool - {{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}{{ if $hasArchived }}IncludedArchived bool `schema:"included-archived"`{{ end }} -} -{{ end }} -{{ define "Enums"}} -{{ range $fk, $f := .Model.Fields }}{{ if $f.DbColumn }}{{ if $f.DbColumn.IsEnum }} -// {{ $f.FieldType }} represents the {{ $f.ColumnName }} of {{ FormatCamelLowerTitle $.Model.Name }}. -type {{ $f.FieldType }} string - -// {{ $f.FieldType }} values define the {{ $f.ColumnName }} field of {{ FormatCamelLowerTitle $.Model.Name }}. -const ( - {{ range $evk, $ev := $f.DbColumn.EnumValues }} - // {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}. - {{ $f.FieldType }}_{{ FormatCamel $ev }} {{ $f.FieldType }} = "{{ $ev }}" - {{ end }} -) - -// {{ $f.FieldType }}_Values provides list of valid {{ $f.FieldType }} values. -var {{ $f.FieldType }}_Values = []{{ $f.FieldType }}{ - {{ range $evk, $ev := $f.DbColumn.EnumValues }} - {{ $f.FieldType }}_{{ FormatCamel $ev }}, - {{ end }} -} - -// Scan supports reading the {{ $f.FieldType }} value from the database. -func (s *{{ $f.FieldType }}) Scan(value interface{}) error { - asBytes, ok := value.([]byte) - if !ok { - return errors.New("Scan source is not []byte") - } - *s = {{ $f.FieldType }}(string(asBytes)) - return nil -} - -// Value converts the {{ $f.FieldType }} value to be stored in the database. -func (s {{ $f.FieldType }}) Value() (driver.Value, error) { - v := validator.New() - - errs := v.Var(s, "required,oneof={{ JoinStrings $f.DbColumn.EnumValues " " }}") - if errs != nil { - return nil, errs - } - - return string(s), nil -} - -// String converts the {{ $f.FieldType }} value to a string. -func (s {{ $f.FieldType }}) String() string { - return string(s) -} -{{ end }}{{ end }}{{ end }} -{{ end }} \ No newline at end of file