1
0
mirror of https://github.com/raseels-repos/golang-saas-starter-kit.git synced 2025-06-15 00:15:15 +02:00

Completed enough code gen for project ATM

This commit is contained in:
Lee Brown
2019-06-24 04:26:48 -08:00
parent 7b5c2a5807
commit 07e86cfd52
13 changed files with 839 additions and 92 deletions

View File

@ -63,8 +63,8 @@ func TestFindRequestQuery(t *testing.T) {
} }
} }
// TestApplyClaimsSelectvalidates applyClaimsSelect // TestApplyClaimsSelect validates applyClaimsSelect
func TestApplyClaimsSelectvalidates(t *testing.T) { func TestApplyClaimsSelect(t *testing.T) {
var claimTests = []struct { var claimTests = []struct {
name string name string
claims auth.Claims claims auth.Claims

View File

@ -2,45 +2,42 @@ package project
import ( import (
"database/sql/driver" "database/sql/driver"
"time"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"gopkg.in/go-playground/validator.v9" "gopkg.in/go-playground/validator.v9"
"time"
) )
// Project represents a workflow. // Project represents a workflow.
type Project struct { type Project struct {
ID string `json:"id" validate:"required,uuid"` ID string `json:"id" validate:"required,uuid"`
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"` AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
Name string `json:"name" validate:"required"` Name string `json:"name" validate:"required"`
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
CreatedAt time.Time `json:"created_at" truss:"api-read"` CreatedAt time.Time `json:"created_at" truss:"api-read"`
UpdatedAt time.Time `json:"updated_at" truss:"api-read"` UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"` ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
} }
// CreateProjectRequest contains information needed to create a new Project. // ProjectCreateRequest contains information needed to create a new Project.
type ProjectCreateRequest struct { type ProjectCreateRequest struct {
AccountID string `json:"account_id" validate:"required,uuid"` AccountID string `json:"account_id" validate:"required,uuid"`
Name string `json:"name" validate:"required"` Name string `json:"name" validate:"required"`
Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
} }
// UpdateProjectRequest defines what information may be provided to modify an existing // ProjectUpdateRequest defines what information may be provided to modify an existing
// Project. All fields are optional so clients can send just the fields they want // Project. All fields are optional so clients can send just the fields they want
// changed. It uses pointer fields so we can differentiate between a field that // changed. It uses pointer fields so we can differentiate between a field that
// was not provided and a field that was provided as explicitly blank. Normally // was not provided and a field that was provided as explicitly blank.
// we do not want to use pointers to basic types but we make exceptions around
// marshalling/unmarshalling.
type ProjectUpdateRequest struct { type ProjectUpdateRequest struct {
ID string `validate:"required,uuid"` ID string `json:"id" validate:"required,uuid"`
Name *string `json:"name" validate:"omitempty"` Name *string `json:"name" validate:"omitempty"`
Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active pending disabled"` Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
} }
// ProjectFindRequest defines the possible options to search for projects. By default // ProjectFindRequest defines the possible options to search for projects. By default
// archived projects will be excluded from response. // archived project will be excluded from response.
type ProjectFindRequest struct { type ProjectFindRequest struct {
Where *string Where *string
Args []interface{} Args []interface{}
@ -50,20 +47,21 @@ type ProjectFindRequest struct {
IncludedArchived bool IncludedArchived bool
} }
// ProjectStatus represents the status of an project. // ProjectStatus represents the status of project.
type ProjectStatus string type ProjectStatus string
// ProjectStatus values define the status field of a user project. // ProjectStatus values define the status field of project.
const ( const (
// ProjectStatus_Active defines the state when a user can access an project.
// ProjectStatus_Active defines the status of active for project.
ProjectStatus_Active ProjectStatus = "active" ProjectStatus_Active ProjectStatus = "active"
// ProjectStatus_Disabled defines the state when a user has been disabled from // ProjectStatus_Disabled defines the status of disabled for project.
// accessing an project.
ProjectStatus_Disabled ProjectStatus = "disabled" ProjectStatus_Disabled ProjectStatus = "disabled"
) )
// ProjectStatus_Values provides list of valid ProjectStatus values. // ProjectStatus_Values provides list of valid ProjectStatus values.
var ProjectStatus_Values = []ProjectStatus{ var ProjectStatus_Values = []ProjectStatus{
ProjectStatus_Active, ProjectStatus_Active,
ProjectStatus_Disabled, ProjectStatus_Disabled,
} }
@ -74,6 +72,7 @@ func (s *ProjectStatus) Scan(value interface{}) error {
if !ok { if !ok {
return errors.New("Scan source is not []byte") return errors.New("Scan source is not []byte")
} }
*s = ProjectStatus(string(asBytes)) *s = ProjectStatus(string(asBytes))
return nil return nil
} }
@ -81,7 +80,6 @@ func (s *ProjectStatus) Scan(value interface{}) error {
// Value converts the ProjectStatus value to be stored in the database. // Value converts the ProjectStatus value to be stored in the database.
func (s ProjectStatus) Value() (driver.Value, error) { func (s ProjectStatus) Value() (driver.Value, error) {
v := validator.New() v := validator.New()
errs := v.Var(s, "required,oneof=active disabled") errs := v.Var(s, "required,oneof=active disabled")
if errs != nil { if errs != nil {
return nil, errs return nil, errs

View File

@ -0,0 +1,447 @@
package project
import (
"context"
"database/sql"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
"time"
)
const (
// The database table for Project
projectTableName = "projects"
)
var (
// ErrNotFound abstracts the postgres not found error.
ErrNotFound = errors.New("Entity not found")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
)
// projectMapColumns is the list of columns needed for mapRowsToProject
var projectMapColumns = "id,account_id,name,status,created_at,updated_at,archived_at"
// mapRowsToProject takes the SQL rows and maps it to the Project struct
// with the columns defined by projectMapColumns
func mapRowsToProject(rows *sql.Rows) (*Project, error) {
var (
m Project
err error
)
err = rows.Scan(&m.ID, &m.AccountID, &m.Name, &m.Status, &m.CreatedAt, &m.UpdatedAt, &m.ArchivedAt)
if err != nil {
return nil, errors.WithStack(err)
}
return &m, nil
}
// CanReadProject determines if claims has the authority to access the specified project by id.
func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error {
// If the request has claims from a specific project, ensure that the claims
// has the correct access to the project.
if claims.Audience != "" {
// select id from projects where account_id = [accountID]
query := sqlbuilder.NewSelectBuilder().Select("id").From(projectTableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("ID", id),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var id string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id)
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no id returned, then the current claim user does not have access
// to the specified project.
if id == "" {
return errors.WithStack(ErrForbidden)
}
}
return nil
}
// CanModifyProject determines if claims has the authority to modify the specified project by id.
func CanModifyProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error {
err := CanReadProject(ctx, claims, dbConn, id)
if err != nil {
return err
}
// Admin users can update projects they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
return nil
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
// 1. No claims, request is internal, no ACL applied
// 2. All role types can access their user ID
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" {
return nil
}
query.Where(query.Equal("account_id", claims.Audience))
return nil
}
// selectQuery constructs a base select query for Project
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select(projectMapColumns)
query.From(projectTableName)
return query
}
// findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the projects from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args, req.IncludedArchived)
}
// find internal method for getting all the projects from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find")
defer span.Finish()
query.Select(projectMapColumns)
query.From(projectTableName)
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
// Check to see if a sub query needs to be applied for the claims.
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// Fetch all entries from the db.
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find projects failed")
return nil, err
}
// Iterate over each row.
resp := []*Project{}
for rows.Next() {
u, err := mapRowsToProject(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Read gets the specified project from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read")
defer span.Finish()
// Filter base select query by id
query := selectQuery()
query.Where(query.Equal("id", id))
res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived)
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "project %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
// Create inserts a new project into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectCreateRequest, now time.Time) (*Project, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create")
defer span.Finish()
if claims.Audience != "" {
// Admin users can update projects they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return nil, errors.WithStack(ErrForbidden)
}
if req.AccountID != "" {
// Request accountId must match claims.
if req.AccountID != claims.Audience {
return nil, errors.WithStack(ErrForbidden)
}
} else {
// Set the accountId from claims.
req.AccountID = claims.Audience
}
}
v := validator.New()
// Validate the request.
err := v.Struct(req)
if err != nil {
return nil, err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
m := Project{
ID: uuid.NewRandom().String(),
AccountID: req.AccountID,
Name: req.Name,
Status: ProjectStatus_Active,
CreatedAt: now,
UpdatedAt: now,
}
if req.Status != nil {
m.Status = *req.Status
}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto(projectTableName)
query.Cols(
"id",
"account_id",
"name",
"status",
"created_at",
"updated_at",
"archived_at",
)
query.Values(
m.ID,
m.AccountID,
m.Name,
m.Status,
m.CreatedAt,
m.UpdatedAt,
m.ArchivedAt,
)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create project failed")
return nil, err
}
return &m, nil
}
// Update replaces an project in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectUpdateRequest, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update")
defer span.Finish()
v := validator.New()
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the project specified in the request.
err = CanModifyProject(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(projectTableName)
var fields []string
if req.Name != nil {
fields = append(fields, query.Assign("name", req.Name))
}
if req.Status != nil {
fields = append(fields, query.Assign("status", req.Status))
}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
query.Set(fields...)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update project %s failed", req.ID)
return err
}
return nil
}
// Archive soft deleted the project from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the project specified in the request.
err = CanModifyProject(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update(projectTableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive project %s failed", req.ID)
return err
}
return nil
}
// Delete removes an project from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
ID string `validate:"required,uuid"`
}{}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the project specified in the request.
err = CanModifyProject(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom(projectTableName)
query.Where(query.Equal("id", req.ID))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete project %s failed", req.ID)
return err
}
return nil
}

View File

@ -0,0 +1,102 @@
package project
import (
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth"
"geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests"
"github.com/google/go-cmp/cmp"
"github.com/huandu/go-sqlbuilder"
"os"
"testing"
)
var test *tests.Test
// TestMain is the entry point for testing.
func TestMain(m *testing.M) {
os.Exit(testMain(m))
}
func testMain(m *testing.M) int {
test = tests.New()
defer test.TearDown()
return m.Run()
}
// TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) {
where := "field1 = ? or field2 = ?"
var (
limit uint = 12
offset uint = 34
)
req := ProjectFindRequest{
Where: &where,
Args: []interface{}{
"lee brown",
"103 East Main St.",
},
Order: []string{
"id asc",
"created_at desc",
},
Limit: &limit,
Offset: &offset,
}
expected := "SELECT " + projectMapColumns + " FROM " + projectTableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := findRequestQuery(req)
if diff := cmp.Diff(res.String(), expected); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
if diff := cmp.Diff(args, req.Args); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
}
// TestApplyClaimsSelect applyClaimsSelect
func TestApplyClaimsSelect(t *testing.T) {
var claimTests = []struct {
name string
claims auth.Claims
expectedSql string
error error
}{}
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
{
for i, tt := range claimTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
query := selectQuery()
err := applyClaimsSelect(ctx, tt.claims, query)
if err != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
}
sql, args := query.Build()
// Use mysql flavor so placeholders will get replaced for comparison.
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
}
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success)
}
}
}
}

View File

@ -16,7 +16,7 @@ Truss provides code generation to reduce copy/pasting.
go build . go build .
``` ```
### Usage ### Configuration
```bash ```bash
./truss -h ./truss -h
@ -29,5 +29,40 @@ Usage of ./truss
--db_driver string <postgres> --db_driver string <postgres>
--db_timezone string <utc> --db_timezone string <utc>
--db_disabletls bool <false> --db_disabletls bool <false>
```
## Commands:
## dbtable2crud
Used to bootstrap a new business logic package with basic CRUD.
**Usage**
```bash
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false]
```
**Example**
1. Define a new database table in `internal/schema/migrations.go`
2. Create a new file for the base model at `internal/projects/models.go`. Only the following struct needs to be included. All the other times will be generated.
```go
// Project represents a workflow.
type Project struct {
ID string `json:"id" validate:"required,uuid"`
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
Name string `json:"name" validate:"required"`
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
CreatedAt time.Time `json:"created_at" truss:"api-read"`
UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
}
``` ```
3. Run `dbtable2crud`
```bash
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -save=true
```

View File

@ -303,7 +303,57 @@ func updateModelCrudFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, template
continue continue
} }
if crudDoc.HasType(obj.Name, obj.Type) { if obj.Name == "" && (obj.Type == goparse.GoObjectType_Var || obj.Type == goparse.GoObjectType_Const) {
var curDocObj *goparse.GoObject
for _, subObj := range obj.Objects().List() {
for _, do := range crudDoc.Objects().List() {
if do.Name == "" && (do.Type == goparse.GoObjectType_Var || do.Type == goparse.GoObjectType_Const) {
for _, subDocObj := range do.Objects().List() {
if subDocObj.String() == subObj.String() && subObj.Type != goparse.GoObjectType_LineBreak {
curDocObj = do
break
}
}
}
}
}
if curDocObj != nil {
for _, subObj := range obj.Objects().List() {
var hasSubObj bool
for _, subDocObj := range curDocObj.Objects().List() {
if subDocObj.String() == subObj.String() {
hasSubObj = true
break
}
}
if !hasSubObj {
curDocObj.Objects().Add(subObj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
}
}
} else {
// Append comments and line breaks before adding the object
for _, c := range objHeaders {
err := crudDoc.Objects().Add(c)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
return err
}
}
err := crudDoc.Objects().Add(obj)
if err != nil {
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
return err
}
}
} else if crudDoc.HasType(obj.Name, obj.Type) {
cur := crudDoc.Objects().Get(obj.Name, obj.Type) cur := crudDoc.Objects().Get(obj.Name, obj.Type)
newObjs := []*goparse.GoObject{} newObjs := []*goparse.GoObject{}

View File

@ -36,7 +36,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
tempFilePath := filepath.Join(templateDir, filename) tempFilePath := filepath.Join(templateDir, filename)
dat, err := ioutil.ReadFile(tempFilePath) dat, err := ioutil.ReadFile(tempFilePath)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath) err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
return nil, err return nil, err
} }
@ -46,17 +46,17 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
"Concat": func(vals ...string) string { "Concat": func(vals ...string) string {
return strings.Join(vals, "") return strings.Join(vals, "")
}, },
"JoinStrings": func(vals []string, sep string) string { "JoinStrings": func(vals []string, sep string ) string {
return strings.Join(vals, sep) return strings.Join(vals, sep)
}, },
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string { "PrefixAndJoinStrings": func(vals []string, pre, sep string ) string {
l := []string{} l := []string{}
for _, v := range vals { for _, v := range vals {
l = append(l, pre+v) l = append(l, pre + v)
} }
return strings.Join(l, sep) return strings.Join(l, sep)
}, },
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string { "FmtAndJoinStrings": func(vals []string, fmtStr, sep string ) string {
l := []string{} l := []string{}
for _, v := range vals { for _, v := range vals {
l = append(l, fmt.Sprintf(fmtStr, v)) l = append(l, fmt.Sprintf(fmtStr, v))
@ -74,35 +74,35 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
return "id" return "id"
} }
return FormatCamelLower(name) return FormatCamelLower(name)
}, } ,
"FormatCamelLowerTitle": func(name string) string { "FormatCamelLowerTitle": func(name string) string {
return FormatCamelLowerTitle(name) return FormatCamelLowerTitle(name)
}, } ,
"FormatCamelPluralTitle": func(name string) string { "FormatCamelPluralTitle": func(name string) string {
return FormatCamelPluralTitle(name) return FormatCamelPluralTitle(name)
}, } ,
"FormatCamelPluralTitleLower": func(name string) string { "FormatCamelPluralTitleLower": func(name string) string {
return FormatCamelPluralTitleLower(name) return FormatCamelPluralTitleLower(name)
}, } ,
"FormatCamelPluralCamel": func(name string) string { "FormatCamelPluralCamel": func(name string) string {
return FormatCamelPluralCamel(name) return FormatCamelPluralCamel(name)
}, } ,
"FormatCamelPluralLower": func(name string) string { "FormatCamelPluralLower": func(name string) string {
return FormatCamelPluralLower(name) return FormatCamelPluralLower(name)
}, } ,
"FormatCamelPluralUnderscore": func(name string) string { "FormatCamelPluralUnderscore": func(name string) string {
return FormatCamelPluralUnderscore(name) return FormatCamelPluralUnderscore(name)
}, } ,
"FormatCamelPluralLowerUnderscore": func(name string) string { "FormatCamelPluralLowerUnderscore": func(name string) string {
return FormatCamelPluralLowerUnderscore(name) return FormatCamelPluralLowerUnderscore(name)
}, } ,
"FormatCamelUnderscore": func(name string) string { "FormatCamelUnderscore": func(name string) string {
return FormatCamelUnderscore(name) return FormatCamelUnderscore(name)
}, } ,
"FormatCamelLowerUnderscore": func(name string) string { "FormatCamelLowerUnderscore": func(name string) string {
return FormatCamelLowerUnderscore(name) return FormatCamelLowerUnderscore(name)
}, } ,
"FieldTagHasOption": func(f modelField, tagName, optName string) bool { "FieldTagHasOption": func(f modelField, tagName, optName string ) bool {
if f.Tags == nil { if f.Tags == nil {
return false return false
} }
@ -143,7 +143,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
} }
} }
} else if !ft.HasOption(newVal) { } else if !ft.HasOption(newVal) {
if ft.Name == "" { if ft.Name == ""{
ft.Name = newVal ft.Name = newVal
} else { } else {
ft.Options = append(ft.Options, newVal) ft.Options = append(ft.Options, newVal)
@ -165,7 +165,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
// Load the template file using the text/template package. // Load the template file using the text/template package.
tmpl, err := baseTmpl.Parse(string(dat)) tmpl, err := baseTmpl.Parse(string(dat))
if err != nil { if err != nil {
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath) err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat)) log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
return nil, err return nil, err
} }
@ -189,18 +189,18 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
// Executed the defined template with the given data. // Executed the defined template with the given data.
var tpl bytes.Buffer var tpl bytes.Buffer
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil { if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath) err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
return resp, err return resp, err
} }
// Format the source code to ensure its valid and code to parsed consistently. // Format the source code to ensure its valid and code to parsed consistently.
codeBytes, err := format.Source(tpl.Bytes()) codeBytes, err := format.Source(tpl.Bytes())
if err != nil { if err != nil {
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename) err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
dl := []string{} dl := []string{}
for idx, l := range strings.Split(tpl.String(), "\n") { for idx, l := range strings.Split(tpl.String(), "\n") {
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l) dl = append(dl, fmt.Sprintf("%d -> ", idx) + l)
} }
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n")) log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
@ -216,7 +216,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename
// Parse the code lines into a set of objects. // Parse the code lines into a set of objects.
objs, err := goparse.ParseLines(codeLines, 0) objs, err := goparse.ParseLines(codeLines, 0)
if err != nil { if err != nil {
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename) err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr) log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
return resp, err return resp, err
} }
@ -316,7 +316,7 @@ func templateFileOrderedNames(localPath string, names []string) (resp []string,
idx := 0 idx := 0
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") { if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
continue continue
} }

View File

@ -88,6 +88,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
ld := lineDepth(l) ld := lineDepth(l)
//fmt.Println("l", l)
//fmt.Println("> Depth", ld, "???", depth)
if ld == depth { if ld == depth {
if strings.HasPrefix(ls, "/*") { if strings.HasPrefix(ls, "/*") {
multiLine = true multiLine = true
@ -108,6 +113,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
} }
} }
//fmt.Println("> multiLine", multiLine)
//fmt.Println("> multiComment", multiComment)
//fmt.Println("> muiliVar", muiliVar)
objLines = append(objLines, l) objLines = append(objLines, l)
if multiComment { if multiComment {
@ -131,6 +141,8 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
} }
} }
//fmt.Println(" > objLines", objLines)
obj, err := ParseGoObject(objLines, depth) obj, err := ParseGoObject(objLines, depth)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
@ -197,8 +209,22 @@ func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
if strings.HasPrefix(firstStrip, "var") { if strings.HasPrefix(firstStrip, "var") {
obj.Type = GoObjectType_Var obj.Type = GoObjectType_Var
if !strings.HasSuffix(firstStrip, "(") {
if strings.HasPrefix(firstStrip, "var ") {
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "var ", "", 1))
}
obj.Name = strings.Split(firstStrip, " ")[0]
}
} else if strings.HasPrefix(firstStrip, "const") { } else if strings.HasPrefix(firstStrip, "const") {
obj.Type = GoObjectType_Const obj.Type = GoObjectType_Const
if !strings.HasSuffix(firstStrip, "(") {
if strings.HasPrefix(firstStrip, "const ") {
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "const ", "", 1))
}
obj.Name = strings.Split(firstStrip, " ")[0]
}
} else if strings.HasPrefix(firstStrip, "func") { } else if strings.HasPrefix(firstStrip, "func") {
obj.Type = GoObjectType_Func obj.Type = GoObjectType_Func

View File

@ -64,7 +64,8 @@ func TestNewDocImports(t *testing.T) {
func TestParseLines1(t *testing.T) { func TestParseLines1(t *testing.T) {
g := gomega.NewGomegaWithT(t) g := gomega.NewGomegaWithT(t)
code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model { codeTests := []string{
`func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
g := gomega.NewGomegaWithT(t) g := gomega.NewGomegaWithT(t)
obj := datamodels.MockModelNew() obj := datamodels.MockModelNew()
resp, err := ModelCreate(ctx, DB, &obj) resp, err := ModelCreate(ctx, DB, &obj)
@ -76,15 +77,30 @@ func TestParseLines1(t *testing.T) {
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active)) g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
return resp return resp
} }
` `,
lines := strings.Split(code, "\n") `var (
// ErrNotFound abstracts the postgres not found error.
objs, err := ParseLines(lines, 0) ErrNotFound = errors.New("Entity not found")
if err != nil { // ErrInvalidID occurs when an ID is not in a valid form.
t.Fatalf("got error %v", err) ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
`,
} }
g.Expect(objs.Lines()).Should(gomega.Equal(lines)) for _, code := range codeTests {
lines := strings.Split(code, "\n")
objs, err := ParseLines(lines, 0)
if err != nil {
t.Fatalf("got error %v", err)
}
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
}
} }
func TestParseLines2(t *testing.T) { func TestParseLines2(t *testing.T) {

View File

@ -124,8 +124,8 @@ func main() {
app.Commands = []cli.Command{ app.Commands = []cli.Command{
{ {
Name: "dbtable2crud", Name: "dbtable2crud",
Aliases: []string{"dbtable2crud"}, Aliases: []string{},
Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -saveChanges=false", Usage: "-table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] ",
Flags: []cli.Flag{ Flags: []cli.Flag{
cli.StringFlag{Name: "dbtable, table"}, cli.StringFlag{Name: "dbtable, table"},
cli.StringFlag{Name: "modelFile, modelfile, file"}, cli.StringFlag{Name: "modelFile, modelfile, file"},

View File

@ -46,15 +46,15 @@ func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
return &a, nil return &m, nil
} }
{{ end }} {{ end }}
{{ define "ACL"}} {{ define "ACL"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }} {{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. // CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
{{ if $hasAccountId }} {{ if $hasAccountID }}
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims // If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
// has the correct access to the {{ FormatCamelLower $.Model.Name }}. // has the correct access to the {{ FormatCamelLower $.Model.Name }}.
if claims.Audience != "" { if claims.Audience != "" {
@ -90,7 +90,7 @@ func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. // CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }}) err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
if err != nil { if err != nil {
return err return err
} }
@ -105,7 +105,7 @@ func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided. // applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
// 1. No claims, request is internal, no ACL applied // 1. No claims, request is internal, no ACL applied
{{ if $hasAccountId }} {{ if $hasAccountID }}
// 2. All role types can access their user ID // 2. All role types can access their user ID
{{ end }} {{ end }}
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error { func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
@ -114,7 +114,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde
return nil return nil
} }
{{ if $hasAccountId }} {{ if $hasAccountID }}
query.Where(query.Equal("account_id", claims.Audience)) query.Where(query.Equal("account_id", claims.Audience))
{{ end }} {{ end }}
@ -226,9 +226,10 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCam
} }
{{ end }} {{ end }}
{{ define "Create"}} {{ define "Create"}}
{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }} {{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
{{ $reqName := (Concat $.Model.Name "CreateRequest") }} {{ $reqName := (Concat $.Model.Name "CreateRequest") }}
{{ $createFields := (index $.StructFields $reqName) }} {{ $createFields := (index $.StructFields $reqName) }}
{{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }}
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database. // Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) { func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create") span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
@ -237,18 +238,18 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
if claims.Audience != "" { if claims.Audience != "" {
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to. // Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) { if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden) return nil, errors.WithStack(ErrForbidden)
} }
{{ if $hasAccountId }} {{ if $reqHasAccountID }}
if req.AccountId != "" { if req.AccountID != "" {
// Request accountId must match claims. // Request accountId must match claims.
if req.AccountId != claims.Audience { if req.AccountID != claims.Audience {
return errors.WithStack(ErrForbidden) return nil, errors.WithStack(ErrForbidden)
} }
} else { } else {
// Set the accountId from claims. // Set the accountId from claims.
req.AccountId = claims.Audience req.AccountID = claims.Audience
} }
{{ end }} {{ end }}
} }
@ -256,7 +257,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
v := validator.New() v := validator.New()
// Validate the request. // Validate the request.
err = v.Struct(req) err := v.Struct(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -281,6 +282,13 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
{{ end }}{{ end }} {{ end }}{{ end }}
} }
{{ if and (not $reqHasAccountID) ($hasAccountID) }}
// Set the accountId from claims.
if claims.Audience != "" && m.AccountID == "" {
req.AccountID = claims.Audience
}
{{ end }}
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }} {{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
if req.{{ $f.FieldName }} != nil { if req.{{ $f.FieldName }} != nil {
{{ if eq $f.FieldType "sql.NullString" }} {{ if eq $f.FieldType "sql.NullString" }}
@ -315,7 +323,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re
return nil, err return nil, err
} }
return &a, nil return &m, nil
} }
{{ end }} {{ end }}
{{ define "Update"}} {{ define "Update"}}

View File

@ -1,19 +1,11 @@
{{ define "imports"}} {{ define "imports"}}
import ( import (
"github.com/lib/pq"
"math/rand"
"os"
"strings"
"testing"
"time"
"{{ $.GoSrcPath }}/internal/platform/auth" "{{ $.GoSrcPath }}/internal/platform/auth"
"{{ $.GoSrcPath }}/internal/platform/tests" "{{ $.GoSrcPath }}/internal/platform/tests"
"github.com/dgrijalva/jwt-go"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/huandu/go-sqlbuilder" "github.com/huandu/go-sqlbuilder"
"github.com/pborman/uuid" "os"
"github.com/pkg/errors" "testing"
) )
{{ end }} {{ end }}
{{ define "Globals"}} {{ define "Globals"}}
@ -33,7 +25,7 @@ func testMain(m *testing.M) int {
{{ define "TestFindRequestQuery"}} {{ define "TestFindRequestQuery"}}
// TestFindRequestQuery validates findRequestQuery // TestFindRequestQuery validates findRequestQuery
func TestFindRequestQuery(t *testing.T) { func TestFindRequestQuery(t *testing.T) {
where := "name = ? or address1 = ?" where := "field1 = ? or field2 = ?"
var ( var (
limit uint = 12 limit uint = 12
offset uint = 34 offset uint = 34
@ -52,7 +44,7 @@ func TestFindRequestQuery(t *testing.T) {
Limit: &limit, Limit: &limit,
Offset: &offset, Offset: &offset,
} }
expected := "SELECT " + accountMapColumns + " FROM " + accountTableName + " WHERE (name = ? or address1 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" expected := "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
res, args := findRequestQuery(req) res, args := findRequestQuery(req)
@ -63,4 +55,77 @@ func TestFindRequestQuery(t *testing.T) {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
} }
} }
{{ end }}
{{ define "TestApplyClaimsSelect"}}
// TestApplyClaimsSelect applyClaimsSelect
func TestApplyClaimsSelect(t *testing.T) {
var claimTests = []struct {
name string
claims auth.Claims
expectedSql string
error error
}{
{"EmptyClaims",
auth.Claims{},
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName,
nil,
},
{"RoleAccount",
auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: "user1",
Audience: "acc1",
},
},
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
nil,
},
{"RoleAdmin",
auth.Claims{
Roles: []string{auth.RoleAdmin},
StandardClaims: jwt.StandardClaims{
Subject: "user1",
Audience: "acc1",
},
},
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
nil,
},
}
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
{
for i, tt := range claimTests {
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
{
ctx := tests.Context()
query := selectQuery()
err := applyClaimsSelect(ctx, tt.claims, query)
if err != tt.error {
t.Logf("\t\tGot : %+v", err)
t.Logf("\t\tWant: %+v", tt.error)
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
}
sql, args := query.Build()
// Use mysql flavor so placeholders will get replaced for comparison.
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
if err != nil {
t.Log("\t\tGot :", err)
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
}
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
}
t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success)
}
}
}
}
{{ end }} {{ end }}

View File

@ -38,7 +38,7 @@ type {{ $f.FieldType }} string
const ( const (
{{ range $evk, $ev := $f.DbColumn.EnumValues }} {{ range $evk, $ev := $f.DbColumn.EnumValues }}
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}. // {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
{{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}" {{ $f.FieldType }}_{{ FormatCamel $ev }} {{ $f.FieldType }} = "{{ $ev }}"
{{ end }} {{ end }}
) )