diff --git a/example-project/internal/account/account_test.go b/example-project/internal/account/account_test.go index eb04452..9b98ae9 100644 --- a/example-project/internal/account/account_test.go +++ b/example-project/internal/account/account_test.go @@ -63,8 +63,8 @@ func TestFindRequestQuery(t *testing.T) { } } -// TestApplyClaimsSelectvalidates applyClaimsSelect -func TestApplyClaimsSelectvalidates(t *testing.T) { +// TestApplyClaimsSelect validates applyClaimsSelect +func TestApplyClaimsSelect(t *testing.T) { var claimTests = []struct { name string claims auth.Claims diff --git a/example-project/internal/project/models.go b/example-project/internal/project/models.go index db501c3..e11ec8e 100644 --- a/example-project/internal/project/models.go +++ b/example-project/internal/project/models.go @@ -2,45 +2,42 @@ package project import ( "database/sql/driver" - "time" - "github.com/lib/pq" "github.com/pkg/errors" "gopkg.in/go-playground/validator.v9" + "time" ) // Project represents a workflow. type Project struct { - ID string `json:"id" validate:"required,uuid"` - AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"` - Name string `json:"name" validate:"required"` - Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` - CreatedAt time.Time `json:"created_at" truss:"api-read"` - UpdatedAt time.Time `json:"updated_at" truss:"api-read"` - ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"` + ID string `json:"id" validate:"required,uuid"` + AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"` + Name string `json:"name" validate:"required"` + Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` + CreatedAt time.Time `json:"created_at" truss:"api-read"` + UpdatedAt time.Time `json:"updated_at" truss:"api-read"` + ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"` } -// CreateProjectRequest contains information needed to create a new Project. +// ProjectCreateRequest contains information needed to create a new Project. type ProjectCreateRequest struct { - AccountID string `json:"account_id" validate:"required,uuid"` - Name string `json:"name" validate:"required"` - Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` + AccountID string `json:"account_id" validate:"required,uuid"` + Name string `json:"name" validate:"required"` + Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` } -// UpdateProjectRequest defines what information may be provided to modify an existing +// ProjectUpdateRequest defines what information may be provided to modify an existing // Project. All fields are optional so clients can send just the fields they want // changed. It uses pointer fields so we can differentiate between a field that -// was not provided and a field that was provided as explicitly blank. Normally -// we do not want to use pointers to basic types but we make exceptions around -// marshalling/unmarshalling. +// was not provided and a field that was provided as explicitly blank. type ProjectUpdateRequest struct { - ID string `validate:"required,uuid"` - Name *string `json:"name" validate:"omitempty"` - Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active pending disabled"` + ID string `json:"id" validate:"required,uuid"` + Name *string `json:"name" validate:"omitempty"` + Status *ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` } // ProjectFindRequest defines the possible options to search for projects. By default -// archived projects will be excluded from response. +// archived project will be excluded from response. type ProjectFindRequest struct { Where *string Args []interface{} @@ -50,20 +47,21 @@ type ProjectFindRequest struct { IncludedArchived bool } -// ProjectStatus represents the status of an project. +// ProjectStatus represents the status of project. type ProjectStatus string -// ProjectStatus values define the status field of a user project. +// ProjectStatus values define the status field of project. const ( - // ProjectStatus_Active defines the state when a user can access an project. + + // ProjectStatus_Active defines the status of active for project. ProjectStatus_Active ProjectStatus = "active" - // ProjectStatus_Disabled defines the state when a user has been disabled from - // accessing an project. + // ProjectStatus_Disabled defines the status of disabled for project. ProjectStatus_Disabled ProjectStatus = "disabled" ) // ProjectStatus_Values provides list of valid ProjectStatus values. var ProjectStatus_Values = []ProjectStatus{ + ProjectStatus_Active, ProjectStatus_Disabled, } @@ -74,6 +72,7 @@ func (s *ProjectStatus) Scan(value interface{}) error { if !ok { return errors.New("Scan source is not []byte") } + *s = ProjectStatus(string(asBytes)) return nil } @@ -81,7 +80,6 @@ func (s *ProjectStatus) Scan(value interface{}) error { // Value converts the ProjectStatus value to be stored in the database. func (s ProjectStatus) Value() (driver.Value, error) { v := validator.New() - errs := v.Var(s, "required,oneof=active disabled") if errs != nil { return nil, errs diff --git a/example-project/internal/project/project.go b/example-project/internal/project/project.go new file mode 100644 index 0000000..809f5d8 --- /dev/null +++ b/example-project/internal/project/project.go @@ -0,0 +1,447 @@ +package project + +import ( + "context" + "database/sql" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "github.com/huandu/go-sqlbuilder" + "github.com/jmoiron/sqlx" + "github.com/pborman/uuid" + "github.com/pkg/errors" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" + "gopkg.in/go-playground/validator.v9" + "time" +) + +const ( + // The database table for Project + projectTableName = "projects" +) + +var ( + // ErrNotFound abstracts the postgres not found error. + ErrNotFound = errors.New("Entity not found") + // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. + ErrForbidden = errors.New("Attempted action is not allowed") + // ErrInvalidID occurs when an ID is not in a valid form. + ErrInvalidID = errors.New("ID is not in its proper form") +) + +// projectMapColumns is the list of columns needed for mapRowsToProject +var projectMapColumns = "id,account_id,name,status,created_at,updated_at,archived_at" + +// mapRowsToProject takes the SQL rows and maps it to the Project struct +// with the columns defined by projectMapColumns +func mapRowsToProject(rows *sql.Rows) (*Project, error) { + var ( + m Project + err error + ) + + err = rows.Scan(&m.ID, &m.AccountID, &m.Name, &m.Status, &m.CreatedAt, &m.UpdatedAt, &m.ArchivedAt) + if err != nil { + return nil, errors.WithStack(err) + } + + return &m, nil +} + +// CanReadProject determines if claims has the authority to access the specified project by id. +func CanReadProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { + + // If the request has claims from a specific project, ensure that the claims + // has the correct access to the project. + if claims.Audience != "" { + // select id from projects where account_id = [accountID] + query := sqlbuilder.NewSelectBuilder().Select("id").From(projectTableName) + query.Where(query.And( + query.Equal("account_id", claims.Audience), + query.Equal("ID", id), + )) + + queryStr, args := query.Build() + queryStr = dbConn.Rebind(queryStr) + var id string + err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&id) + if err != nil && err != sql.ErrNoRows { + err = errors.Wrapf(err, "query - %s", query.String()) + return err + } + + // When there is no id returned, then the current claim user does not have access + // to the specified project. + if id == "" { + return errors.WithStack(ErrForbidden) + } + + } + + return nil +} + +// CanModifyProject determines if claims has the authority to modify the specified project by id. +func CanModifyProject(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { + err := CanReadProject(ctx, claims, dbConn, id) + if err != nil { + return err + } + + // Admin users can update projects they have access to. + if !claims.HasRole(auth.RoleAdmin) { + return errors.WithStack(ErrForbidden) + } + + return nil +} + +// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided. +// 1. No claims, request is internal, no ACL applied + +// 2. All role types can access their user ID + +func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error { + // Claims are empty, don't apply any ACL + if claims.Audience == "" { + return nil + } + + query.Where(query.Equal("account_id", claims.Audience)) + return nil +} + +// selectQuery constructs a base select query for Project +func selectQuery() *sqlbuilder.SelectBuilder { + query := sqlbuilder.NewSelectBuilder() + query.Select(projectMapColumns) + query.From(projectTableName) + return query +} + +// findRequestQuery generates the select query for the given find request. +// TODO: Need to figure out why can't parse the args when appending the where +// to the query. +func findRequestQuery(req ProjectFindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { + query := selectQuery() + if req.Where != nil { + query.Where(query.And(*req.Where)) + } + + if len(req.Order) > 0 { + query.OrderBy(req.Order...) + } + + if req.Limit != nil { + query.Limit(int(*req.Limit)) + } + + if req.Offset != nil { + query.Offset(int(*req.Offset)) + } + + return query, req.Args +} + +// Find gets all the projects from the database based on the request params. +func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectFindRequest) ([]*Project, error) { + query, args := findRequestQuery(req) + return find(ctx, claims, dbConn, query, args, req.IncludedArchived) +} + +// find internal method for getting all the projects from the database using a select query. +func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}, includedArchived bool) ([]*Project, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Find") + defer span.Finish() + query.Select(projectMapColumns) + query.From(projectTableName) + if !includedArchived { + query.Where(query.IsNull("archived_at")) + } + + // Check to see if a sub query needs to be applied for the claims. + err := applyClaimsSelect(ctx, claims, query) + if err != nil { + return nil, err + } + + queryStr, queryArgs := query.Build() + queryStr = dbConn.Rebind(queryStr) + args = append(args, queryArgs...) + // Fetch all entries from the db. + rows, err := dbConn.QueryContext(ctx, queryStr, args...) + if err != nil { + err = errors.Wrapf(err, "query - %s", query.String()) + err = errors.WithMessage(err, "find projects failed") + return nil, err + } + + // Iterate over each row. + resp := []*Project{} + for rows.Next() { + u, err := mapRowsToProject(rows) + if err != nil { + err = errors.Wrapf(err, "query - %s", query.String()) + return nil, err + } + + resp = append(resp, u) + } + + return resp, nil +} + +// Read gets the specified project from the database. +func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, includedArchived bool) (*Project, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Read") + defer span.Finish() + // Filter base select query by id + query := selectQuery() + query.Where(query.Equal("id", id)) + res, err := find(ctx, claims, dbConn, query, []interface{}{}, includedArchived) + if err != nil { + return nil, err + } else if res == nil || len(res) == 0 { + err = errors.WithMessagef(ErrNotFound, "project %s not found", id) + return nil, err + } + + u := res[0] + return u, nil +} + +// Create inserts a new project into the database. +func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectCreateRequest, now time.Time) (*Project, error) { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Create") + defer span.Finish() + if claims.Audience != "" { + // Admin users can update projects they have access to. + if !claims.HasRole(auth.RoleAdmin) { + return nil, errors.WithStack(ErrForbidden) + } + + if req.AccountID != "" { + // Request accountId must match claims. + if req.AccountID != claims.Audience { + return nil, errors.WithStack(ErrForbidden) + } + + } else { + // Set the accountId from claims. + req.AccountID = claims.Audience + } + + } + + v := validator.New() + // Validate the request. + err := v.Struct(req) + if err != nil { + return nil, err + } + + // If now empty set it to the current time. + if now.IsZero() { + now = time.Now() + } + + // Always store the time as UTC. + now = now.UTC() + // Postgres truncates times to milliseconds when storing. We and do the same + // here so the value we return is consistent with what we store. + now = now.Truncate(time.Millisecond) + m := Project{ + ID: uuid.NewRandom().String(), + AccountID: req.AccountID, + Name: req.Name, + Status: ProjectStatus_Active, + CreatedAt: now, + UpdatedAt: now, + } + + if req.Status != nil { + m.Status = *req.Status + } + + // Build the insert SQL statement. + query := sqlbuilder.NewInsertBuilder() + query.InsertInto(projectTableName) + query.Cols( + "id", + "account_id", + "name", + "status", + "created_at", + "updated_at", + "archived_at", + ) + + query.Values( + m.ID, + m.AccountID, + m.Name, + m.Status, + m.CreatedAt, + m.UpdatedAt, + m.ArchivedAt, + ) + + // Execute the query with the provided context. + sql, args := query.Build() + sql = dbConn.Rebind(sql) + _, err = dbConn.ExecContext(ctx, sql, args...) + if err != nil { + err = errors.Wrapf(err, "query - %s", query.String()) + err = errors.WithMessage(err, "create project failed") + return nil, err + } + + return &m, nil +} + +// Update replaces an project in the database. +func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req ProjectUpdateRequest, now time.Time) error { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Update") + defer span.Finish() + v := validator.New() + // Validate the request. + err := v.Struct(req) + if err != nil { + return err + } + + // Ensure the claims can modify the project specified in the request. + err = CanModifyProject(ctx, claims, dbConn, req.ID) + if err != nil { + return err + } + + // If now empty set it to the current time. + if now.IsZero() { + now = time.Now() + } + + // Always store the time as UTC. + now = now.UTC() + // Postgres truncates times to milliseconds when storing. We and do the same + // here so the value we return is consistent with what we store. + now = now.Truncate(time.Millisecond) + // Build the update SQL statement. + query := sqlbuilder.NewUpdateBuilder() + query.Update(projectTableName) + var fields []string + if req.Name != nil { + fields = append(fields, query.Assign("name", req.Name)) + } + + if req.Status != nil { + fields = append(fields, query.Assign("status", req.Status)) + } + + // If there's nothing to update we can quit early. + if len(fields) == 0 { + return nil + } + + // Append the updated_at field + fields = append(fields, query.Assign("updated_at", now)) + query.Set(fields...) + query.Where(query.Equal("id", req.ID)) + // Execute the query with the provided context. + sql, args := query.Build() + sql = dbConn.Rebind(sql) + _, err = dbConn.ExecContext(ctx, sql, args...) + if err != nil { + err = errors.Wrapf(err, "query - %s", query.String()) + err = errors.WithMessagef(err, "update project %s failed", req.ID) + return err + } + + return nil +} + +// Archive soft deleted the project from the database. +func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string, now time.Time) error { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Archive") + defer span.Finish() + // Defines the struct to apply validation + req := struct { + ID string `validate:"required,uuid"` + }{} + // Validate the request. + err := validator.New().Struct(req) + if err != nil { + return err + } + + // Ensure the claims can modify the project specified in the request. + err = CanModifyProject(ctx, claims, dbConn, req.ID) + if err != nil { + return err + } + + // If now empty set it to the current time. + if now.IsZero() { + now = time.Now() + } + + // Always store the time as UTC. + now = now.UTC() + // Postgres truncates times to milliseconds when storing. We and do the same + // here so the value we return is consistent with what we store. + now = now.Truncate(time.Millisecond) + // Build the update SQL statement. + query := sqlbuilder.NewUpdateBuilder() + query.Update(projectTableName) + query.Set( + query.Assign("archived_at", now), + ) + + query.Where(query.Equal("id", req.ID)) + // Execute the query with the provided context. + sql, args := query.Build() + sql = dbConn.Rebind(sql) + _, err = dbConn.ExecContext(ctx, sql, args...) + if err != nil { + err = errors.Wrapf(err, "query - %s", query.String()) + err = errors.WithMessagef(err, "archive project %s failed", req.ID) + return err + } + + return nil +} + +// Delete removes an project from the database. +func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, id string) error { + span, ctx := tracer.StartSpanFromContext(ctx, "internal.project.Delete") + defer span.Finish() + // Defines the struct to apply validation + req := struct { + ID string `validate:"required,uuid"` + }{} + // Validate the request. + err := validator.New().Struct(req) + if err != nil { + return err + } + + // Ensure the claims can modify the project specified in the request. + err = CanModifyProject(ctx, claims, dbConn, req.ID) + if err != nil { + return err + } + + // Build the delete SQL statement. + query := sqlbuilder.NewDeleteBuilder() + query.DeleteFrom(projectTableName) + query.Where(query.Equal("id", req.ID)) + // Execute the query with the provided context. + sql, args := query.Build() + sql = dbConn.Rebind(sql) + _, err = dbConn.ExecContext(ctx, sql, args...) + if err != nil { + err = errors.Wrapf(err, "query - %s", query.String()) + err = errors.WithMessagef(err, "delete project %s failed", req.ID) + return err + } + + return nil +} diff --git a/example-project/internal/project/project_test.go b/example-project/internal/project/project_test.go new file mode 100644 index 0000000..af96da4 --- /dev/null +++ b/example-project/internal/project/project_test.go @@ -0,0 +1,102 @@ +package project + +import ( + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/auth" + "geeks-accelerator/oss/saas-starter-kit/example-project/internal/platform/tests" + "github.com/google/go-cmp/cmp" + "github.com/huandu/go-sqlbuilder" + "os" + "testing" +) + +var test *tests.Test + +// TestMain is the entry point for testing. +func TestMain(m *testing.M) { + os.Exit(testMain(m)) +} + +func testMain(m *testing.M) int { + test = tests.New() + defer test.TearDown() + return m.Run() +} + +// TestFindRequestQuery validates findRequestQuery +func TestFindRequestQuery(t *testing.T) { + where := "field1 = ? or field2 = ?" + var ( + limit uint = 12 + offset uint = 34 + ) + + req := ProjectFindRequest{ + Where: &where, + Args: []interface{}{ + "lee brown", + "103 East Main St.", + }, + + Order: []string{ + "id asc", + "created_at desc", + }, + + Limit: &limit, + Offset: &offset, + } + + expected := "SELECT " + projectMapColumns + " FROM " + projectTableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" + res, args := findRequestQuery(req) + if diff := cmp.Diff(res.String(), expected); diff != "" { + t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) + } + + if diff := cmp.Diff(args, req.Args); diff != "" { + t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) + } + +} + +// TestApplyClaimsSelect applyClaimsSelect +func TestApplyClaimsSelect(t *testing.T) { + var claimTests = []struct { + name string + claims auth.Claims + expectedSql string + error error + }{} + t.Log("Given the need to validate ACLs are enforced by claims to a select query.") + { + for i, tt := range claimTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + ctx := tests.Context() + query := selectQuery() + err := applyClaimsSelect(ctx, tt.claims, query) + if err != tt.error { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.error) + t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed) + } + + sql, args := query.Build() + // Use mysql flavor so placeholders will get replaced for comparison. + sql, err = sqlbuilder.MySQL.Interpolate(sql, args) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed) + } + + if diff := cmp.Diff(sql, tt.expectedSql); diff != "" { + t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) + } + + t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success) + } + + } + + } + +} diff --git a/example-project/tools/truss/README.md b/example-project/tools/truss/README.md index 3886e3b..46c5522 100644 --- a/example-project/tools/truss/README.md +++ b/example-project/tools/truss/README.md @@ -16,7 +16,7 @@ Truss provides code generation to reduce copy/pasting. go build . ``` -### Usage +### Configuration ```bash ./truss -h @@ -29,5 +29,40 @@ Usage of ./truss --db_driver string --db_timezone string --db_disabletls bool +``` + +## Commands: + +## dbtable2crud + +Used to bootstrap a new business logic package with basic CRUD. + +**Usage** +```bash +./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] +``` + +**Example** +1. Define a new database table in `internal/schema/migrations.go` + + +2. Create a new file for the base model at `internal/projects/models.go`. Only the following struct needs to be included. All the other times will be generated. +```go +// Project represents a workflow. +type Project struct { + ID string `json:"id" validate:"required,uuid"` + AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"` + Name string `json:"name" validate:"required"` + Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"` + CreatedAt time.Time `json:"created_at" truss:"api-read"` + UpdatedAt time.Time `json:"updated_at" truss:"api-read"` + ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"` +} ``` +3. Run `dbtable2crud` +```bash +./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -save=true +``` + + diff --git a/example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go b/example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go index 2c6da6f..0e1adcd 100644 --- a/example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go +++ b/example-project/tools/truss/cmd/dbtable2crud/dbtable2crud.go @@ -303,7 +303,57 @@ func updateModelCrudFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, template continue } - if crudDoc.HasType(obj.Name, obj.Type) { + if obj.Name == "" && (obj.Type == goparse.GoObjectType_Var || obj.Type == goparse.GoObjectType_Const) { + var curDocObj *goparse.GoObject + for _, subObj := range obj.Objects().List() { + for _, do := range crudDoc.Objects().List() { + if do.Name == "" && (do.Type == goparse.GoObjectType_Var || do.Type == goparse.GoObjectType_Const) { + for _, subDocObj := range do.Objects().List() { + if subDocObj.String() == subObj.String() && subObj.Type != goparse.GoObjectType_LineBreak { + curDocObj = do + break + } + + } + } + } + } + + if curDocObj != nil { + for _, subObj := range obj.Objects().List() { + var hasSubObj bool + for _, subDocObj := range curDocObj.Objects().List() { + if subDocObj.String() == subObj.String() { + hasSubObj = true + break + } + } + + if !hasSubObj { + curDocObj.Objects().Add(subObj) + if err != nil { + err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name) + return err + } + } + } + } else { + // Append comments and line breaks before adding the object + for _, c := range objHeaders { + err := crudDoc.Objects().Add(c) + if err != nil { + err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name) + return err + } + } + + err := crudDoc.Objects().Add(obj) + if err != nil { + err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name) + return err + } + } + } else if crudDoc.HasType(obj.Name, obj.Type) { cur := crudDoc.Objects().Get(obj.Name, obj.Type) newObjs := []*goparse.GoObject{} diff --git a/example-project/tools/truss/cmd/dbtable2crud/templates.go b/example-project/tools/truss/cmd/dbtable2crud/templates.go index 611c9fe..89720f6 100644 --- a/example-project/tools/truss/cmd/dbtable2crud/templates.go +++ b/example-project/tools/truss/cmd/dbtable2crud/templates.go @@ -36,7 +36,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename tempFilePath := filepath.Join(templateDir, filename) dat, err := ioutil.ReadFile(tempFilePath) if err != nil { - err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath) + err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath) return nil, err } @@ -46,17 +46,17 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename "Concat": func(vals ...string) string { return strings.Join(vals, "") }, - "JoinStrings": func(vals []string, sep string) string { + "JoinStrings": func(vals []string, sep string ) string { return strings.Join(vals, sep) }, - "PrefixAndJoinStrings": func(vals []string, pre, sep string) string { + "PrefixAndJoinStrings": func(vals []string, pre, sep string ) string { l := []string{} for _, v := range vals { - l = append(l, pre+v) + l = append(l, pre + v) } return strings.Join(l, sep) }, - "FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string { + "FmtAndJoinStrings": func(vals []string, fmtStr, sep string ) string { l := []string{} for _, v := range vals { l = append(l, fmt.Sprintf(fmtStr, v)) @@ -74,35 +74,35 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename return "id" } return FormatCamelLower(name) - }, + } , "FormatCamelLowerTitle": func(name string) string { return FormatCamelLowerTitle(name) - }, + } , "FormatCamelPluralTitle": func(name string) string { return FormatCamelPluralTitle(name) - }, + } , "FormatCamelPluralTitleLower": func(name string) string { return FormatCamelPluralTitleLower(name) - }, + } , "FormatCamelPluralCamel": func(name string) string { return FormatCamelPluralCamel(name) - }, + } , "FormatCamelPluralLower": func(name string) string { return FormatCamelPluralLower(name) - }, + } , "FormatCamelPluralUnderscore": func(name string) string { return FormatCamelPluralUnderscore(name) - }, + } , "FormatCamelPluralLowerUnderscore": func(name string) string { return FormatCamelPluralLowerUnderscore(name) - }, + } , "FormatCamelUnderscore": func(name string) string { return FormatCamelUnderscore(name) - }, + } , "FormatCamelLowerUnderscore": func(name string) string { return FormatCamelLowerUnderscore(name) - }, - "FieldTagHasOption": func(f modelField, tagName, optName string) bool { + } , + "FieldTagHasOption": func(f modelField, tagName, optName string ) bool { if f.Tags == nil { return false } @@ -143,7 +143,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename } } } else if !ft.HasOption(newVal) { - if ft.Name == "" { + if ft.Name == ""{ ft.Name = newVal } else { ft.Options = append(ft.Options, newVal) @@ -165,7 +165,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename // Load the template file using the text/template package. tmpl, err := baseTmpl.Parse(string(dat)) if err != nil { - err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath) + err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath) log.Printf("loadTemplateObjects : %v\n%v", err, string(dat)) return nil, err } @@ -189,18 +189,18 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename // Executed the defined template with the given data. var tpl bytes.Buffer if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil { - err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath) + err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath) return resp, err } // Format the source code to ensure its valid and code to parsed consistently. codeBytes, err := format.Source(tpl.Bytes()) if err != nil { - err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename) + err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename) dl := []string{} for idx, l := range strings.Split(tpl.String(), "\n") { - dl = append(dl, fmt.Sprintf("%d -> ", idx)+l) + dl = append(dl, fmt.Sprintf("%d -> ", idx) + l) } log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n")) @@ -216,7 +216,7 @@ func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename // Parse the code lines into a set of objects. objs, err := goparse.ParseLines(codeLines, 0) if err != nil { - err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename) + err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename) log.Printf("loadTemplateObjects : %v\n%v", err, codeStr) return resp, err } @@ -316,7 +316,7 @@ func templateFileOrderedNames(localPath string, names []string) (resp []string, idx := 0 scanner := bufio.NewScanner(file) for scanner.Scan() { - if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") { + if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") { continue } diff --git a/example-project/tools/truss/internal/goparse/goparse.go b/example-project/tools/truss/internal/goparse/goparse.go index 72df56f..aa0454c 100644 --- a/example-project/tools/truss/internal/goparse/goparse.go +++ b/example-project/tools/truss/internal/goparse/goparse.go @@ -88,6 +88,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) { ld := lineDepth(l) + + //fmt.Println("l", l) + //fmt.Println("> Depth", ld, "???", depth) + + if ld == depth { if strings.HasPrefix(ls, "/*") { multiLine = true @@ -108,6 +113,11 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) { } } + + //fmt.Println("> multiLine", multiLine) + //fmt.Println("> multiComment", multiComment) + //fmt.Println("> muiliVar", muiliVar) + objLines = append(objLines, l) if multiComment { @@ -131,6 +141,8 @@ func ParseLines(lines []string, depth int) (objs *GoObjects, err error) { } } + //fmt.Println(" > objLines", objLines) + obj, err := ParseGoObject(objLines, depth) if err != nil { log.Println(err) @@ -197,8 +209,22 @@ func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) { if strings.HasPrefix(firstStrip, "var") { obj.Type = GoObjectType_Var + + if !strings.HasSuffix(firstStrip, "(") { + if strings.HasPrefix(firstStrip, "var ") { + firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "var ", "", 1)) + } + obj.Name = strings.Split(firstStrip, " ")[0] + } } else if strings.HasPrefix(firstStrip, "const") { obj.Type = GoObjectType_Const + + if !strings.HasSuffix(firstStrip, "(") { + if strings.HasPrefix(firstStrip, "const ") { + firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "const ", "", 1)) + } + obj.Name = strings.Split(firstStrip, " ")[0] + } } else if strings.HasPrefix(firstStrip, "func") { obj.Type = GoObjectType_Func diff --git a/example-project/tools/truss/internal/goparse/goparse_test.go b/example-project/tools/truss/internal/goparse/goparse_test.go index 0e28315..8f9b62b 100644 --- a/example-project/tools/truss/internal/goparse/goparse_test.go +++ b/example-project/tools/truss/internal/goparse/goparse_test.go @@ -64,7 +64,8 @@ func TestNewDocImports(t *testing.T) { func TestParseLines1(t *testing.T) { g := gomega.NewGomegaWithT(t) - code := `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model { + codeTests := []string{ + `func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model { g := gomega.NewGomegaWithT(t) obj := datamodels.MockModelNew() resp, err := ModelCreate(ctx, DB, &obj) @@ -76,15 +77,30 @@ func TestParseLines1(t *testing.T) { g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active)) return resp } -` - lines := strings.Split(code, "\n") - - objs, err := ParseLines(lines, 0) - if err != nil { - t.Fatalf("got error %v", err) +`, + `var ( + // ErrNotFound abstracts the postgres not found error. + ErrNotFound = errors.New("Entity not found") + // ErrInvalidID occurs when an ID is not in a valid form. + ErrInvalidID = errors.New("ID is not in its proper form") + // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. + ErrForbidden = errors.New("Attempted action is not allowed") +) +`, } - g.Expect(objs.Lines()).Should(gomega.Equal(lines)) + for _, code := range codeTests { + lines := strings.Split(code, "\n") + + objs, err := ParseLines(lines, 0) + if err != nil { + t.Fatalf("got error %v", err) + } + + g.Expect(objs.Lines()).Should(gomega.Equal(lines)) + } + + } func TestParseLines2(t *testing.T) { diff --git a/example-project/tools/truss/main.go b/example-project/tools/truss/main.go index 441c797..16c61f2 100644 --- a/example-project/tools/truss/main.go +++ b/example-project/tools/truss/main.go @@ -124,8 +124,8 @@ func main() { app.Commands = []cli.Command{ { Name: "dbtable2crud", - Aliases: []string{"dbtable2crud"}, - Usage: "dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -saveChanges=false", + Aliases: []string{}, + Usage: "-table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] ", Flags: []cli.Flag{ cli.StringFlag{Name: "dbtable, table"}, cli.StringFlag{Name: "modelFile, modelfile, file"}, diff --git a/example-project/tools/truss/templates/dbtable2crud/model_crud.tmpl b/example-project/tools/truss/templates/dbtable2crud/model_crud.tmpl index f290291..5972bc3 100644 --- a/example-project/tools/truss/templates/dbtable2crud/model_crud.tmpl +++ b/example-project/tools/truss/templates/dbtable2crud/model_crud.tmpl @@ -46,15 +46,15 @@ func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) { return nil, errors.WithStack(err) } - return &a, nil + return &m, nil } {{ end }} {{ define "ACL"}} -{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }} +{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }} // CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { - {{ if $hasAccountId }} + {{ if $hasAccountID }} // If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims // has the correct access to the {{ FormatCamelLower $.Model.Name }}. if claims.Audience != "" { @@ -90,7 +90,7 @@ func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn * // CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { - err = CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }}) + err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }}) if err != nil { return err } @@ -105,7 +105,7 @@ func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn // applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided. // 1. No claims, request is internal, no ACL applied -{{ if $hasAccountId }} +{{ if $hasAccountID }} // 2. All role types can access their user ID {{ end }} func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error { @@ -114,7 +114,7 @@ func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilde return nil } - {{ if $hasAccountId }} + {{ if $hasAccountID }} query.Where(query.Equal("account_id", claims.Audience)) {{ end }} @@ -226,9 +226,10 @@ func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCam } {{ end }} {{ define "Create"}} -{{ $hasAccountId := (StringListHasValue $.Model.ColumnNames "account_id") }} +{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }} {{ $reqName := (Concat $.Model.Name "CreateRequest") }} {{ $createFields := (index $.StructFields $reqName) }} +{{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }} // Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database. func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create") @@ -237,18 +238,18 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re if claims.Audience != "" { // Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to. if !claims.HasRole(auth.RoleAdmin) { - return errors.WithStack(ErrForbidden) + return nil, errors.WithStack(ErrForbidden) } - {{ if $hasAccountId }} - if req.AccountId != "" { + {{ if $reqHasAccountID }} + if req.AccountID != "" { // Request accountId must match claims. - if req.AccountId != claims.Audience { - return errors.WithStack(ErrForbidden) + if req.AccountID != claims.Audience { + return nil, errors.WithStack(ErrForbidden) } } else { // Set the accountId from claims. - req.AccountId = claims.Audience + req.AccountID = claims.Audience } {{ end }} } @@ -256,7 +257,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re v := validator.New() // Validate the request. - err = v.Struct(req) + err := v.Struct(req) if err != nil { return nil, err } @@ -281,6 +282,13 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re {{ end }}{{ end }} } + {{ if and (not $reqHasAccountID) ($hasAccountID) }} + // Set the accountId from claims. + if claims.Audience != "" && m.AccountID == "" { + req.AccountID = claims.Audience + } + {{ end }} + {{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }} if req.{{ $f.FieldName }} != nil { {{ if eq $f.FieldType "sql.NullString" }} @@ -315,7 +323,7 @@ func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $re return nil, err } - return &a, nil + return &m, nil } {{ end }} {{ define "Update"}} diff --git a/example-project/tools/truss/templates/dbtable2crud/model_crud_test.tmpl b/example-project/tools/truss/templates/dbtable2crud/model_crud_test.tmpl index d3ced6a..65477d1 100644 --- a/example-project/tools/truss/templates/dbtable2crud/model_crud_test.tmpl +++ b/example-project/tools/truss/templates/dbtable2crud/model_crud_test.tmpl @@ -1,19 +1,11 @@ {{ define "imports"}} import ( - "github.com/lib/pq" - "math/rand" - "os" - "strings" - "testing" - "time" - "{{ $.GoSrcPath }}/internal/platform/auth" "{{ $.GoSrcPath }}/internal/platform/tests" - "github.com/dgrijalva/jwt-go" "github.com/google/go-cmp/cmp" "github.com/huandu/go-sqlbuilder" - "github.com/pborman/uuid" - "github.com/pkg/errors" + "os" + "testing" ) {{ end }} {{ define "Globals"}} @@ -33,7 +25,7 @@ func testMain(m *testing.M) int { {{ define "TestFindRequestQuery"}} // TestFindRequestQuery validates findRequestQuery func TestFindRequestQuery(t *testing.T) { - where := "name = ? or address1 = ?" + where := "field1 = ? or field2 = ?" var ( limit uint = 12 offset uint = 34 @@ -52,7 +44,7 @@ func TestFindRequestQuery(t *testing.T) { Limit: &limit, Offset: &offset, } - expected := "SELECT " + accountMapColumns + " FROM " + accountTableName + " WHERE (name = ? or address1 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" + expected := "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34" res, args := findRequestQuery(req) @@ -63,4 +55,77 @@ func TestFindRequestQuery(t *testing.T) { t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) } } +{{ end }} +{{ define "TestApplyClaimsSelect"}} +// TestApplyClaimsSelect applyClaimsSelect +func TestApplyClaimsSelect(t *testing.T) { + var claimTests = []struct { + name string + claims auth.Claims + expectedSql string + error error + }{ + {"EmptyClaims", + auth.Claims{}, + "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName, + nil, + }, + {"RoleAccount", + auth.Claims{ + Roles: []string{auth.RoleAdmin}, + StandardClaims: jwt.StandardClaims{ + Subject: "user1", + Audience: "acc1", + }, + }, + "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'", + nil, + }, + {"RoleAdmin", + auth.Claims{ + Roles: []string{auth.RoleAdmin}, + StandardClaims: jwt.StandardClaims{ + Subject: "user1", + Audience: "acc1", + }, + }, + "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'", + nil, + }, + } + + t.Log("Given the need to validate ACLs are enforced by claims to a select query.") + { + for i, tt := range claimTests { + t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name) + { + ctx := tests.Context() + + query := selectQuery() + + err := applyClaimsSelect(ctx, tt.claims, query) + if err != tt.error { + t.Logf("\t\tGot : %+v", err) + t.Logf("\t\tWant: %+v", tt.error) + t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed) + } + + sql, args := query.Build() + + // Use mysql flavor so placeholders will get replaced for comparison. + sql, err = sqlbuilder.MySQL.Interpolate(sql, args) + if err != nil { + t.Log("\t\tGot :", err) + t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed) + } + + if diff := cmp.Diff(sql, tt.expectedSql); diff != "" { + t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff) + } + + t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success) + } + } + } +} {{ end }} \ No newline at end of file diff --git a/example-project/tools/truss/templates/dbtable2crud/models.tmpl b/example-project/tools/truss/templates/dbtable2crud/models.tmpl index fa3940e..1829d08 100644 --- a/example-project/tools/truss/templates/dbtable2crud/models.tmpl +++ b/example-project/tools/truss/templates/dbtable2crud/models.tmpl @@ -38,7 +38,7 @@ type {{ $f.FieldType }} string const ( {{ range $evk, $ev := $f.DbColumn.EnumValues }} // {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}. - {{ $f.FieldType }}_{{ FormatCamel $ev }}{{ $f.FieldType }} = "{{ $ev }}" + {{ $f.FieldType }}_{{ FormatCamel $ev }} {{ $f.FieldType }} = "{{ $ev }}" {{ end }} )