You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-06-15 00:15:15 +02:00
Completed enough code gen for project ATM
This commit is contained in:
@ -16,7 +16,7 @@ Truss provides code generation to reduce copy/pasting.
|
||||
go build .
|
||||
```
|
||||
|
||||
### Usage
|
||||
### Configuration
|
||||
```bash
|
||||
./truss -h
|
||||
|
||||
@ -29,5 +29,40 @@ Usage of ./truss
|
||||
--db_driver string <postgres>
|
||||
--db_timezone string <utc>
|
||||
--db_disabletls bool <false>
|
||||
```
|
||||
|
||||
## Commands:
|
||||
|
||||
## dbtable2crud
|
||||
|
||||
Used to bootstrap a new business logic package with basic CRUD.
|
||||
|
||||
**Usage**
|
||||
```bash
|
||||
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false]
|
||||
```
|
||||
|
||||
**Example**
|
||||
1. Define a new database table in `internal/schema/migrations.go`
|
||||
|
||||
|
||||
2. Create a new file for the base model at `internal/projects/models.go`. Only the following struct needs to be included. All the other times will be generated.
|
||||
```go
|
||||
// Project represents a workflow.
|
||||
type Project struct {
|
||||
ID string `json:"id" validate:"required,uuid"`
|
||||
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
|
||||
CreatedAt time.Time `json:"created_at" truss:"api-read"`
|
||||
UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
|
||||
}
|
||||
```
|
||||
|
||||
3. Run `dbtable2crud`
|
||||
```bash
|
||||
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -save=true
|
||||
```
|
||||
|
||||
|
||||
|
@ -303,7 +303,57 @@ func updateModelCrudFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, template
|
||||
continue
|
||||
}
|
||||
|
||||
if crudDoc.HasType(obj.Name, obj.Type) {
|
||||
if obj.Name == "" && (obj.Type == goparse.GoObjectType_Var || obj.Type == goparse.GoObjectType_Const) {
|
||||
var curDocObj *goparse.GoObject
|
||||
for _, subObj := range obj.Objects().List() {
|
||||
for _, do := range crudDoc.Objects().List() {
|
||||
if do.Name == "" && (do.Type == goparse.GoObjectType_Var || do.Type == goparse.GoObjectType_Const) {
|
||||
for _, subDocObj := range do.Objects().List() {
|
||||
if subDocObj.String() == subObj.String() && subObj.Type != goparse.GoObjectType_LineBreak {
|
||||
curDocObj = do
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if curDocObj != nil {
|
||||
for _, subObj := range obj.Objects().List() {
|
||||
var hasSubObj bool
|
||||
for _, subDocObj := range curDocObj.Objects().List() {
|
||||
if subDocObj.String() == subObj.String() {
|
||||
hasSubObj = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasSubObj {
|
||||
curDocObj.Objects().Add(subObj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := crudDoc.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := crudDoc.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if crudDoc.HasType(obj.Name, obj.Type) {
|
||||
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
|
@ -36,7 +36,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
tempFilePath := filepath.Join(templateDir, filename)
|
||||
dat, err := ioutil.ReadFile(tempFilePath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
|
||||
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -46,17 +46,17 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
"Concat": func(vals ...string) string {
|
||||
return strings.Join(vals, "")
|
||||
},
|
||||
"JoinStrings": func(vals []string, sep string) string {
|
||||
"JoinStrings": func(vals []string, sep string ) string {
|
||||
return strings.Join(vals, sep)
|
||||
},
|
||||
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
|
||||
"PrefixAndJoinStrings": func(vals []string, pre, sep string ) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, pre+v)
|
||||
l = append(l, pre + v)
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
|
||||
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string ) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, fmt.Sprintf(fmtStr, v))
|
||||
@ -74,35 +74,35 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
return "id"
|
||||
}
|
||||
return FormatCamelLower(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelLowerTitle": func(name string) string {
|
||||
return FormatCamelLowerTitle(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelPluralTitle": func(name string) string {
|
||||
return FormatCamelPluralTitle(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelPluralTitleLower": func(name string) string {
|
||||
return FormatCamelPluralTitleLower(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelPluralCamel": func(name string) string {
|
||||
return FormatCamelPluralCamel(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelPluralLower": func(name string) string {
|
||||
return FormatCamelPluralLower(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelPluralUnderscore": func(name string) string {
|
||||
return FormatCamelPluralUnderscore(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelPluralLowerUnderscore": func(name string) string {
|
||||
return FormatCamelPluralLowerUnderscore(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelUnderscore": func(name string) string {
|
||||
return FormatCamelUnderscore(name)
|
||||
},
|
||||
} ,
|
||||
"FormatCamelLowerUnderscore": func(name string) string {
|
||||
return FormatCamelLowerUnderscore(name)
|
||||
},
|
||||
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
|
||||
} ,
|
||||
"FieldTagHasOption": func(f modelField, tagName, optName string ) bool {
|
||||
if f.Tags == nil {
|
||||
return false
|
||||
}
|
||||
@ -143,7 +143,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
}
|
||||
}
|
||||
} else if !ft.HasOption(newVal) {
|
||||
if ft.Name == "" {
|
||||
if ft.Name == ""{
|
||||
ft.Name = newVal
|
||||
} else {
|
||||
ft.Options = append(ft.Options, newVal)
|
||||
@ -165,7 +165,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
// Load the template file using the text/template package.
|
||||
tmpl, err := baseTmpl.Parse(string(dat))
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
|
||||
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
|
||||
return nil, err
|
||||
}
|
||||
@ -189,18 +189,18 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
// Executed the defined template with the given data.
|
||||
var tpl bytes.Buffer
|
||||
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
|
||||
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Format the source code to ensure its valid and code to parsed consistently.
|
||||
codeBytes, err := format.Source(tpl.Bytes())
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
|
||||
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
|
||||
|
||||
dl := []string{}
|
||||
for idx, l := range strings.Split(tpl.String(), "\n") {
|
||||
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
|
||||
dl = append(dl, fmt.Sprintf("%d -> ", idx) + l)
|
||||
}
|
||||
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
|
||||
@ -216,7 +216,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := goparse.ParseLines(codeLines, 0)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
|
||||
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
|
||||
return resp, err
|
||||
}
|
||||
@ -316,7 +316,7 @@ func templateFileOrderedNames(localPath string, names []string) (resp []string,
|
||||
idx := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
|
||||
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -88,6 +88,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
|
||||
|
||||
ld := lineDepth(l)
|
||||
|
||||
|
||||
//fmt.Println("l", l)
|
||||
//fmt.Println("> Depth", ld, "???", depth)
|
||||
|
||||
|
||||
if ld == depth {
|
||||
if strings.HasPrefix(ls, "/*") {
|
||||
multiLine = true
|
||||
@ -108,6 +113,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//fmt.Println("> multiLine", multiLine)
|
||||
//fmt.Println("> multiComment", multiComment)
|
||||
//fmt.Println("> muiliVar", muiliVar)
|
||||
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if multiComment {
|
||||
@ -131,6 +141,8 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
//fmt.Println(" > objLines", objLines)
|
||||
|
||||
obj, err := ParseGoObject(objLines, depth)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
@ -197,8 +209,22 @@ func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
|
||||
|
||||
if strings.HasPrefix(firstStrip, "var") {
|
||||
obj.Type = GoObjectType_Var
|
||||
|
||||
if !strings.HasSuffix(firstStrip, "(") {
|
||||
if strings.HasPrefix(firstStrip, "var ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "var ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
}
|
||||
} else if strings.HasPrefix(firstStrip, "const") {
|
||||
obj.Type = GoObjectType_Const
|
||||
|
||||
if !strings.HasSuffix(firstStrip, "(") {
|
||||
if strings.HasPrefix(firstStrip, "const ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "const ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
}
|
||||
} else if strings.HasPrefix(firstStrip, "func") {
|
||||
obj.Type = GoObjectType_Func
|
||||
|
||||
|
@ -64,7 +64,8 @@ func TestNewDocImports(t *testing.T) {
|
||||
func TestParseLines1(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
|
||||
codeTests := []string{
|
||||
`func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
obj := datamodels.MockModelNew()
|
||||
resp, err := ModelCreate(ctx, DB, &obj)
|
||||
@ -76,15 +77,30 @@ func TestParseLines1(t *testing.T) {
|
||||
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
|
||||
return resp
|
||||
}
|
||||
`
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
`,
|
||||
`var (
|
||||
// ErrNotFound abstracts the postgres not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
`,
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
for _, code := range codeTests {
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
func TestParseLines2(t *testing.T) {
|
||||
|
@ -124,8 +124,8 @@ func main() {
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "dbtable2crud",
|
||||
Aliases: []string{"dbtable2crud"},
|
||||
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -saveChanges=false",
|
||||
Aliases: []string{},
|
||||
Usage: "-table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] ",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "dbtable, table"},
|
||||
cli.StringFlag{Name: "modelFile, modelfile, file"},
|
||||
|
@ -46,15 +46,15 @@ func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
return &m, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "ACL"}}
|
||||
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
{{ if $hasAccountID }}
|
||||
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
|
||||
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
|
||||
if claims.Audience != "" {
|
||||
@ -90,7 +90,7 @@ func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *
|
||||
|
||||
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
|
||||
err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -105,7 +105,7 @@ func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn
|
||||
|
||||
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
|
||||
// 1. No claims, request is internal, no ACL applied
|
||||
{{ if $hasAccountId }}
|
||||
{{ if $hasAccountID }}
|
||||
// 2. All role types can access their user ID
|
||||
{{ end }}
|
||||
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
||||
@ -114,7 +114,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
{{ if $hasAccountID }}
|
||||
query.Where(query.Equal("account_id", claims.Audience))
|
||||
{{ end }}
|
||||
|
||||
@ -226,9 +226,10 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCam
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Create"}}
|
||||
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
|
||||
{{ $createFields := (index $.StructFields $reqName) }}
|
||||
{{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }}
|
||||
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
|
||||
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
|
||||
@ -237,18 +238,18 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
|
||||
if claims.Audience != "" {
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
return nil, errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
{{ if $hasAccountId }}
|
||||
if req.AccountId != "" {
|
||||
{{ if $reqHasAccountID }}
|
||||
if req.AccountID != "" {
|
||||
// Request accountId must match claims.
|
||||
if req.AccountId != claims.Audience {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
if req.AccountID != claims.Audience {
|
||||
return nil, errors.WithStack(ErrForbidden)
|
||||
}
|
||||
} else {
|
||||
// Set the accountId from claims.
|
||||
req.AccountId = claims.Audience
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
{{ end }}
|
||||
}
|
||||
@ -256,7 +257,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err = v.Struct(req)
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -281,6 +282,13 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
|
||||
{{ if and (not $reqHasAccountID) ($hasAccountID) }}
|
||||
// Set the accountId from claims.
|
||||
if claims.Audience != "" && m.AccountID == "" {
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
|
||||
if req.{{ $f.FieldName }} != nil {
|
||||
{{ if eq $f.FieldType "sql.NullString" }}
|
||||
@ -315,7 +323,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &a, nil
|
||||
return &m, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Update"}}
|
||||
|
@ -1,19 +1,11 @@
|
||||
{{ define "imports"}}
|
||||
import (
|
||||
"github.com/lib/pq"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"{{ $.GoSrcPath }}/internal/platform/auth"
|
||||
"{{ $.GoSrcPath }}/internal/platform/tests"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Globals"}}
|
||||
@ -33,7 +25,7 @@ func testMain(m *testing.M) int {
|
||||
{{ define "TestFindRequestQuery"}}
|
||||
// TestFindRequestQuery validates findRequestQuery
|
||||
func TestFindRequestQuery(t *testing.T) {
|
||||
where := "name = ? or address1 = ?"
|
||||
where := "field1 = ? or field2 = ?"
|
||||
var (
|
||||
limit uint = 12
|
||||
offset uint = 34
|
||||
@ -52,7 +44,7 @@ func TestFindRequestQuery(t *testing.T) {
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
}
|
||||
expected := "SELECT " + accountMapColumns + " FROM " + accountTableName + " WHERE (name = ? or address1 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
|
||||
expected := "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
|
||||
|
||||
res, args := findRequestQuery(req)
|
||||
|
||||
@ -63,4 +55,77 @@ func TestFindRequestQuery(t *testing.T) {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "TestApplyClaimsSelect"}}
|
||||
// TestApplyClaimsSelect applyClaimsSelect
|
||||
func TestApplyClaimsSelect(t *testing.T) {
|
||||
var claimTests = []struct {
|
||||
name string
|
||||
claims auth.Claims
|
||||
expectedSql string
|
||||
error error
|
||||
}{
|
||||
{"EmptyClaims",
|
||||
auth.Claims{},
|
||||
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName,
|
||||
nil,
|
||||
},
|
||||
{"RoleAccount",
|
||||
auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: "user1",
|
||||
Audience: "acc1",
|
||||
},
|
||||
},
|
||||
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
|
||||
nil,
|
||||
},
|
||||
{"RoleAdmin",
|
||||
auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: "user1",
|
||||
Audience: "acc1",
|
||||
},
|
||||
},
|
||||
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
|
||||
{
|
||||
for i, tt := range claimTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
query := selectQuery()
|
||||
|
||||
err := applyClaimsSelect(ctx, tt.claims, query)
|
||||
if err != tt.error {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.error)
|
||||
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
|
||||
}
|
||||
|
||||
sql, args := query.Build()
|
||||
|
||||
// Use mysql flavor so placeholders will get replaced for comparison.
|
||||
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
{{ end }}
|
@ -38,7 +38,7 @@ type {{ $f.FieldType }} string
|
||||
const (
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}"
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }} {{ $f.FieldType }} = "{{ $ev }}"
|
||||
{{ end }}
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user