2019-06-24 01:30:18 -08:00
{{ define "imports"}}
import (
"context"
"database/sql"
"time"
"{{ $.GoSrcPath }}/internal/platform/auth"
"github.com/huandu/go-sqlbuilder"
"github.com/jmoiron/sqlx"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
"gopkg.in/go-playground/validator.v9"
)
{{ end }}
{{ define "Globals"}}
const (
// The database table for {{ $.Model.Name }}
{{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}"
)
var (
// ErrNotFound abstracts the postgres not found error.
ErrNotFound = errors.New("Entity not found")
// ErrInvalidID occurs when an ID is not in a valid form.
ErrInvalidID = errors.New("ID is not in its proper form")
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
ErrForbidden = errors.New("Attempted action is not allowed")
)
{{ end }}
{{ define "Helpers"}}
// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }}
var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}"
// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct
// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns
func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
var (
m {{ $.Model.Name }}
err error
)
err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }})
if err != nil {
return nil, errors.WithStack(err)
}
2019-06-24 04:26:48 -08:00
return &m, nil
2019-06-24 01:30:18 -08:00
}
{{ end }}
{{ define "ACL"}}
2019-06-24 04:26:48 -08:00
{{ $ hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
2019-06-24 01:30:18 -08:00
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
2019-06-24 04:26:48 -08:00
{{ if $ hasAccountID }}
2019-06-24 01:30:18 -08:00
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
if claims.Audience != "" {
// select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID]
query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName)
query.Where(query.And(
query.Equal("account_id", claims.Audience),
query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}),
))
queryStr, args := query.Build()
queryStr = dbConn.Rebind(queryStr)
var {{ FormatCamelLower $.Model.PrimaryField }} string
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }})
if err != nil && err != sql.ErrNoRows {
err = errors.Wrapf(err, "query - %s", query.String())
return err
}
// When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access
// to the specified {{ FormatCamelLowerTitle $.Model.Name }}.
if {{ FormatCamelLower $.Model.PrimaryField }} == "" {
return errors.WithStack(ErrForbidden)
}
}
{{ else }}
// TODO: Unable to auto generate sql statement, update accordingly.
panic("Not implemented!")
{{ end }}
return nil
}
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
2019-06-24 04:26:48 -08:00
err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
2019-06-24 01:30:18 -08:00
if err != nil {
return err
}
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
return errors.WithStack(ErrForbidden)
}
return nil
}
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
// 1. No claims, request is internal, no ACL applied
2019-06-24 04:26:48 -08:00
{{ if $ hasAccountID }}
2019-06-24 01:30:18 -08:00
// 2. All role types can access their user ID
{{ end }}
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
// Claims are empty, don't apply any ACL
if claims.Audience == "" {
return nil
}
2019-06-24 04:26:48 -08:00
{{ if $ hasAccountID }}
2019-06-24 01:30:18 -08:00
query.Where(query.Equal("account_id", claims.Audience))
{{ end }}
return nil
}
{{ end }}
{{ define "Find"}}
{{ $ hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}
// selectQuery constructs a base select query for {{ $.Model.Name }}
func selectQuery() *sqlbuilder.SelectBuilder {
query := sqlbuilder.NewSelectBuilder()
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
query.From({{ FormatCamelLower $.Model.Name }}TableName)
return query
}
// findRequestQuery generates the select query for the given find request.
// TODO: Need to figure out why can't parse the args when appending the where
// to the query.
func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
query := selectQuery()
if req.Where != nil {
query.Where(query.And(*req.Where))
}
if len(req.Order) > 0 {
query.OrderBy(req.Order...)
}
if req.Limit != nil {
query.Limit(int(*req.Limit))
}
if req.Offset != nil {
query.Offset(int(*req.Offset))
}
return query, req.Args
}
// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params.
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) {
query, args := findRequestQuery(req)
return find(ctx, claims, dbConn, query, args{{ if $ hasArchived }}, req.IncludedArchived {{ end }})
}
// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query.
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $ hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find")
defer span.Finish()
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
query.From({{ FormatCamelLower $.Model.Name }}TableName)
{{ if $ hasArchived }}
if !includedArchived {
query.Where(query.IsNull("archived_at"))
}
{{ end }}
// Check to see if a sub query needs to be applied for the claims.
err := applyClaimsSelect(ctx, claims, query)
if err != nil {
return nil, err
}
queryStr, queryArgs := query.Build()
queryStr = dbConn.Rebind(queryStr)
args = append(args, queryArgs...)
// Fetch all entries from the db.
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed")
return nil, err
}
// Iterate over each row.
resp := []*{{ $.Model.Name }}{}
for rows.Next() {
u, err := mapRowsTo{{ $.Model.Name }}(rows)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
return nil, err
}
resp = append(resp, u)
}
return resp, nil
}
// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $ hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read")
defer span.Finish()
// Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }}
query := selectQuery()
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }}))
res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $ hasArchived }}, includedArchived{{ end }})
if err != nil {
return nil, err
} else if res == nil || len(res) == 0 {
err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id)
return nil, err
}
u := res[0]
return u, nil
}
{{ end }}
{{ define "Create"}}
2019-06-24 04:26:48 -08:00
{{ $ hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
2019-06-24 01:30:18 -08:00
{{ $ reqName := (Concat $.Model.Name "CreateRequest") }}
{{ $ createFields := (index $.StructFields $ reqName ) }}
2019-06-24 04:26:48 -08:00
{{ $ reqHasAccountID := false }}{{ $ reqAccountID := (index $ createFields "AccountID") }}{{ if $ reqAccountID }}{{ $ reqHasAccountID = true }}{{ end }}
2019-06-24 01:30:18 -08:00
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $ reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
defer span.Finish()
if claims.Audience != "" {
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
if !claims.HasRole(auth.RoleAdmin) {
2019-06-24 04:26:48 -08:00
return nil, errors.WithStack(ErrForbidden)
2019-06-24 01:30:18 -08:00
}
2019-06-24 04:26:48 -08:00
{{ if $ reqHasAccountID }}
if req.AccountID != "" {
2019-06-24 01:30:18 -08:00
// Request accountId must match claims.
2019-06-24 04:26:48 -08:00
if req.AccountID != claims.Audience {
return nil, errors.WithStack(ErrForbidden)
2019-06-24 01:30:18 -08:00
}
} else {
// Set the accountId from claims.
2019-06-24 04:26:48 -08:00
req.AccountID = claims.Audience
2019-06-24 01:30:18 -08:00
}
{{ end }}
}
v := validator.New()
// Validate the request.
2019-06-24 04:26:48 -08:00
err := v.Struct(req)
2019-06-24 01:30:18 -08:00
if err != nil {
return nil, err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
m := {{ $.Model.Name }}{
{{ range $ mk , $ mf := $.Model.Fields }}{{ $ cf := (index $ createFields $ mf . FieldName ) }}
{{ if eq $ mf . FieldName $.Model.PrimaryField }}{{ $ isUuid := (FieldTagHasOption $ mf "validate" "uuid") }}{{ $ mf . FieldName }}: {{ if $ isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $ mf . FieldName }}{{ end }},
{{ else if or (eq $ mf . FieldName "CreatedAt") (eq $ mf . FieldName "UpdatedAt") }}{{ $ mf . FieldName }}: now,
{{ else if $ cf }}{{ $ required := (FieldTagHasOption $ cf "validate" "required") }}{{ if $ required }}{{ $ cf . FieldName }}: req.{{ $ cf . FieldName }},{{ else if ne $ cf . DefaultValue "" }}{{ $ cf . FieldName }}: {{ $ cf . DefaultValue }},{{ end }}
{{ end }}{{ end }}
}
2019-06-24 04:26:48 -08:00
{{ if and (not $ reqHasAccountID ) ( $ hasAccountID ) }}
// Set the accountId from claims.
if claims.Audience != "" && m.AccountID == "" {
req.AccountID = claims.Audience
}
{{ end }}
2019-06-24 01:30:18 -08:00
{{ range $ fk , $f := $ createFields }}{{ $ required := (FieldTagHasOption $f "validate" "required") }}{{ if not $ required }}
if req.{{ $ f . FieldName }} != nil {
{{ if eq $ f . FieldType "sql.NullString" }}
m.{{ $ f . FieldName }} = sql.NullString{String: *req.{{ $ f . FieldName }}, Valid: true}
{{ else if eq $ f . FieldType "*sql.NullString" }}
m.{{ $ f . FieldName }} = &sql.NullString{String: *req.{{ $ f . FieldName }}, Valid: true}
{{ else }}
m.{{ $ f . FieldName }} = *req.{{ $ f . FieldName }}
{{ end }}
}
{{ end }}{{ end }}
// Build the insert SQL statement.
query := sqlbuilder.NewInsertBuilder()
query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName)
query.Cols(
{{ range $ mk , $ mf := $.Model.Fields }}{{ $ cf := (index $ createFields $ mf . FieldName ) }}{{ if or (eq $ mf . FieldName $.Model.PrimaryField) ( $ cf ) (eq $ mf . FieldName "CreatedAt") (eq $ mf . FieldName "UpdatedAt") }}"{{ $ mf . ColumnName }}",
{{ end }}{{ end }}
)
query.Values(
{{ range $ mk , $ mf := $.Model.Fields }}{{ $ cf := (index $ createFields $ mf . FieldName ) }}{{ if or (eq $ mf . FieldName $.Model.PrimaryField) ( $ cf ) (eq $ mf . FieldName "CreatedAt") (eq $ mf . FieldName "UpdatedAt") }}m.{{ $ mf . FieldName }},
{{ end }}{{ end }}
)
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed")
return nil, err
}
2019-06-24 04:26:48 -08:00
return &m, nil
2019-06-24 01:30:18 -08:00
}
{{ end }}
{{ define "Update"}}
{{ $ reqName := (Concat $.Model.Name "UpdateRequest") }}
{{ $ updateFields := (index $.StructFields $ reqName ) }}
// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database.
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $ reqName }}, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update")
defer span.Finish()
v := validator.New()
// Validate the request.
err := v.Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }})
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
var fields []string
{{ range $ mk , $ mf := $.Model.Fields }}{{ $ uf := (index $ updateFields $ mf . FieldName ) }}{{ if and ( $ uf . FieldName ) (ne $ uf . FieldName $.Model.PrimaryField) }}
{{ $ optional := (FieldTagHasOption $ uf "validate" "omitempty") }}{{ $ isUuid := (FieldTagHasOption $ uf "validate" "uuid") }}
if req.{{ $ uf . FieldName }} != nil {
{{ if and ( $ optional ) ( $ isUuid ) }}
if *req.{{ $ uf . FieldName }} != "" {
fields = append(fields, query.Assign("{{ $ uf . ColumnName }}", req.{{ $ uf . FieldName }}))
} else {
fields = append(fields, query.Assign("{{ $ uf . ColumnName }}", nil))
}
{{ else }}
fields = append(fields, query.Assign("{{ $ uf . ColumnName }}", req.{{ $ uf . FieldName }}))
{{ end }}
}
{{ end }}{{ end }}
// If there's nothing to update we can quit early.
if len(fields) == 0 {
return nil
}
{{ $ hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $ hasUpdatedAt }}
// Append the updated_at field
fields = append(fields, query.Assign("updated_at", now))
{{ end }}
query.Set(fields...)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}
{{ define "Archive"}}
// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
}{
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// If now empty set it to the current time.
if now.IsZero() {
now = time.Now()
}
// Always store the time as UTC.
now = now.UTC()
// Postgres truncates times to milliseconds when storing. We and do the same
// here so the value we return is consistent with what we store.
now = now.Truncate(time.Millisecond)
// Build the update SQL statement.
query := sqlbuilder.NewUpdateBuilder()
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
query.Set(
query.Assign("archived_at", now),
)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}
{{ define "Delete"}}
// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database.
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error {
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete")
defer span.Finish()
// Defines the struct to apply validation
req := struct {
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
}{
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
}
// Validate the request.
err := validator.New().Struct(req)
if err != nil {
return err
}
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
if err != nil {
return err
}
// Build the delete SQL statement.
query := sqlbuilder.NewDeleteBuilder()
query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName)
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
// Execute the query with the provided context.
sql, args := query.Build()
sql = dbConn.Rebind(sql)
_, err = dbConn.ExecContext(ctx, sql, args...)
if err != nil {
err = errors.Wrapf(err, "query - %s", query.String())
err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
return err
}
return nil
}
{{ end }}