1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-15 00:15:15 +02:00

Completed enough code gen for project ATM

This commit is contained in:
Lee Brown
2019-06-24 04:26:48 -08:00
parent 7b5c2a5807
commit 07e86cfd52
13 changed files with 839 additions and 92 deletions

View File

@ -16,7 +16,7 @@ Truss provides code generation to reduce copy/pasting.
go build .
```
### Usage
### Configuration
```bash
./truss -h
@ -29,5 +29,40 @@ Usage of ./truss
--db_driver string <postgres>
--db_timezone string <utc>
--db_disabletls bool <false>
```
## Commands:
## dbtable2crud
Used to bootstrap a new business logic package with basic CRUD.
**Usage**
```bash
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false]
```
**Example**
1. Define a new database table in `internal/schema/migrations.go`
2. Create a new file for the base model at `internal/projects/models.go`. Only the following struct needs to be included. All the other times will be generated.
```go
// Project represents a workflow.
type Project struct {
ID string `json:"id" validate:"required,uuid"`
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
Name string `json:"name" validate:"required"`
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
CreatedAt time.Time `json:"created_at" truss:"api-read"`
UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
}
```
3. Run `dbtable2crud`
```bash
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -save=true
```

View File

@ -303,7 +303,57 @@ func updateModelCrudFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, template
continue
}
if crudDoc.HasType(obj.Name, obj.Type) {
if obj.Name == "" && (obj.Type == goparse.GoObjectType_Var || obj.Type == goparse.GoObjectType_Const) {
var curDocObj *goparse.GoObject
for _, subObj := range obj.Objects().List() {
for _, do := range crudDoc.Objects().List() {
if do.Name == "" && (do.Type == goparse.GoObjectType_Var || do.Type == goparse.GoObjectType_Const) {
for _, subDocObj := range do.Objects().List() {
if subDocObj.String() == subObj.String() && subObj.Type != goparse.GoObjectType_LineBreak {
curDocObj = do
break
}
}
}
}
}
if curDocObj != nil {
for _, subObj := range obj.Objects().List() {
var hasSubObj bool
for _, subDocObj := range curDocObj.Objects().List() {
if subDocObj.String() == subObj.String() {
hasSubObj = true
break
}
}
if !hasSubObj {
curDocObj.Objects().Add(subObj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
}
}
} else {
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
err := crudDoc.Objects().Add(c)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
return err
}
}
err := crudDoc.Objects().Add(obj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
}
} else if crudDoc.HasType(obj.Name, obj.Type) {
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
newObjs := []*goparse.GoObject{}

View File

@ -36,7 +36,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
tempFilePath := filepath.Join(templateDir, filename)
dat, err := ioutil.ReadFile(tempFilePath)
if err != nil {
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
return nil, err
}
@ -46,17 +46,17 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
"Concat": func(vals ...string) string {
return strings.Join(vals, "")
},
"JoinStrings": func(vals []string, sep string) string {
"JoinStrings": func(vals []string, sep string ) string {
return strings.Join(vals, sep)
},
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
"PrefixAndJoinStrings": func(vals []string, pre, sep string ) string {
l := []string{}
for _, v := range vals {
l = append(l, pre+v)
l = append(l, pre + v)
}
return strings.Join(l, sep)
},
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string ) string {
l := []string{}
for _, v := range vals {
l = append(l, fmt.Sprintf(fmtStr, v))
@ -74,35 +74,35 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
return "id"
}
return FormatCamelLower(name)
},
} ,
"FormatCamelLowerTitle": func(name string) string {
return FormatCamelLowerTitle(name)
},
} ,
"FormatCamelPluralTitle": func(name string) string {
return FormatCamelPluralTitle(name)
},
} ,
"FormatCamelPluralTitleLower": func(name string) string {
return FormatCamelPluralTitleLower(name)
},
} ,
"FormatCamelPluralCamel": func(name string) string {
return FormatCamelPluralCamel(name)
},
} ,
"FormatCamelPluralLower": func(name string) string {
return FormatCamelPluralLower(name)
},
} ,
"FormatCamelPluralUnderscore": func(name string) string {
return FormatCamelPluralUnderscore(name)
},
} ,
"FormatCamelPluralLowerUnderscore": func(name string) string {
return FormatCamelPluralLowerUnderscore(name)
},
} ,
"FormatCamelUnderscore": func(name string) string {
return FormatCamelUnderscore(name)
},
} ,
"FormatCamelLowerUnderscore": func(name string) string {
return FormatCamelLowerUnderscore(name)
},
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
} ,
"FieldTagHasOption": func(f modelField, tagName, optName string ) bool {
if f.Tags == nil {
return false
}
@ -143,7 +143,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
}
}
} else if !ft.HasOption(newVal) {
if ft.Name == "" {
if ft.Name == ""{
ft.Name = newVal
} else {
ft.Options = append(ft.Options, newVal)
@ -165,7 +165,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
// Load the template file using the text/template package.
tmpl, err := baseTmpl.Parse(string(dat))
if err != nil {
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
return nil, err
}
@ -189,18 +189,18 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
// Executed the defined template with the given data.
var tpl bytes.Buffer
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
return resp, err
}
// Format the source code to ensure its valid and code to parsed consistently.
codeBytes, err := format.Source(tpl.Bytes())
if err != nil {
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
dl := []string{}
for idx, l := range strings.Split(tpl.String(), "\n") {
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
dl = append(dl, fmt.Sprintf("%d -> ", idx) + l)
}
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
@ -216,7 +216,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
// Parse the code lines into a set of objects.
objs, err := goparse.ParseLines(codeLines, 0)
if err != nil {
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
return resp, err
}
@ -316,7 +316,7 @@ func templateFileOrderedNames(localPath string, names []string) (resp []string,
idx := 0
scanner := bufio.NewScanner(file)
for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
continue
}

View File

@ -88,6 +88,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
ld := lineDepth(l)
//fmt.Println("l", l)
//fmt.Println("> Depth", ld, "???", depth)
if ld == depth {
if strings.HasPrefix(ls, "/*") {
multiLine = true
@ -108,6 +113,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
}
}
//fmt.Println("> multiLine", multiLine)
//fmt.Println("> multiComment", multiComment)
//fmt.Println("> muiliVar", muiliVar)
objLines = append(objLines, l)
if multiComment {
@ -131,6 +141,8 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
}
}
//fmt.Println(" > objLines", objLines)
obj, err := ParseGoObject(objLines, depth)
if err != nil {
log.Println(err)
@ -197,8 +209,22 @@ func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
if strings.HasPrefix(firstStrip, "var") {
obj.Type = GoObjectType_Var
if !strings.HasSuffix(firstStrip, "(") {
if strings.HasPrefix(firstStrip, "var ") {
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "var ", "", 1))
}
obj.Name = strings.Split(firstStrip, " ")[0]
}
} else if strings.HasPrefix(firstStrip, "const") {
obj.Type = GoObjectType_Const
if !strings.HasSuffix(firstStrip, "(") {
if strings.HasPrefix(firstStrip, "const ") {
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "const ", "", 1))
}
obj.Name = strings.Split(firstStrip, " ")[0]
}
} else if strings.HasPrefix(firstStrip, "func") {
obj.Type = GoObjectType_Func

View File

@ -64,7 +64,8 @@ func TestNewDocImports(t *testing.T) {
func TestParseLines1(t *testing.T) {
g := gomega.NewGomegaWithT(t)
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
codeTests := []string{
`func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
g := gomega.NewGomegaWithT(t)
obj := datamodels.MockModelNew()
resp, err := ModelCreate(ctx, DB, &obj)
@ -76,15 +77,30 @@ func TestParseLines1(t *testing.T) {
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
return resp
}
`
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
`,
`var (
// ErrNotFound abstracts the postgres not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
`,
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
for _, code := range codeTests {
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
}
func TestParseLines2(t *testing.T) {

View File

@ -124,8 +124,8 @@ func main() {
app.Commands = []cli.Command{
{
Name: "dbtable2crud",
Aliases: []string{"dbtable2crud"},
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -saveChanges=false",
Aliases: []string{},
Usage: "-table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] ",
Flags: []cli.Flag{
cli.StringFlag{Name: "dbtable, table"},
cli.StringFlag{Name: "modelFile, modelfile, file"},

View File

@ -46,15 +46,15 @@ func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
return nil, errors.WithStack(err)
}
return &a, nil
return &m, nil
}
{{ end }}
{{ define "ACL"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
{{ if $hasAccountId }}
{{ if $hasAccountID }}
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
if claims.Audience != "" {
@ -90,7 +90,7 @@ func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
if err != nil {
return err
}
@ -105,7 +105,7 @@ func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
// 1. No claims, request is internal, no ACL applied
{{ if $hasAccountId }}
{{ if $hasAccountID }}
// 2. All role types can access their user ID
{{ end }}
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
@ -114,7 +114,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
return nil
}
{{ if $hasAccountId }}
{{ if $hasAccountID }}
query.Where(query.Equal("account_id", claims.Audience))
{{ end }}
@ -226,9 +226,10 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCam
}
{{ end }}
{{ define "Create"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
{{ $createFields := (index $.StructFields $reqName) }}
{{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }}
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
@ -237,18 +238,18 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
if claims.Audience != "" {
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
return nil, errors.WithStack(ErrForbidden)
}
{{ if $hasAccountId }}
if req.AccountId != "" {
{{ if $reqHasAccountID }}
if req.AccountID != "" {
// Request accountId must match claims.
if req.AccountId != claims.Audience {
return errors.WithStack(ErrForbidden)
if req.AccountID != claims.Audience {
return nil, errors.WithStack(ErrForbidden)
}
} else {
// Set the accountId from claims.
req.AccountId = claims.Audience
req.AccountID = claims.Audience
}
{{ end }}
}
@ -256,7 +257,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
v := validator.New()
// Validate the request.
err = v.Struct(req)
err := v.Struct(req)
if err != nil {
return nil, err
}
@ -281,6 +282,13 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
{{ end }}{{ end }}
}
{{ if and (not $reqHasAccountID) ($hasAccountID) }}
// Set the accountId from claims.
if claims.Audience != "" && m.AccountID == "" {
req.AccountID = claims.Audience
}
{{ end }}
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
if req.{{ $f.FieldName }} != nil {
{{ if eq $f.FieldType "sql.NullString" }}
@ -315,7 +323,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
return nil, err
}
return &a, nil
return &m, nil
}
{{ end }}
{{ define "Update"}}

View File

@ -1,19 +1,11 @@
{{ define "imports"}}
import (
"github.com/lib/pq"
"math/rand"
"os"
"strings"
"testing"
"time"
"{{ $.GoSrcPath }}/internal/platform/auth"
"{{ $.GoSrcPath }}/internal/platform/tests"
"github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp"
"github.com/huandu/go-sqlbuilder"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"os"
"testing"
)
{{ end }}
{{ define "Globals"}}
@ -33,7 +25,7 @@ func testMain(m *testing.M) int {
{{ define "TestFindRequestQuery"}}
// TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) {
where := "name = ? or address1 = ?"
where := "field1 = ? or field2 = ?"
var (
limit uint = 12
offset uint = 34
@ -52,7 +44,7 @@ func TestFindRequestQuery(t *testing.T) {
Limit: &limit,
Offset: &offset,
}
expected := "SELECT " + accountMapColumns + " FROM " + accountTableName + " WHERE (name = ? or address1 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
expected := "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := findRequestQuery(req)
@ -63,4 +55,77 @@ func TestFindRequestQuery(t *testing.T) {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
}
{{ end }}
{{ define "TestApplyClaimsSelect"}}
// TestApplyClaimsSelect applyClaimsSelect
func TestApplyClaimsSelect(t *testing.T) {
var claimTests = []struct {
name string
claims auth.Claims
expectedSql string
error error
}{
{"EmptyClaims",
auth.Claims{},
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName,
nil,
},
{"RoleAccount",
auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: "user1",
Audience: "acc1",
},
},
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
nil,
},
{"RoleAdmin",
auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: "user1",
Audience: "acc1",
},
},
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
nil,
},
}
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
{
for i, tt := range claimTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
query := selectQuery()
err := applyClaimsSelect(ctx, tt.claims, query)
if err != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
}
sql, args := query.Build()
// Use mysql flavor so placeholders will get replaced for comparison.
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
}
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success)
}
}
}
}
{{ end }}

View File

@ -38,7 +38,7 @@ type {{ $f.FieldType }} string
const (
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}"
{{ $f.FieldType }}_{{ FormatCamel $ev }} {{ $f.FieldType }} = "{{ $ev }}"
{{ end }}
)