{{ define "imports"}} import ( "context" "database/sql" "time" "{{ $.GoSrcPath }}/internal/platform/auth" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" "github.com/pborman/uuid" "github.com/pkg/errors" "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer" "gopkg.in/go-playground/validator.v9" ) {{ end }} {{ define "Globals"}} const ( // The database table for {{ $.Model.Name }} {{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}" ) var ( // ErrNotFound abstracts the postgres not found error. ErrNotFound = errors.New("Entity not found") // ErrInvalidID occurs when an ID is not in a valid form. ErrInvalidID = errors.New("ID is not in its proper form") // ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies. ErrForbidden = errors.New("Attempted action is not allowed") ) {{ end }} {{ define "Helpers"}} // {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }} var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}" // mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct // with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) { var ( m {{ $.Model.Name }} err error ) err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }}) if err != nil { return nil, errors.WithStack(err) } return &m, nil } {{ end }} {{ define "ACL"}} {{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }} // CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { {{ if $hasAccountID }} // If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims // has the correct access to the {{ FormatCamelLower $.Model.Name }}. if claims.Audience != "" { // select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID] query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName) query.Where(query.And( query.Equal("account_id", claims.Audience), query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}), )) queryStr, args := query.Build() queryStr = dbConn.Rebind(queryStr) var {{ FormatCamelLower $.Model.PrimaryField }} string err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }}) if err != nil && err != sql.ErrNoRows { err = errors.Wrapf(err, "query - %s", query.String()) return err } // When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access // to the specified {{ FormatCamelLowerTitle $.Model.Name }}. if {{ FormatCamelLower $.Model.PrimaryField }} == "" { return errors.WithStack(ErrForbidden) } } {{ else }} // TODO: Unable to auto generate sql statement, update accordingly. panic("Not implemented!") {{ end }} return nil } // CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}. func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error { err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }}) if err != nil { return err } // Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to. if !claims.HasRole(auth.RoleAdmin) { return errors.WithStack(ErrForbidden) } return nil } // applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided. // 1. No claims, request is internal, no ACL applied {{ if $hasAccountID }} // 2. All role types can access their user ID {{ end }} func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error { // Claims are empty, don't apply any ACL if claims.Audience == "" { return nil } {{ if $hasAccountID }} query.Where(query.Equal("account_id", claims.Audience)) {{ end }} return nil } {{ end }} {{ define "Find"}} {{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }} // selectQuery constructs a base select query for {{ $.Model.Name }} func selectQuery() *sqlbuilder.SelectBuilder { query := sqlbuilder.NewSelectBuilder() query.Select({{ FormatCamelLower $.Model.Name }}MapColumns) query.From({{ FormatCamelLower $.Model.Name }}TableName) return query } // findRequestQuery generates the select query for the given find request. // TODO: Need to figure out why can't parse the args when appending the where // to the query. func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) { query := selectQuery() if req.Where != nil { query.Where(query.And(*req.Where)) } if len(req.Order) > 0 { query.OrderBy(req.Order...) } if req.Limit != nil { query.Limit(int(*req.Limit)) } if req.Offset != nil { query.Offset(int(*req.Offset)) } return query, req.Args } // Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params. func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) { query, args := findRequestQuery(req) return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }}) } // find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query. func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find") defer span.Finish() query.Select({{ FormatCamelLower $.Model.Name }}MapColumns) query.From({{ FormatCamelLower $.Model.Name }}TableName) {{ if $hasArchived }} if !includedArchived { query.Where(query.IsNull("archived_at")) } {{ end }} // Check to see if a sub query needs to be applied for the claims. err := applyClaimsSelect(ctx, claims, query) if err != nil { return nil, err } queryStr, queryArgs := query.Build() queryStr = dbConn.Rebind(queryStr) args = append(args, queryArgs...) // Fetch all entries from the db. rows, err := dbConn.QueryContext(ctx, queryStr, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed") return nil, err } // Iterate over each row. resp := []*{{ $.Model.Name }}{} for rows.Next() { u, err := mapRowsTo{{ $.Model.Name }}(rows) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) return nil, err } resp = append(resp, u) } return resp, nil } // Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database. func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read") defer span.Finish() // Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }} query := selectQuery() query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }})) res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }}) if err != nil { return nil, err } else if res == nil || len(res) == 0 { err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id) return nil, err } u := res[0] return u, nil } {{ end }} {{ define "Create"}} {{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }} {{ $reqName := (Concat $.Model.Name "CreateRequest") }} {{ $createFields := (index $.StructFields $reqName) }} {{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }} // Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database. func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create") defer span.Finish() if claims.Audience != "" { // Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to. if !claims.HasRole(auth.RoleAdmin) { return nil, errors.WithStack(ErrForbidden) } {{ if $reqHasAccountID }} if req.AccountID != "" { // Request accountId must match claims. if req.AccountID != claims.Audience { return nil, errors.WithStack(ErrForbidden) } } else { // Set the accountId from claims. req.AccountID = claims.Audience } {{ end }} } v := validator.New() // Validate the request. err := v.Struct(req) if err != nil { return nil, err } // If now empty set it to the current time. if now.IsZero() { now = time.Now() } // Always store the time as UTC. now = now.UTC() // Postgres truncates times to milliseconds when storing. We and do the same // here so the value we return is consistent with what we store. now = now.Truncate(time.Millisecond) m := {{ $.Model.Name }}{ {{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }} {{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }}, {{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now, {{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }} {{ end }}{{ end }} } {{ if and (not $reqHasAccountID) ($hasAccountID) }} // Set the accountId from claims. if claims.Audience != "" && m.AccountID == "" { req.AccountID = claims.Audience } {{ end }} {{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }} if req.{{ $f.FieldName }} != nil { {{ if eq $f.FieldType "sql.NullString" }} m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true} {{ else if eq $f.FieldType "*sql.NullString" }} m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true} {{ else }} m.{{ $f.FieldName }} = *req.{{ $f.FieldName }} {{ end }} } {{ end }}{{ end }} // Build the insert SQL statement. query := sqlbuilder.NewInsertBuilder() query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName) query.Cols( {{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}", {{ end }}{{ end }} ) query.Values( {{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }}, {{ end }}{{ end }} ) // Execute the query with the provided context. sql, args := query.Build() sql = dbConn.Rebind(sql) _, err = dbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed") return nil, err } return &m, nil } {{ end }} {{ define "Update"}} {{ $reqName := (Concat $.Model.Name "UpdateRequest") }} {{ $updateFields := (index $.StructFields $reqName) }} // Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database. func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update") defer span.Finish() v := validator.New() // Validate the request. err := v.Struct(req) if err != nil { return err } // Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request. err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }}) if err != nil { return err } // If now empty set it to the current time. if now.IsZero() { now = time.Now() } // Always store the time as UTC. now = now.UTC() // Postgres truncates times to milliseconds when storing. We and do the same // here so the value we return is consistent with what we store. now = now.Truncate(time.Millisecond) // Build the update SQL statement. query := sqlbuilder.NewUpdateBuilder() query.Update({{ FormatCamelLower $.Model.Name }}TableName) var fields []string {{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }} {{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }} if req.{{ $uf.FieldName }} != nil { {{ if and ($optional) ($isUuid) }} if *req.{{ $uf.FieldName }} != "" { fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }})) } else { fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil)) } {{ else }} fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }})) {{ end }} } {{ end }}{{ end }} // If there's nothing to update we can quit early. if len(fields) == 0 { return nil } {{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }} // Append the updated_at field fields = append(fields, query.Assign("updated_at", now)) {{ end }} query.Set(fields...) query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }})) // Execute the query with the provided context. sql, args := query.Build() sql = dbConn.Rebind(sql) _, err = dbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }}) return err } return nil } {{ end }} {{ define "Archive"}} // Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database. func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive") defer span.Finish() // Defines the struct to apply validation req := struct { {{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"` }{ {{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }}, } // Validate the request. err := validator.New().Struct(req) if err != nil { return err } // Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request. err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID) if err != nil { return err } // If now empty set it to the current time. if now.IsZero() { now = time.Now() } // Always store the time as UTC. now = now.UTC() // Postgres truncates times to milliseconds when storing. We and do the same // here so the value we return is consistent with what we store. now = now.Truncate(time.Millisecond) // Build the update SQL statement. query := sqlbuilder.NewUpdateBuilder() query.Update({{ FormatCamelLower $.Model.Name }}TableName) query.Set( query.Assign("archived_at", now), ) query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }})) // Execute the query with the provided context. sql, args := query.Build() sql = dbConn.Rebind(sql) _, err = dbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }}) return err } return nil } {{ end }} {{ define "Delete"}} // Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database. func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error { span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete") defer span.Finish() // Defines the struct to apply validation req := struct { {{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"` }{ {{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }}, } // Validate the request. err := validator.New().Struct(req) if err != nil { return err } // Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request. err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID) if err != nil { return err } // Build the delete SQL statement. query := sqlbuilder.NewDeleteBuilder() query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName) query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }})) // Execute the query with the provided context. sql, args := query.Build() sql = dbConn.Rebind(sql) _, err = dbConn.ExecContext(ctx, sql, args...) if err != nil { err = errors.Wrapf(err, "query - %s", query.String()) err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }}) return err } return nil } {{ end }}