You've already forked golang-saas-starter-kit
mirror of
https://github.com/raseels-repos/golang-saas-starter-kit.git
synced 2025-07-03 00:58:13 +02:00
Merge branch 'master' of gitlab.com:geeks-accelerator/oss/saas-starter-kit into issue8/datadog-lambda-func
This commit is contained in:
140
README.md
140
README.md
@ -4,9 +4,7 @@ Copyright 2019, Geeks Accelerator
|
||||
twins@geeksaccelerator.com
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
The SaaS Starter Kit is a set of libraries for building scalable software-as-a-service (SaaS) applications while
|
||||
The SaaS Starter Kit is a set of libraries for building scalable software-as-a-service (SaaS) applications that helps
|
||||
preventing both misuse and fraud. The goal of this project is to provide a proven starting point for new
|
||||
projects that reduces the repetitive tasks in getting a new project launched to production that can easily be scaled
|
||||
and ready to onboard enterprise clients. It uses minimal dependencies, implements idiomatic code and follows Golang
|
||||
@ -18,9 +16,8 @@ This project should not be considered a web framework. It is a starter toolkit t
|
||||
to handle some of the common challenges for developing SaaS using Golang. Coding is a discovery process and with that,
|
||||
it leaves you in control of your project’s architecture and development.
|
||||
|
||||
SaaS product offerings typically provide two main components: an API and a web application. Both facilitate delivering a
|
||||
valuable software based product to clients ideally from a single code base on a recurring basis delivered over the
|
||||
internet.
|
||||
SaaS product offerings generally provide a web-based service using a subscription model. They typically provide at
|
||||
least two main components: a REST API and a web application.
|
||||
|
||||
To see screen captures of the web app and auto-generated API documentation, check out this Google Slides deck:
|
||||
https://docs.google.com/presentation/d/1WGYqMZ-YUOaNxlZBfU4srpN8i86MU0ppWWSBb3pkejM/edit#slide=id.p
|
||||
@ -30,27 +27,55 @@ https://docs.google.com/presentation/d/1WGYqMZ-YUOaNxlZBfU4srpN8i86MU0ppWWSBb3pk
|
||||
[](https://docs.google.com/presentation/d/1WGYqMZ-YUOaNxlZBfU4srpN8i86MU0ppWWSBb3pkejM/edit#slide=id.p)
|
||||
|
||||
|
||||
There are five areas of expertise that an engineer or her engineering team must do for a project to grow and scale.
|
||||
Based on our experience, a few core decisions were made for each of these areas that help you focus initially on writing
|
||||
the business logic.
|
||||
1. Micro level - Since SaaS requires transactions, project implements Postgres. Implementation facilitates the data
|
||||
semantics that define the data being captured and their relationships.
|
||||
2. Macro level - The project architecture and design, provides basic project structure and foundation for development.
|
||||
3. Business logic - Defines an example Golang package that helps illustrate where value generating activities should
|
||||
reside and how the code will be delivered to clients.
|
||||
4. Deployment and Operations - Integrates with GitLab for CI/CD and AWS for serverless deployments with AWS Fargate.
|
||||
5. Observability - Implements Datadog to facilitate exposing metrics, logs and request tracing that ensure stable and
|
||||
responsive service for clients.
|
||||
## Motivation
|
||||
|
||||
When getting started building SaaS, we believe that is important for both the frontend web experience and the backend
|
||||
business logic (business value) be developed in the same codebase - using the same language for the frontend and backend
|
||||
development in the same single repository. We believe this for two main reasons:
|
||||
1. Lower barrier for and accelerate onboarding of new engineers developing the SaaS by making it easy for them
|
||||
to load a complete mental model of the codebase.
|
||||
2. Minimize potential bottlenecks and eliminate complexities of coordinating development across repositories, with
|
||||
potentially different teams responsible for the different repositories.
|
||||
|
||||
Once the SaaS product has gained market traction and the core set of functionality has been identified to achieve
|
||||
product-market fit, the functionality could be re-written with a language that would improve user experience or
|
||||
further increase efficiency. Two good examples of this would be:
|
||||
1. Developing an iPhone or Android app. The front end web application provided by this project is responsive
|
||||
to support mobile devices. However, there may be a point that developing native would provide an enhanced experience.
|
||||
2. The backend business logic has a set of methods that handle small data transformations on a massive scale. If the code
|
||||
for this is relatively small and can easily be rewritten, it might make sense to rewrite this directly in C or Rust.
|
||||
This is a very rare case as GoLang is already a preformat language.
|
||||
|
||||
There are five areas of expertise that an engineer or engineering team must do for a project to grow and scale.
|
||||
Based on our experience, a few core decisions were made for each of these areas that help you focus initially on
|
||||
building the business logic.
|
||||
1. Micro level - The semantics that cover how data is defined, the relationships and how the data is being captured. This
|
||||
project tries to minimize the connection between packages on the same horizontally later. Data models should not be part
|
||||
of feature functionality. Hopefully these micro level decisions help prevent cases where 30K lines of code rely on a
|
||||
single data model which makes simple one line changes potentially high risk.
|
||||
2. Macro level - The architecture and its design provides basic project structure and the foundation for development.
|
||||
This project provides a good set of examples that demonstrate where different types of code can reside.
|
||||
3. Business logic - The code for the business logic facilitates value generating activities for the business. This
|
||||
project provides an example Golang package that helps illustrate the implementation of business logic and how it can be
|
||||
delivered to clients.
|
||||
4. Deployment and Operations - Get the code to production! This sometimes can be a challenging task as it requires
|
||||
a knowledge of a completely different expertise - DevOps. This project provides a complete continuous build pipeline that
|
||||
will push the code to production with minimal effort using serverless deployments to AWS Fargate with GitLab CI/CD.
|
||||
5. Observability - Ensure the code is running as expected in a remote environment. This project implements Datadog to
|
||||
facilitate exposing metrics, logs and request tracing to obversabe and validate your services are stable and responsive
|
||||
for your clients (hopefully paying clients).
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
The example project is a complete starter kit for building SasS with GoLang. It provides two example services:
|
||||
* Web App - Responsive web application to provide service to clients. Includes user signup and user authentication for
|
||||
direct client interaction via their web connected devices.
|
||||
direct client interaction via their web browsers.
|
||||
* Web API - REST API with JWT authentication that renders results as JSON. This allows clients and other third-pary companies to develop deep
|
||||
integrations with the project.
|
||||
|
||||
And these tools:
|
||||
* Schema - Initializing of Postgres database and handles schema migration.
|
||||
The example project also provides these tools:
|
||||
* Schema - Creating, initializing tables of Postgres database and handles schema migration.
|
||||
* Dev Ops - Deploying project to AWS with GitLab CI/CD.
|
||||
|
||||
It contains the following features:
|
||||
@ -58,17 +83,17 @@ It contains the following features:
|
||||
* Auto-documented REST API.
|
||||
* Middleware integration.
|
||||
* Database support using Postgres.
|
||||
* Key value store using Redis
|
||||
* Cache and key value store using Redis
|
||||
* CRUD based pattern.
|
||||
* Role-based access control (RBAC).
|
||||
* Account signup and user management.
|
||||
* Distributed logging and tracing.
|
||||
* Integration with Datadog for enterprise-level observability.
|
||||
* Testing patterns.
|
||||
* Use of Docker, Docker Compose, and Makefiles.
|
||||
* Build, deploy and run application using Docker, Docker Compose, and Makefiles.
|
||||
* Vendoring dependencies with Modules, requires Go 1.12 or higher.
|
||||
* Continuous deployment pipeline.
|
||||
* Serverless deployments.
|
||||
* Serverless deployments with AWS ECS Fargate.
|
||||
* CLI with boilerplate templates to reduce repetitive copy/pasting.
|
||||
* Integration with GitLab for enterprise-level CI/CD.
|
||||
|
||||
@ -81,9 +106,8 @@ Accordingly, the project architecture is illustrated with the following diagram.
|
||||
With SaaS, a client subscribes to an online service you provide them. The example project provides functionality for
|
||||
clients to subscribe and then once subscribed they can interact with your software service.
|
||||
|
||||
The initial contributors to this project are building this saas-starter-kit based on their years of experience building
|
||||
building enterprise B2B SaaS. Particularily, this saas-starter-kit is based on their most recent experience building the
|
||||
B2B SaaS for [standard operating procedure software](https://keeni.space) (100% Golang). Reference the Keeni.Space website,
|
||||
The initial contributors to this project are building this saas-starter-kit based on their years of experience building enterprise B2B SaaS. Particularily, this saas-starter-kit is based on their most recent experience building the
|
||||
B2B SaaS for [standard operating procedure software](https://keeni.space) (written entirely in Golang). Please refer to the Keeni.Space website,
|
||||
its [SOP software pricing](https://keeni.space/pricing) and its signup process. The SaaS web app is then available at
|
||||
[app.keeni.space](https://app.keeni.space). They plan on leveraging this experience and build it into a simplified set
|
||||
example services for both a web API and a web app for SaaS businesses.
|
||||
@ -147,8 +171,8 @@ need to be using Go >= 1.11.
|
||||
You should now be able to clone the project.
|
||||
|
||||
```bash
|
||||
git clone git@gitlab.com:geeks-accelerator/oss/saas-starter-kit.git
|
||||
cd saas-starter-kit/
|
||||
$ git clone git@gitlab.com:geeks-accelerator/oss/saas-starter-kit.git
|
||||
$ cd saas-starter-kit/
|
||||
```
|
||||
|
||||
If you have Go Modules enabled, you should be able compile the project locally. If you have Go Modulels disabled, see
|
||||
@ -162,13 +186,13 @@ This project is using Go Module support for vendoring dependencies.
|
||||
We are using the `tidy` command to maintain the dependencies and make sure the project can create reproducible builds.
|
||||
|
||||
```bash
|
||||
GO111MODULE=on go mod tidy
|
||||
$ GO111MODULE=on go mod tidy
|
||||
```
|
||||
|
||||
It is recommended to use at least Go 1.12 and enable go modules.
|
||||
|
||||
```bash
|
||||
echo "export GO111MODULE=on" >> ~/.bash_profile
|
||||
$ echo "export GO111MODULE=on" >> ~/.bash_profile
|
||||
```
|
||||
|
||||
|
||||
@ -184,8 +208,8 @@ https://docs.docker.com/install/
|
||||
There is a `docker-compose` file that knows how to build and run all the services. Each service has its own a
|
||||
`dockerfile`.
|
||||
|
||||
When you run `docker-compose up` it will run all the services including the main.go file for each Go service. The
|
||||
services the project will run are:
|
||||
Before using `docker-compose`, you need to copy `sample.env_docker_compose` to `.env_docker_compose` that docker will use. When you run `docker-compose up` it will run all the services including the main.go file for each Go service. The
|
||||
following services will run:
|
||||
- web-api
|
||||
- web-app
|
||||
- postgres
|
||||
@ -198,7 +222,8 @@ Use the `docker-compose.yaml` to run all of the services, including the 3rd part
|
||||
command, Docker will download the required images for the 3rd party services.
|
||||
|
||||
```bash
|
||||
$ docker-compose up --build
|
||||
$ cp sample.env_docker_compose .env_docker_compose
|
||||
$ docker-compose up
|
||||
```
|
||||
|
||||
Default configuration is set which should be valid for most systems.
|
||||
@ -208,7 +233,7 @@ Use the `docker-compose.yaml` file to configure the services differently using e
|
||||
#### How we run the project
|
||||
|
||||
We like to run the project where the services run in the background of our CLI. This can be done by using the -d with
|
||||
the `docker-compose up` command:
|
||||
the `docker-compose up --build` command:
|
||||
```bash
|
||||
$ docker-compose up --build -d
|
||||
```
|
||||
@ -226,7 +251,7 @@ $ docker-compose logs -f
|
||||
|
||||
### Stopping the project
|
||||
|
||||
You can hit <ctrl>C in the terminal window running `docker-compose up`.
|
||||
You can hit `ctrl-C` in the terminal window that ran `docker-compose up`.
|
||||
|
||||
Once that shutdown sequence is complete, it is important to run the `docker-compose down` command.
|
||||
|
||||
@ -251,7 +276,7 @@ services again with 'docker-compose up'.
|
||||
To restart a specific service, first use `docker ps` to see the list of services running.
|
||||
|
||||
```bash
|
||||
docker ps
|
||||
$ docker ps
|
||||
CONTAINER ID IMAGE COMMAND NAMES
|
||||
35043164fd0d example-project/web-api:latest "/gosrv" saas-starter-kit_web-api_1
|
||||
d34c8fc27f3b example-project/web-app:latest "/gosrv" saas-starter-kit_web-app_1
|
||||
@ -259,23 +284,23 @@ fd844456243e postgres:11-alpine "docker-entrypoint.s…"
|
||||
dda16bfbb8b5 redis:latest "redis-server --appe…" saas-starter-kit_redis_1
|
||||
```
|
||||
|
||||
Then use `docker-compose down` for a specific service. In the command include the name of the container for the service
|
||||
to shut down. In the example command, we will shut down down the web-api service so we can start it again.
|
||||
Then use `docker-compose stop` for a specific service. In the command including the name of service in `docker-compose.yaml` file for the service
|
||||
to shut down. In the example command, we will shut down the web-api service so we can start it again.
|
||||
|
||||
```bash
|
||||
docker-compose down saas-starter-kit_web-api_1
|
||||
$ docker-compose stop web-app
|
||||
```
|
||||
|
||||
If you are not in the directory for the service you want to restart navigate to it. We will go to the directory for the
|
||||
If you are not in the directory for the service you want to restart then navigate to it. We will go to the directory for the
|
||||
web-api.
|
||||
|
||||
```bash
|
||||
cd cmd/web-api/
|
||||
$ cd cmd/web-api/
|
||||
```
|
||||
|
||||
Then you can start the service again by running main.go
|
||||
```bash
|
||||
go run main.go
|
||||
$ go run main.go
|
||||
```
|
||||
|
||||
|
||||
@ -283,7 +308,7 @@ go run main.go
|
||||
|
||||
By default the project will compile and run without AWS configs or other third-party dependencies.
|
||||
|
||||
As you use start utilizing AWS services in this project and/or ready for deployment, you will need to start specifying
|
||||
As you start utilizing AWS services in this project and/or ready for deployment, you will need to start specifying
|
||||
AWS configs in a docker-compose file. You can also set credentials for other dependencies in the new docker-compose file
|
||||
too.
|
||||
|
||||
@ -291,7 +316,7 @@ The sample docker-compose file is not loaded since it is named sample, which all
|
||||
configs.
|
||||
|
||||
To set AWS configs and credentials for other third-party dependencies, you need to create a copy of the sample
|
||||
docker-compose file without "sample" prepending the file name.
|
||||
environment docker-compose file without "sample" prepending the file name.
|
||||
|
||||
Navigate to the root of the project. Copy `sample.env_docker_compose` to `.env_docker_compose`.
|
||||
|
||||
@ -311,20 +336,20 @@ $ DD_API_KEY=
|
||||
```
|
||||
|
||||
In your new copy of the example docker-compose file ".env_docker_compose", set the AWS configs by updating the following
|
||||
environmental variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_REGION. Remember to remove the $ before the
|
||||
environment variables: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_REGION. Remember to remove the $ before the
|
||||
variable name.
|
||||
|
||||
As noted in the Local Installation section, the project is integrated with Datadog for observability. You can specify
|
||||
the API key for you Datadog account by setting the environmental variable: DD_API_KEY.
|
||||
the API key for your Datadog account by setting the environment variable: DD_API_KEY.
|
||||
|
||||
|
||||
## Web API
|
||||
[cmd/web-api](https://gitlab.com/geeks-accelerator/oss/saas-starter-kit/tree/master/cmd/web-api)
|
||||
|
||||
REST API available to clients for supporting deeper integrations. This API is also a foundation for third-party
|
||||
REST API is available to clients for supporting deeper integrations. This API is also a foundation for third-party
|
||||
integrations. The API implements JWT authentication that renders results as JSON to clients.
|
||||
|
||||
Once the web-app service is running it will be available on port 3001.
|
||||
Once the web-api service is running it will be available on port 3001.
|
||||
http://127.0.0.1:3001/
|
||||
|
||||
This web-api service is not directly used by the web-app service to prevent locking the functionally required for
|
||||
@ -334,7 +359,7 @@ client expectations.
|
||||
|
||||
The web-app will have its own internal API, similar to this external web-api service, but not exposed for third-party
|
||||
integrations. It is believed that in the beginning, having to define an additional API for internal purposes is worth
|
||||
the additional effort as the internal API can handle more flexible updates.
|
||||
for the additional effort as the internal API can handle more flexible updates.
|
||||
|
||||
For more details on this service, read [web-api readme](https://gitlab.com/geeks-accelerator/oss/saas-starter-kit/blob/master/cmd/web-api/README.md)
|
||||
|
||||
@ -418,7 +443,11 @@ shared=# \dt
|
||||
|
||||
## Deployment
|
||||
|
||||
This project includes a complete build pipeline that relies on AWS and GitLab. The `.gitlab-ci.yaml` file includes the following build
|
||||
This project includes a complete build pipeline that relies on AWS and GitLab. The presentation "[SaaS Starter Kit - Setup GitLab CI / CD](https://docs.google.com/presentation/d/1sRFQwipziZlxBtN7xuF-ol8vtUqD55l_4GE-4_ns-qM/edit#slide=id.p)"
|
||||
has been made available on Google Docs that provides a step by step guide to setting up a build pipeline using your own
|
||||
AWS and GitLab accounts.
|
||||
|
||||
The `.gitlab-ci.yaml` file includes the following build
|
||||
stages:
|
||||
```yaml
|
||||
stages:
|
||||
@ -442,10 +471,19 @@ additional configuration. You can customizing any of the configuration in the co
|
||||
the saas-starter-kit, keeping the deployment in GoLang limits the scope of additional technologies required to get your
|
||||
project successfully up and running. If you understand Golang, then you will be a master at devops with this tool.
|
||||
|
||||
Refer to the [README](https://gitlab.com/geeks-accelerator/oss/saas-starter-kit/blob/master/tools/devops/README.md) for setup details.
|
||||
Refer to the [README](https://gitlab.com/geeks-accelerator/oss/saas-starter-kit/blob/master/tools/devops/README.md) for
|
||||
setup details.
|
||||
|
||||
|
||||
## Development Notes
|
||||
|
||||
### Country / Region / Postal Code Support
|
||||
|
||||
This project uses [geonames.org](https://www.geonames.org/) to populate database tables for countries, postal codes and
|
||||
timezones that help facilitate standardizing user input. To keep the schema script quick for `dev`, the postal codes for
|
||||
only country code `US` are loaded. This can be changed as needed in
|
||||
[geonames.go](https://gitlab.com/geeks-accelerator/oss/saas-starter-kit/blob/master/internal/geonames/geonames.go#L30).
|
||||
|
||||
### Datadog
|
||||
|
||||
Datadog has a custom init script to support setting multiple expvar urls for monitoring. The docker-compose file then
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
@ -24,9 +25,12 @@ const (
|
||||
geonamesTableName = "geonames"
|
||||
)
|
||||
|
||||
var (
|
||||
// List of country codes that will geonames will be downloaded for.
|
||||
ValidGeonameCountries = []string{
|
||||
// List of country codes that will geonames will be downloaded for.
|
||||
func ValidGeonameCountries(ctx context.Context) []string {
|
||||
if webcontext.ContextEnv(ctx) == webcontext.Env_Dev {
|
||||
return []string{"US"}
|
||||
}
|
||||
return []string{
|
||||
"AD", "AR", "AS", "AT", "AU", "AX", "BD", "BE", "BG", "BM",
|
||||
"BR", "BY", "CA", "CH", "CO", "CR", "CZ", "DE", "DK", "DO",
|
||||
"DZ", "ES", "FI", "FO", "FR", "GB", "GF", "GG", "GL", "GP",
|
||||
@ -36,7 +40,7 @@ var (
|
||||
"PK", "PL", "PM", "PR", "PT", "RE", "RO", "RU", "SE", "SI",
|
||||
"SJ", "SK", "SM", "TH", "TR", "UA", "US", "UY", "VA", "VI",
|
||||
"WF", "YT", "ZA"}
|
||||
)
|
||||
}
|
||||
|
||||
// FindGeonames ....
|
||||
func FindGeonames(ctx context.Context, dbConn *sqlx.DB, orderBy, where string, args ...interface{}) ([]*Geoname, error) {
|
||||
@ -194,7 +198,7 @@ func LoadGeonames(ctx context.Context, rr chan<- interface{}, countries ...strin
|
||||
defer close(rr)
|
||||
|
||||
if len(countries) == 0 {
|
||||
countries = ValidGeonameCountries
|
||||
countries = ValidGeonameCountries(ctx)
|
||||
}
|
||||
|
||||
for _, country := range countries {
|
||||
|
@ -107,7 +107,11 @@ func authenticateSession(authenticator *auth.Authenticator, required bool) web.M
|
||||
|
||||
claims, err := authenticator.ParseClaims(tknStr)
|
||||
if err != nil {
|
||||
return weberror.NewError(ctx, err, http.StatusUnauthorized)
|
||||
if required {
|
||||
return weberror.NewError(ctx, err, http.StatusUnauthorized)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Add claims to the context so they can be retrieved later.
|
||||
|
@ -1,6 +1,7 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
@ -8,7 +9,7 @@ import (
|
||||
|
||||
// initSchema runs before any migrations are executed. This happens when no other migrations
|
||||
// have previously been executed.
|
||||
func initSchema(db *sqlx.DB, log *log.Logger, isUnittest bool) func(*sqlx.DB) error {
|
||||
func initSchema(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest bool) func(*sqlx.DB) error {
|
||||
f := func(db *sqlx.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
|
||||
// migrationList returns a list of migrations to be executed. If the id of the
|
||||
// migration already exists in the migrations table it will be skipped.
|
||||
func migrationList(db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate.Migration {
|
||||
func migrationList(ctx context.Context, db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate.Migration {
|
||||
return []*sqlxmigrate.Migration{
|
||||
// Create table users.
|
||||
{
|
||||
@ -213,7 +213,7 @@ func migrationList(db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate
|
||||
},
|
||||
// Load new geonames table.
|
||||
{
|
||||
ID: "20190731-02b",
|
||||
ID: "20190731-02h",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
|
||||
schemas := []string{
|
||||
@ -253,7 +253,7 @@ func migrationList(db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate
|
||||
|
||||
} else {
|
||||
resChan := make(chan interface{})
|
||||
go geonames.LoadGeonames(context.Background(), resChan)
|
||||
go geonames.LoadGeonames(ctx, resChan)
|
||||
|
||||
for r := range resChan {
|
||||
switch v := r.(type) {
|
||||
@ -288,7 +288,7 @@ func migrationList(db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate
|
||||
},
|
||||
// Load new countries table.
|
||||
{
|
||||
ID: "20190731-02d",
|
||||
ID: "20190731-02f",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
|
||||
schemas := []string{
|
||||
@ -489,7 +489,7 @@ func migrationList(db *sqlx.DB, log *log.Logger, isUnittest bool) []*sqlxmigrate
|
||||
},
|
||||
// Load new country_timezones table.
|
||||
{
|
||||
ID: "20190731-03d",
|
||||
ID: "20190731-03e",
|
||||
Migrate: func(tx *sql.Tx) error {
|
||||
|
||||
queries := []string{
|
||||
|
@ -1,21 +1,22 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/geeks-accelerator/sqlxmigrate"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func Migrate(masterDb *sqlx.DB, log *log.Logger, isUnittest bool) error {
|
||||
func Migrate(ctx context.Context, masterDb *sqlx.DB, log *log.Logger, isUnittest bool) error {
|
||||
// Load list of Schema migrations and init new sqlxmigrate client
|
||||
migrations := migrationList(masterDb, log, isUnittest)
|
||||
migrations := migrationList(ctx, masterDb, log, isUnittest)
|
||||
m := sqlxmigrate.New(masterDb, sqlxmigrate.DefaultOptions, migrations)
|
||||
m.SetLogger(log)
|
||||
|
||||
// Append any schema that need to be applied if this is a fresh migration
|
||||
// ie. the migrations database table does not exist.
|
||||
m.InitSchema(initSchema(masterDb, log, isUnittest))
|
||||
m.InitSchema(initSchema(ctx, masterDb, log, isUnittest))
|
||||
|
||||
// Execute the migrations
|
||||
return m.Migrate()
|
||||
|
@ -105,10 +105,10 @@ instance will be a dedicated host since we need it always up and running, thus i
|
||||
Advanced Details: none
|
||||
```
|
||||
|
||||
4. Add Storage. Increase the volume size for the root device to 100 GiB.
|
||||
4. Add Storage. Increase the volume size for the root device to 30 GiB.
|
||||
```
|
||||
Volume Type | Device | Size (GiB) | Volume Type
|
||||
Root | /dev/xvda | 100 | General Purpose SSD (gp2)
|
||||
Root | /dev/xvda | 30 | General Purpose SSD (gp2)
|
||||
```
|
||||
|
||||
5. Add Tags.
|
||||
@ -127,7 +127,7 @@ instance will be a dedicated host since we need it always up and running, thus i
|
||||
|
||||
7. Review and Launch instance. Select an existing key pair or create a new one. This will be used to SSH into the
|
||||
instance for additional configuration.
|
||||
|
||||
|
||||
8. Update the security group to reference itself. The instances need to be able to communicate between each other.
|
||||
|
||||
Navigate to edit the security group and add the following two rules where `SECURITY_GROUP_ID` is replaced with the
|
||||
|
@ -66,7 +66,7 @@ type ServiceDeployFlags struct {
|
||||
DockerFile string `validate:"omitempty" example:"./cmd/web-api/Dockerfile"`
|
||||
EnableLambdaVPC bool `validate:"omitempty" example:"false"`
|
||||
EnableEcsElb bool `validate:"omitempty" example:"false"`
|
||||
IsLambda bool `validate:"omitempty" example:"false"`
|
||||
IsLambda bool `validate:"omitempty" example:"false"`
|
||||
|
||||
StaticFilesS3Enable bool `validate:"omitempty" example:"false"`
|
||||
StaticFilesImgResizeEnable bool `validate:"omitempty" example:"false"`
|
||||
@ -130,6 +130,8 @@ type deployEcsServiceRequest struct {
|
||||
VpcPublic *ec2.CreateVpcInput
|
||||
VpcPublicSubnets []*ec2.CreateSubnetInput
|
||||
|
||||
EnableLambdaVPC bool `validate:"omitempty"`
|
||||
IsLambda bool `validate:"omitempty"`
|
||||
RecreateService bool `validate:"omitempty"`
|
||||
|
||||
SDNamepsace *servicediscovery.CreatePrivateDnsNamespaceInput
|
||||
@ -189,7 +191,7 @@ func NewServiceDeployRequest(log *log.Logger, flags ServiceDeployFlags) (*servic
|
||||
S3BucketPrivateName: flags.S3BucketPrivateName,
|
||||
S3BucketPublicName: flags.S3BucketPublicName,
|
||||
|
||||
IsLambda: flags.IsLambda,
|
||||
IsLambda: flags.IsLambda,
|
||||
EnableLambdaVPC: flags.EnableLambdaVPC,
|
||||
EnableEcsElb: flags.EnableEcsElb,
|
||||
RecreateService: flags.RecreateService,
|
||||
@ -436,7 +438,7 @@ func NewServiceDeployRequest(log *log.Logger, flags ServiceDeployFlags) (*servic
|
||||
|
||||
if req.IsLambda {
|
||||
|
||||
} else {
|
||||
} else {
|
||||
|
||||
}
|
||||
|
||||
@ -856,12 +858,30 @@ func NewServiceDeployRequest(log *log.Logger, flags ServiceDeployFlags) (*servic
|
||||
log.Printf("\t%s\tDefaults set.", tests.Success)
|
||||
}
|
||||
|
||||
r, err := regexp.Compile(`^(\d+)`)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
// Workaround for domains that start with a numeric value like 8north.com
|
||||
// Validation fails with error: failed on the 'fqdn' tag
|
||||
origServiceHostPrimary := req.ServiceHostPrimary
|
||||
matches := r.FindAllString(req.ServiceHostPrimary, -1)
|
||||
if len(matches) > 0 {
|
||||
for _, m := range matches {
|
||||
req.ServiceHostPrimary = strings.Replace(req.ServiceHostPrimary, m, "X", -1)
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("\tValidate request.")
|
||||
errs := validator.New().Struct(req)
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
// Reset the primary domain after validation is completed.
|
||||
req.ServiceHostPrimary = origServiceHostPrimary
|
||||
|
||||
log.Printf("\t%s\tNew request generated.", tests.Success)
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/web/webcontext"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/platform/flag"
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/schema"
|
||||
@ -125,8 +128,16 @@ func main() {
|
||||
// =========================================================================
|
||||
// Start Migrations
|
||||
|
||||
// Set the context with the required values to
|
||||
// process the request.
|
||||
v := webcontext.Values{
|
||||
Now: time.Now(),
|
||||
Env: cfg.Env,
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), webcontext.KeyValues, &v)
|
||||
|
||||
// Execute the migrations
|
||||
if err = schema.Migrate(masterDb, log, false); err != nil {
|
||||
if err = schema.Migrate(ctx, masterDb, log, false); err != nil {
|
||||
log.Fatalf("main : Migrate : %v", err)
|
||||
}
|
||||
log.Printf("main : Migrate : Completed")
|
||||
|
1
tools/truss/.gitignore
vendored
1
tools/truss/.gitignore
vendored
@ -1 +0,0 @@
|
||||
truss
|
@ -1,68 +0,0 @@
|
||||
# SaaS Truss
|
||||
|
||||
Copyright 2019, Geeks Accelerator
|
||||
accelerator@geeksinthewoods.com.com
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
Truss provides code generation to reduce copy/pasting.
|
||||
|
||||
|
||||
## Local Installation
|
||||
|
||||
### Build
|
||||
```bash
|
||||
go build .
|
||||
```
|
||||
|
||||
### Configuration
|
||||
```bash
|
||||
./truss -h
|
||||
|
||||
Usage of ./truss
|
||||
--cmd string <dbtable2crud>
|
||||
--db_host string <127.0.0.1:5433>
|
||||
--db_user string <postgres>
|
||||
--db_pass string <postgres>
|
||||
--db_database string <shared>
|
||||
--db_driver string <postgres>
|
||||
--db_timezone string <utc>
|
||||
--db_disabletls bool <false>
|
||||
```
|
||||
|
||||
## Commands:
|
||||
|
||||
## dbtable2crud
|
||||
|
||||
Used to bootstrap a new business logic package with basic CRUD.
|
||||
|
||||
**Usage**
|
||||
```bash
|
||||
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false]
|
||||
```
|
||||
|
||||
**Example**
|
||||
1. Define a new database table in `internal/schema/migrations.go`
|
||||
|
||||
|
||||
2. Create a new file for the base model at `internal/projects/models.go`. Only the following struct needs to be included. All the other times will be generated.
|
||||
```go
|
||||
// Project represents a workflow.
|
||||
type Project struct {
|
||||
ID string `json:"id" validate:"required,uuid"`
|
||||
AccountID string `json:"account_id" validate:"required,uuid" truss:"api-create"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
Status ProjectStatus `json:"status" validate:"omitempty,oneof=active disabled"`
|
||||
CreatedAt time.Time `json:"created_at" truss:"api-read"`
|
||||
UpdatedAt time.Time `json:"updated_at" truss:"api-read"`
|
||||
ArchivedAt pq.NullTime `json:"archived_at" truss:"api-hide"`
|
||||
}
|
||||
```
|
||||
|
||||
3. Run `dbtable2crud`
|
||||
```bash
|
||||
./truss dbtable2crud -table=projects -file=../../internal/project/models.go -model=Project -save=true
|
||||
```
|
||||
|
||||
|
@ -1,149 +0,0 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type psqlColumn struct {
|
||||
Table string
|
||||
Column string
|
||||
ColumnId int64
|
||||
NotNull bool
|
||||
DataTypeFull string
|
||||
DataTypeName string
|
||||
DataTypeLength *int
|
||||
NumericPrecision *int
|
||||
NumericScale *int
|
||||
IsPrimaryKey bool
|
||||
PrimaryKeyName *string
|
||||
IsUniqueKey bool
|
||||
UniqueKeyName *string
|
||||
IsForeignKey bool
|
||||
ForeignKeyName *string
|
||||
ForeignKeyColumnId pq.Int64Array
|
||||
ForeignKeyTable *string
|
||||
ForeignKeyLocalColumnId pq.Int64Array
|
||||
DefaultFull *string
|
||||
DefaultValue *string
|
||||
IsEnum bool
|
||||
EnumTypeId *string
|
||||
EnumValues []string
|
||||
}
|
||||
|
||||
// descTable lists all the columns for a table.
|
||||
func descTable(db *sqlx.DB, dbName, dbTable string) ([]psqlColumn, error) {
|
||||
|
||||
queryStr := fmt.Sprintf(`SELECT
|
||||
c.relname as table,
|
||||
f.attname as column,
|
||||
f.attnum as columnId,
|
||||
f.attnotnull as not_null,
|
||||
pg_catalog.format_type(f.atttypid,f.atttypmod) AS data_type_full,
|
||||
t.typname AS data_type_name,
|
||||
CASE WHEN f.atttypmod >= 0 AND t.typname <> 'numeric'THEN (f.atttypmod - 4) --first 4 bytes are for storing actual length of data
|
||||
END AS data_type_length,
|
||||
CASE WHEN t.typname = 'numeric' THEN (((f.atttypmod - 4) >> 16) & 65535)
|
||||
END AS numeric_precision,
|
||||
CASE WHEN t.typname = 'numeric' THEN ((f.atttypmod - 4)& 65535 )
|
||||
END AS numeric_scale,
|
||||
CASE WHEN p.contype = 'p' THEN true ELSE false
|
||||
END AS is_primary_key,
|
||||
CASE WHEN p.contype = 'p' THEN p.conname
|
||||
END AS primary_key_name,
|
||||
CASE WHEN p.contype = 'u' THEN true ELSE false
|
||||
END AS is_unique_key,
|
||||
CASE WHEN p.contype = 'u' THEN p.conname
|
||||
END AS unique_key_name,
|
||||
CASE WHEN p.contype = 'f' THEN true ELSE false
|
||||
END AS is_foreign_key,
|
||||
CASE WHEN p.contype = 'f' THEN p.conname
|
||||
END AS foreignkey_name,
|
||||
CASE WHEN p.contype = 'f' THEN p.confkey
|
||||
END AS foreign_key_columnid,
|
||||
CASE WHEN p.contype = 'f' THEN g.relname
|
||||
END AS foreign_key_table,
|
||||
CASE WHEN p.contype = 'f' THEN p.conkey
|
||||
END AS foreign_key_local_column_id,
|
||||
CASE WHEN f.atthasdef = 't' THEN d.adsrc
|
||||
END AS default_value,
|
||||
CASE WHEN t.typtype = 'e' THEN true ELSE false
|
||||
END AS is_enum,
|
||||
CASE WHEN t.typtype = 'e' THEN t.oid
|
||||
END AS enum_type_id
|
||||
FROM pg_attribute f
|
||||
JOIN pg_class c ON c.oid = f.attrelid
|
||||
JOIN pg_type t ON t.oid = f.atttypid
|
||||
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||
WHERE c.relkind = 'r'::char
|
||||
AND f.attisdropped = false
|
||||
AND c.relname = '%s'
|
||||
AND f.attnum > 0
|
||||
ORDER BY f.attnum
|
||||
;`, dbTable) // AND n.nspname = '%s'
|
||||
|
||||
rows, err := db.Query(queryStr)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// iterate over each row
|
||||
var resp []psqlColumn
|
||||
for rows.Next() {
|
||||
var c psqlColumn
|
||||
err = rows.Scan(&c.Table, &c.Column, &c.ColumnId, &c.NotNull, &c.DataTypeFull, &c.DataTypeName, &c.DataTypeLength, &c.NumericPrecision, &c.NumericScale, &c.IsPrimaryKey, &c.PrimaryKeyName, &c.IsUniqueKey, &c.UniqueKeyName, &c.IsForeignKey, &c.ForeignKeyName, &c.ForeignKeyColumnId, &c.ForeignKeyTable, &c.ForeignKeyLocalColumnId, &c.DefaultFull, &c.IsEnum, &c.EnumTypeId)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.DefaultFull != nil {
|
||||
defaultValue := *c.DefaultFull
|
||||
|
||||
// "'active'::project_status_t"
|
||||
defaultValue = strings.Split(defaultValue, "::")[0]
|
||||
c.DefaultValue = &defaultValue
|
||||
}
|
||||
|
||||
resp = append(resp, c)
|
||||
}
|
||||
|
||||
for colIdx, dbCol := range resp {
|
||||
if !dbCol.IsEnum {
|
||||
continue
|
||||
}
|
||||
|
||||
queryStr := fmt.Sprintf(`SELECT e.enumlabel
|
||||
FROM pg_enum AS e
|
||||
WHERE e.enumtypid = '%s'
|
||||
ORDER BY e.enumsortorder`, *dbCol.EnumTypeId)
|
||||
|
||||
rows, err := db.Query(queryStr)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var v string
|
||||
err = rows.Scan(&v)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", queryStr)
|
||||
return nil, err
|
||||
}
|
||||
dbCol.EnumValues = append(dbCol.EnumValues, v)
|
||||
}
|
||||
|
||||
resp[colIdx] = dbCol
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
@ -1,431 +0,0 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/internal/schema"
|
||||
"geeks-accelerator/oss/saas-starter-kit/tools/truss/internal/goparse"
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sergi/go-diff/diffmatchpatch"
|
||||
)
|
||||
|
||||
// Run in the main entry point for the dbtable2crud cmd.
|
||||
func Run(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir, goSrcPath string, saveChanges bool) error {
|
||||
log.SetPrefix(log.Prefix() + " : dbtable2crud")
|
||||
|
||||
// Ensure the schema is up to date
|
||||
if err := schema.Migrate(db, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// When dbTable is empty, lower case the model name
|
||||
if dbTable == "" {
|
||||
dbTable = strings.Join(camelcase.Split(modelName), " ")
|
||||
dbTable = english.PluralWord(2, dbTable, "")
|
||||
dbTable = strings.Replace(dbTable, " ", "_", -1)
|
||||
dbTable = strings.ToLower(dbTable)
|
||||
}
|
||||
|
||||
// Parse the model file and load the specified model struct.
|
||||
model, err := parseModelFile(db, log, dbName, dbTable, modelFile, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Basic lint of the model struct.
|
||||
err = validateModel(log, model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmplData := map[string]interface{}{
|
||||
"GoSrcPath": goSrcPath,
|
||||
}
|
||||
|
||||
// Update the model file with new or updated code.
|
||||
err = updateModel(log, model, modelFile, templateDir, tmplData, saveChanges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the model crud file with new or updated code.
|
||||
err = updateModelCrud(db, log, dbName, dbTable, modelFile, modelName, templateDir, model, tmplData, saveChanges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateModel performs a basic lint of the model struct to ensure
|
||||
// code gen output is correct.
|
||||
func validateModel(log *log.Logger, model *modelDef) error {
|
||||
for _, sf := range model.Fields {
|
||||
if sf.DbColumn == nil && sf.ColumnName != "-" {
|
||||
log.Printf("validateStruct : Unable to find struct field for db column %s\n", sf.ColumnName)
|
||||
}
|
||||
|
||||
var expectedType string
|
||||
switch sf.FieldName {
|
||||
case "ID":
|
||||
expectedType = "string"
|
||||
case "CreatedAt":
|
||||
expectedType = "time.Time"
|
||||
case "UpdatedAt":
|
||||
expectedType = "time.Time"
|
||||
case "ArchivedAt":
|
||||
expectedType = "pq.NullTime"
|
||||
}
|
||||
|
||||
if expectedType != "" && expectedType != sf.FieldType {
|
||||
log.Printf("validateStruct : Struct field %s should be of type %s not %s\n", sf.FieldName, expectedType, sf.FieldType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModel updated the parsed code file with the new code.
|
||||
func updateModel(log *log.Logger, model *modelDef, modelFile, templateDir string, tmplData map[string]interface{}, saveChanges bool) error {
|
||||
|
||||
// Execute template and parse code to be used to compare against modelFile.
|
||||
tmplObjs, err := loadTemplateObjects(log, model, templateDir, "models.tmpl", tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the current code as a string to produce a diff.
|
||||
curCode := model.String()
|
||||
|
||||
objHeaders := []*goparse.GoObject{}
|
||||
|
||||
for _, obj := range tmplObjs {
|
||||
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
|
||||
objHeaders = append(objHeaders, obj)
|
||||
continue
|
||||
}
|
||||
|
||||
if model.HasType(obj.Name, obj.Type) {
|
||||
cur := model.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
if len(objHeaders) > 0 {
|
||||
// Remove any comments and linebreaks before the existing object so updates can be added.
|
||||
removeObjs := []*goparse.GoObject{}
|
||||
for idx := cur.Index - 1; idx > 0; idx-- {
|
||||
prevObj := model.Objects().List()[idx]
|
||||
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
|
||||
removeObjs = append(removeObjs, prevObj)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(removeObjs) > 0 {
|
||||
err := model.Objects().Remove(removeObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the current index is correct.
|
||||
cur = model.Objects().Get(obj.Name, obj.Type)
|
||||
}
|
||||
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
newObjs = append(newObjs, c)
|
||||
}
|
||||
}
|
||||
|
||||
newObjs = append(newObjs, obj)
|
||||
|
||||
// Do the object replacement.
|
||||
err := model.Objects().Replace(cur, newObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := model.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := model.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
objHeaders = []*goparse.GoObject{}
|
||||
}
|
||||
|
||||
// Set some flags to determine additional imports and need to be added.
|
||||
var hasEnum bool
|
||||
var hasPq bool
|
||||
for _, f := range model.Fields {
|
||||
if f.DbColumn != nil && f.DbColumn.IsEnum {
|
||||
hasEnum = true
|
||||
}
|
||||
if strings.HasPrefix(strings.Trim(f.FieldType, "*"), "pq.") {
|
||||
hasPq = true
|
||||
}
|
||||
}
|
||||
|
||||
reqImports := []string{}
|
||||
if hasEnum {
|
||||
reqImports = append(reqImports, "database/sql/driver")
|
||||
reqImports = append(reqImports, "gopkg.in/go-playground/validator.v9")
|
||||
reqImports = append(reqImports, "github.com/pkg/errors")
|
||||
}
|
||||
|
||||
if hasPq {
|
||||
reqImports = append(reqImports, "github.com/lib/pq")
|
||||
}
|
||||
|
||||
for _, in := range reqImports {
|
||||
err := model.AddImport(goparse.GoImport{Name: in})
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add import %s for %s", in, model.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if saveChanges {
|
||||
err = model.Save(modelFile)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to save changes for %s to %s", model.Name, modelFile)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Produce a diff after the updates have been applied.
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(curCode, model.String(), true)
|
||||
fmt.Println(dmp.DiffPrettyText(diffs))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModelCrud updated the parsed code file with the new code.
|
||||
func updateModelCrud(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName, templateDir string, baseModel *modelDef, tmplData map[string]interface{}, saveChanges bool) error {
|
||||
|
||||
// Load all the updated struct fields from the base model file.
|
||||
structFields := make(map[string]map[string]modelField)
|
||||
for _, obj := range baseModel.GoDocument.Objects().List() {
|
||||
if obj.Type != goparse.GoObjectType_Struct || obj.Name == baseModel.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
objFields, err := parseModelFields(baseModel.GoDocument, obj.Name, baseModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
structFields[obj.Name] = make(map[string]modelField)
|
||||
for _, f := range objFields {
|
||||
structFields[obj.Name][f.FieldName] = f
|
||||
}
|
||||
}
|
||||
|
||||
// Append the struct fields to be used for template execution.
|
||||
if tmplData == nil {
|
||||
tmplData = make(map[string]interface{})
|
||||
}
|
||||
tmplData["StructFields"] = structFields
|
||||
|
||||
// Get the dir to store crud methods and test files.
|
||||
modelDir := filepath.Dir(modelFile)
|
||||
|
||||
// Process the CRUD hanlders template and write to file.
|
||||
crudFilePath := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+".go")
|
||||
crudTmplFile := "model_crud.tmpl"
|
||||
err := updateModelCrudFile(db, log, dbName, dbTable, templateDir, crudFilePath, crudTmplFile, baseModel, tmplData, saveChanges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Process the CRUD test template and write to file.
|
||||
testFilePath := filepath.Join(modelDir, FormatCamelLowerUnderscore(baseModel.Name)+"_test.go")
|
||||
testTmplFile := "model_crud_test.tmpl"
|
||||
err = updateModelCrudFile(db, log, dbName, dbTable, templateDir, testFilePath, testTmplFile, baseModel, tmplData, saveChanges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateModelCrudFile processes the input file.
|
||||
func updateModelCrudFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, templateDir, crudFilePath, tmplFile string, baseModel *modelDef, tmplData map[string]interface{}, saveChanges bool) error {
|
||||
|
||||
// Execute template and parse code to be used to compare against modelFile.
|
||||
tmplObjs, err := loadTemplateObjects(log, baseModel, templateDir, tmplFile, tmplData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var crudDoc *goparse.GoDocument
|
||||
if _, err := os.Stat(crudFilePath); os.IsNotExist(err) {
|
||||
crudDoc, err = goparse.NewGoDocument(baseModel.Package)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Parse the supplied model file.
|
||||
crudDoc, err = goparse.ParseFile(log, crudFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Store the current code as a string to produce a diff.
|
||||
curCode := crudDoc.String()
|
||||
|
||||
objHeaders := []*goparse.GoObject{}
|
||||
|
||||
for _, obj := range tmplObjs {
|
||||
if obj.Type == goparse.GoObjectType_Comment || obj.Type == goparse.GoObjectType_LineBreak {
|
||||
objHeaders = append(objHeaders, obj)
|
||||
continue
|
||||
}
|
||||
|
||||
if obj.Name == "" && (obj.Type == goparse.GoObjectType_Var || obj.Type == goparse.GoObjectType_Const) {
|
||||
var curDocObj *goparse.GoObject
|
||||
for _, subObj := range obj.Objects().List() {
|
||||
for _, do := range crudDoc.Objects().List() {
|
||||
if do.Name == "" && (do.Type == goparse.GoObjectType_Var || do.Type == goparse.GoObjectType_Const) {
|
||||
for _, subDocObj := range do.Objects().List() {
|
||||
if subDocObj.String() == subObj.String() && subObj.Type != goparse.GoObjectType_LineBreak {
|
||||
curDocObj = do
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if curDocObj != nil {
|
||||
for _, subObj := range obj.Objects().List() {
|
||||
var hasSubObj bool
|
||||
for _, subDocObj := range curDocObj.Objects().List() {
|
||||
if subDocObj.String() == subObj.String() {
|
||||
hasSubObj = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasSubObj {
|
||||
curDocObj.Objects().Add(subObj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := crudDoc.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := crudDoc.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if crudDoc.HasType(obj.Name, obj.Type) {
|
||||
cur := crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
|
||||
newObjs := []*goparse.GoObject{}
|
||||
if len(objHeaders) > 0 {
|
||||
// Remove any comments and linebreaks before the existing object so updates can be added.
|
||||
removeObjs := []*goparse.GoObject{}
|
||||
for idx := cur.Index - 1; idx > 0; idx-- {
|
||||
prevObj := crudDoc.Objects().List()[idx]
|
||||
if prevObj.Type == goparse.GoObjectType_Comment || prevObj.Type == goparse.GoObjectType_LineBreak {
|
||||
removeObjs = append(removeObjs, prevObj)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(removeObjs) > 0 {
|
||||
err := crudDoc.Objects().Remove(removeObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the current index is correct.
|
||||
cur = crudDoc.Objects().Get(obj.Name, obj.Type)
|
||||
}
|
||||
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
newObjs = append(newObjs, c)
|
||||
}
|
||||
}
|
||||
|
||||
newObjs = append(newObjs, obj)
|
||||
|
||||
// Do the object replacement.
|
||||
err := crudDoc.Objects().Replace(cur, newObjs...)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to update object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Append comments and line breaks before adding the object
|
||||
for _, c := range objHeaders {
|
||||
err := crudDoc.Objects().Add(c)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", c.Type, c.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err := crudDoc.Objects().Add(obj)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to add object %s %s for %s", obj.Type, obj.Name, baseModel.Name)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
objHeaders = []*goparse.GoObject{}
|
||||
}
|
||||
|
||||
if saveChanges {
|
||||
err = crudDoc.Save(crudFilePath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to save changes for %s to %s", baseModel.Name, crudFilePath)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Produce a diff after the updates have been applied.
|
||||
dmp := diffmatchpatch.New()
|
||||
diffs := dmp.DiffMain(curCode, crudDoc.String(), true)
|
||||
fmt.Println(dmp.DiffPrettyText(diffs))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,229 +0,0 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/tools/truss/internal/goparse"
|
||||
"github.com/fatih/structtag"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// modelDef defines info about the struct and associated db table.
|
||||
type modelDef struct {
|
||||
*goparse.GoDocument
|
||||
Name string
|
||||
TableName string
|
||||
PrimaryField string
|
||||
PrimaryColumn string
|
||||
PrimaryType string
|
||||
Fields []modelField
|
||||
FieldNames []string
|
||||
ColumnNames []string
|
||||
}
|
||||
|
||||
// modelField defines a struct field and associated db column.
|
||||
type modelField struct {
|
||||
ColumnName string
|
||||
DbColumn *psqlColumn
|
||||
FieldName string
|
||||
FieldType string
|
||||
FieldIsPtr bool
|
||||
Tags *structtag.Tags
|
||||
ApiHide bool
|
||||
ApiRead bool
|
||||
ApiCreate bool
|
||||
ApiUpdate bool
|
||||
DefaultValue string
|
||||
}
|
||||
|
||||
// parseModelFile parses the entire model file and then loads the specified model struct.
|
||||
func parseModelFile(db *sqlx.DB, log *log.Logger, dbName, dbTable, modelFile, modelName string) (*modelDef, error) {
|
||||
|
||||
// Parse the supplied model file.
|
||||
doc, err := goparse.ParseFile(log, modelFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Init new modelDef.
|
||||
model := &modelDef{
|
||||
GoDocument: doc,
|
||||
Name: modelName,
|
||||
TableName: dbTable,
|
||||
}
|
||||
|
||||
// Append the field the the model def.
|
||||
model.Fields, err = parseModelFields(doc, modelName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sf := range model.Fields {
|
||||
model.FieldNames = append(model.FieldNames, sf.FieldName)
|
||||
model.ColumnNames = append(model.ColumnNames, sf.ColumnName)
|
||||
}
|
||||
|
||||
// Query the database for a table definition.
|
||||
dbCols, err := descTable(db, dbName, dbTable)
|
||||
if err != nil {
|
||||
return model, err
|
||||
}
|
||||
|
||||
// Loop over all the database table columns and append to the associated
|
||||
// struct field. Don't force all database table columns to be defined in the
|
||||
// in the struct.
|
||||
for _, dbCol := range dbCols {
|
||||
for idx, sf := range model.Fields {
|
||||
if sf.ColumnName != dbCol.Column {
|
||||
continue
|
||||
}
|
||||
|
||||
if dbCol.IsPrimaryKey {
|
||||
model.PrimaryColumn = sf.ColumnName
|
||||
model.PrimaryField = sf.FieldName
|
||||
model.PrimaryType = sf.FieldType
|
||||
}
|
||||
|
||||
if dbCol.DefaultValue != nil {
|
||||
sf.DefaultValue = *dbCol.DefaultValue
|
||||
|
||||
if dbCol.IsEnum {
|
||||
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
|
||||
sf.DefaultValue = sf.FieldType + "_" + FormatCamel(sf.DefaultValue)
|
||||
} else if strings.HasPrefix(sf.DefaultValue, "'") {
|
||||
sf.DefaultValue = strings.Trim(sf.DefaultValue, "'")
|
||||
sf.DefaultValue = "\"" + sf.DefaultValue + "\""
|
||||
}
|
||||
}
|
||||
|
||||
c := dbCol
|
||||
sf.DbColumn = &c
|
||||
model.Fields[idx] = sf
|
||||
}
|
||||
}
|
||||
|
||||
// Print out the model for debugging.
|
||||
//modelJSON, err := json.MarshalIndent(model, "", " ")
|
||||
//if err != nil {
|
||||
// return model, errors.WithStack(err )
|
||||
//}
|
||||
//log.Printf(string(modelJSON))
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// parseModelFields parses the fields from a struct.
|
||||
func parseModelFields(doc *goparse.GoDocument, modelName string, baseModel *modelDef) ([]modelField, error) {
|
||||
|
||||
// Ensure the model file has a struct with the model name supplied.
|
||||
if !doc.HasType(modelName, goparse.GoObjectType_Struct) {
|
||||
err := errors.Errorf("Struct with the name %s does not exist", modelName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load the struct from parsed go file.
|
||||
docModel := doc.Get(modelName, goparse.GoObjectType_Struct)
|
||||
|
||||
// Loop over all the objects contained between the struct definition start and end.
|
||||
// This should be a list of variables defined for model.
|
||||
resp := []modelField{}
|
||||
for _, l := range docModel.Objects().List() {
|
||||
|
||||
// Skip all lines that are not a var.
|
||||
if l.Type != goparse.GoObjectType_Line {
|
||||
log.Printf("parseModelFile : Model %s has line that is %s, not type line, skipping - %s\n", modelName, l.Type, l.String())
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the var name, type and defined tags from the line.
|
||||
sv, err := goparse.ParseStructProp(l)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Init new modelField for the struct var.
|
||||
sf := modelField{
|
||||
FieldName: sv.Name,
|
||||
FieldType: sv.Type,
|
||||
FieldIsPtr: strings.HasPrefix(sv.Type, "*"),
|
||||
Tags: sv.Tags,
|
||||
}
|
||||
|
||||
// Extract the column name from the var tags.
|
||||
if sf.Tags != nil {
|
||||
// First try to get the column name from the db tag.
|
||||
dbt, err := sf.Tags.Get("db")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if dbt != nil {
|
||||
sf.ColumnName = dbt.Name
|
||||
}
|
||||
|
||||
// Second try to get the column name from the json tag.
|
||||
if sf.ColumnName == "" {
|
||||
jt, err := sf.Tags.Get("json")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if jt != nil && jt.Name != "-" {
|
||||
sf.ColumnName = jt.Name
|
||||
}
|
||||
}
|
||||
|
||||
var apiActionsSet bool
|
||||
tt, err := sf.Tags.Get("truss")
|
||||
if err != nil && !strings.Contains(err.Error(), "not exist") {
|
||||
err = errors.WithStack(err)
|
||||
return nil, err
|
||||
} else if tt != nil {
|
||||
if tt.Name == "api-create" || tt.HasOption("api-create") {
|
||||
sf.ApiCreate = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-read" || tt.HasOption("api-read") {
|
||||
sf.ApiRead = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-update" || tt.HasOption("api-update") {
|
||||
sf.ApiUpdate = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
if tt.Name == "api-hide" || tt.HasOption("api-hide") {
|
||||
sf.ApiHide = true
|
||||
apiActionsSet = true
|
||||
}
|
||||
}
|
||||
|
||||
if !apiActionsSet {
|
||||
sf.ApiCreate = true
|
||||
sf.ApiRead = true
|
||||
sf.ApiUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// Set the column name to the field name if empty and does not equal '-'.
|
||||
if sf.ColumnName == "" {
|
||||
sf.ColumnName = sf.FieldName
|
||||
}
|
||||
|
||||
// If a base model as already been parsed with the db columns,
|
||||
// append to the current field.
|
||||
if baseModel != nil {
|
||||
for _, baseSf := range baseModel.Fields {
|
||||
if baseSf.ColumnName == sf.ColumnName {
|
||||
sf.DefaultValue = baseSf.DefaultValue
|
||||
sf.DbColumn = baseSf.DbColumn
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Append the field the the model def.
|
||||
resp = append(resp, sf)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
@ -1,345 +0,0 @@
|
||||
package dbtable2crud
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/tools/truss/internal/goparse"
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/fatih/camelcase"
|
||||
"github.com/iancoleman/strcase"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// loadTemplateObjects executes a template file based on the given model struct and
|
||||
// returns the parsed go objects.
|
||||
func loadTemplateObjects(log *log.Logger, model *modelDef, templateDir, filename string, tmptData map[string]interface{}) ([]*goparse.GoObject, error) {
|
||||
|
||||
// Data used to execute all the of defined code sections in the template file.
|
||||
if tmptData == nil {
|
||||
tmptData = make(map[string]interface{})
|
||||
}
|
||||
tmptData["Model"] = model
|
||||
|
||||
// geeks-accelerator/oss/saas-starter-kit/example-project
|
||||
|
||||
// Read the template file from the local file system.
|
||||
tempFilePath := filepath.Join(templateDir, filename)
|
||||
dat, err := ioutil.ReadFile(tempFilePath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read template file %s", tempFilePath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// New template with custom functions.
|
||||
baseTmpl := template.New("base")
|
||||
baseTmpl.Funcs(template.FuncMap{
|
||||
"Concat": func(vals ...string) string {
|
||||
return strings.Join(vals, "")
|
||||
},
|
||||
"JoinStrings": func(vals []string, sep string) string {
|
||||
return strings.Join(vals, sep)
|
||||
},
|
||||
"PrefixAndJoinStrings": func(vals []string, pre, sep string) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, pre+v)
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FmtAndJoinStrings": func(vals []string, fmtStr, sep string) string {
|
||||
l := []string{}
|
||||
for _, v := range vals {
|
||||
l = append(l, fmt.Sprintf(fmtStr, v))
|
||||
}
|
||||
return strings.Join(l, sep)
|
||||
},
|
||||
"FormatCamel": func(name string) string {
|
||||
return FormatCamel(name)
|
||||
},
|
||||
"FormatCamelTitle": func(name string) string {
|
||||
return FormatCamelTitle(name)
|
||||
},
|
||||
"FormatCamelLower": func(name string) string {
|
||||
if name == "ID" {
|
||||
return "id"
|
||||
}
|
||||
return FormatCamelLower(name)
|
||||
},
|
||||
"FormatCamelLowerTitle": func(name string) string {
|
||||
return FormatCamelLowerTitle(name)
|
||||
},
|
||||
"FormatCamelPluralTitle": func(name string) string {
|
||||
return FormatCamelPluralTitle(name)
|
||||
},
|
||||
"FormatCamelPluralTitleLower": func(name string) string {
|
||||
return FormatCamelPluralTitleLower(name)
|
||||
},
|
||||
"FormatCamelPluralCamel": func(name string) string {
|
||||
return FormatCamelPluralCamel(name)
|
||||
},
|
||||
"FormatCamelPluralLower": func(name string) string {
|
||||
return FormatCamelPluralLower(name)
|
||||
},
|
||||
"FormatCamelPluralUnderscore": func(name string) string {
|
||||
return FormatCamelPluralUnderscore(name)
|
||||
},
|
||||
"FormatCamelPluralLowerUnderscore": func(name string) string {
|
||||
return FormatCamelPluralLowerUnderscore(name)
|
||||
},
|
||||
"FormatCamelUnderscore": func(name string) string {
|
||||
return FormatCamelUnderscore(name)
|
||||
},
|
||||
"FormatCamelLowerUnderscore": func(name string) string {
|
||||
return FormatCamelLowerUnderscore(name)
|
||||
},
|
||||
"FieldTagHasOption": func(f modelField, tagName, optName string) bool {
|
||||
if f.Tags == nil {
|
||||
return false
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return false
|
||||
}
|
||||
if ft.Name == optName || ft.HasOption(optName) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
"FieldTag": func(f modelField, tagName string) string {
|
||||
if f.Tags == nil {
|
||||
return ""
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return ""
|
||||
}
|
||||
return ft.String()
|
||||
},
|
||||
"FieldTagReplaceOrPrepend": func(f modelField, tagName, oldVal, newVal string) string {
|
||||
if f.Tags == nil {
|
||||
return ""
|
||||
}
|
||||
ft, err := f.Tags.Get(tagName)
|
||||
if ft == nil || err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if ft.Name == oldVal || ft.Name == newVal {
|
||||
ft.Name = newVal
|
||||
} else if ft.HasOption(oldVal) {
|
||||
for idx, val := range ft.Options {
|
||||
if val == oldVal {
|
||||
ft.Options[idx] = newVal
|
||||
}
|
||||
}
|
||||
} else if !ft.HasOption(newVal) {
|
||||
if ft.Name == "" {
|
||||
ft.Name = newVal
|
||||
} else {
|
||||
ft.Options = append(ft.Options, newVal)
|
||||
}
|
||||
}
|
||||
|
||||
return ft.String()
|
||||
},
|
||||
"StringListHasValue": func(list []string, val string) bool {
|
||||
for _, v := range list {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
})
|
||||
|
||||
// Load the template file using the text/template package.
|
||||
tmpl, err := baseTmpl.Parse(string(dat))
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse template file %s", tempFilePath)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, string(dat))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a list of template names defined in the template file.
|
||||
tmplNames := []string{}
|
||||
for _, defTmpl := range tmpl.Templates() {
|
||||
tmplNames = append(tmplNames, defTmpl.Name())
|
||||
}
|
||||
|
||||
// Stupid hack to return template names the in order they are defined in the file.
|
||||
tmplNames, err = templateFileOrderedNames(tempFilePath, tmplNames)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Loop over all the defined templates, execute using the defined data, parse the
|
||||
// formatted code and append the parsed go objects to the result list.
|
||||
var resp []*goparse.GoObject
|
||||
for _, tmplName := range tmplNames {
|
||||
// Executed the defined template with the given data.
|
||||
var tpl bytes.Buffer
|
||||
if err := tmpl.Lookup(tmplName).Execute(&tpl, tmptData); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to execute %s from template file %s", tmplName, tempFilePath)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Format the source code to ensure its valid and code to parsed consistently.
|
||||
codeBytes, err := format.Source(tpl.Bytes())
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for %s in template file %s", tmplName, filename)
|
||||
|
||||
dl := []string{}
|
||||
for idx, l := range strings.Split(tpl.String(), "\n") {
|
||||
dl = append(dl, fmt.Sprintf("%d -> ", idx)+l)
|
||||
}
|
||||
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, strings.Join(dl, "\n"))
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Remove extra white space from the code.
|
||||
codeStr := strings.TrimSpace(string(codeBytes))
|
||||
|
||||
// Split the code into a list of strings.
|
||||
codeLines := strings.Split(codeStr, "\n")
|
||||
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := goparse.ParseLines(codeLines, 0)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse %s in template file %s", tmplName, filename)
|
||||
log.Printf("loadTemplateObjects : %v\n%v", err, codeStr)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Append the parsed objects to the return result list.
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Name == "" && obj.Type != goparse.GoObjectType_Import && obj.Type != goparse.GoObjectType_Var && obj.Type != goparse.GoObjectType_Const && obj.Type != goparse.GoObjectType_Comment && obj.Type != goparse.GoObjectType_LineBreak {
|
||||
// All objects should have a name except for multiline var/const declarations and comments.
|
||||
err = errors.Errorf("Failed to parse name with type %s from lines: %v", obj.Type, obj.Lines())
|
||||
return resp, err
|
||||
} else if string(obj.Type) == "" {
|
||||
err = errors.Errorf("Failed to parse type for %s from lines: %v", obj.Name, obj.Lines())
|
||||
return resp, err
|
||||
}
|
||||
|
||||
resp = append(resp, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// FormatCamel formats Valdez mountain to ValdezMountain
|
||||
func FormatCamel(name string) string {
|
||||
return strcase.ToCamel(name)
|
||||
}
|
||||
|
||||
// FormatCamelLower formats ValdezMountain to valdezmountain
|
||||
func FormatCamelLower(name string) string {
|
||||
return strcase.ToLowerCamel(FormatCamel(name))
|
||||
}
|
||||
|
||||
// FormatCamelTitle formats ValdezMountain to Valdez Mountain
|
||||
func FormatCamelTitle(name string) string {
|
||||
return strings.Join(camelcase.Split(name), " ")
|
||||
}
|
||||
|
||||
// FormatCamelLowerTitle formats ValdezMountain to valdez mountain
|
||||
func FormatCamelLowerTitle(name string) string {
|
||||
return strings.ToLower(FormatCamelTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralTitle formats ValdezMountain to Valdez Mountains
|
||||
func FormatCamelPluralTitle(name string) string {
|
||||
pts := camelcase.Split(name)
|
||||
lastIdx := len(pts) - 1
|
||||
pts[lastIdx] = english.PluralWord(2, pts[lastIdx], "")
|
||||
return strings.Join(pts, " ")
|
||||
}
|
||||
|
||||
// FormatCamelPluralTitleLower formats ValdezMountain to valdez mountains
|
||||
func FormatCamelPluralTitleLower(name string) string {
|
||||
return strings.ToLower(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralCamel formats ValdezMountain to ValdezMountains
|
||||
func FormatCamelPluralCamel(name string) string {
|
||||
return strcase.ToCamel(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralLower formats ValdezMountain to valdezmountains
|
||||
func FormatCamelPluralLower(name string) string {
|
||||
return strcase.ToLowerCamel(FormatCamelPluralTitle(name))
|
||||
}
|
||||
|
||||
// FormatCamelPluralUnderscore formats ValdezMountain to Valdez_Mountains
|
||||
func FormatCamelPluralUnderscore(name string) string {
|
||||
return strings.Replace(FormatCamelPluralTitle(name), " ", "_", -1)
|
||||
}
|
||||
|
||||
// FormatCamelPluralLowerUnderscore formats ValdezMountain to valdez_mountains
|
||||
func FormatCamelPluralLowerUnderscore(name string) string {
|
||||
return strings.ToLower(FormatCamelPluralUnderscore(name))
|
||||
}
|
||||
|
||||
// FormatCamelUnderscore formats ValdezMountain to Valdez_Mountain
|
||||
func FormatCamelUnderscore(name string) string {
|
||||
return strings.Replace(FormatCamelTitle(name), " ", "_", -1)
|
||||
}
|
||||
|
||||
// FormatCamelLowerUnderscore formats ValdezMountain to valdez_mountain
|
||||
func FormatCamelLowerUnderscore(name string) string {
|
||||
return strings.ToLower(FormatCamelUnderscore(name))
|
||||
}
|
||||
|
||||
// templateFileOrderedNames returns the template names the in order they are defined in the file.
|
||||
func templateFileOrderedNames(localPath string, names []string) (resp []string, err error) {
|
||||
file, err := os.Open(localPath)
|
||||
if err != nil {
|
||||
return resp, errors.WithStack(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
idxList := []int{}
|
||||
idxNames := make(map[int]string)
|
||||
|
||||
idx := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
if !strings.HasPrefix(scanner.Text(), "{{") || !strings.Contains(scanner.Text(), "define ") {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
if strings.Contains(scanner.Text(), "\""+name+"\"") {
|
||||
idxList = append(idxList, idx)
|
||||
idxNames[idx] = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
idx = idx + 1
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return resp, errors.WithStack(err)
|
||||
}
|
||||
|
||||
sort.Ints(idxList)
|
||||
|
||||
for _, idx := range idxList {
|
||||
resp = append(resp, idxNames[idx])
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
@ -1,301 +0,0 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// GoDocument defines a single go code file.
|
||||
type GoDocument struct {
|
||||
*GoObjects
|
||||
Package string
|
||||
imports GoImports
|
||||
}
|
||||
|
||||
// GoImport defines a single import line with optional alias.
|
||||
type GoImport struct {
|
||||
Name string
|
||||
Alias string
|
||||
}
|
||||
|
||||
// GoImports holds a list of import lines.
|
||||
type GoImports []GoImport
|
||||
|
||||
// NewGoDocument creates a new GoDocument with the package line set.
|
||||
func NewGoDocument(packageName string) (doc *GoDocument, err error) {
|
||||
doc = &GoDocument{
|
||||
GoObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
err = doc.SetPackage(packageName)
|
||||
return doc, err
|
||||
}
|
||||
|
||||
// Objects returns a list of root GoObject.
|
||||
func (doc *GoDocument) Objects() *GoObjects {
|
||||
if doc.GoObjects == nil {
|
||||
doc.GoObjects = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
}
|
||||
|
||||
return doc.GoObjects
|
||||
}
|
||||
|
||||
// NewObjectPackage returns a new GoObject with a single package definition line.
|
||||
func NewObjectPackage(packageName string) *GoObject {
|
||||
lines := []string{
|
||||
fmt.Sprintf("package %s", packageName),
|
||||
"",
|
||||
}
|
||||
|
||||
obj, _ := ParseGoObject(lines, 0)
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
// SetPackage appends sets the package line for the code file.
|
||||
func (doc *GoDocument) SetPackage(packageName string) error {
|
||||
|
||||
var existing *GoObject
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type == GoObjectType_Package {
|
||||
existing = obj
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
new := NewObjectPackage(packageName)
|
||||
|
||||
var err error
|
||||
if existing != nil {
|
||||
err = doc.Objects().Replace(existing, new)
|
||||
} else if len(doc.Objects().List()) > 0 {
|
||||
|
||||
// Insert after any existing comments or line breaks.
|
||||
var insertPos int
|
||||
//for idx, obj := range doc.Objects().List() {
|
||||
// switch obj.Type {
|
||||
// case GoObjectType_Comment, GoObjectType_LineBreak:
|
||||
// insertPos = idx
|
||||
// default:
|
||||
// break
|
||||
// }
|
||||
//}
|
||||
|
||||
err = doc.Objects().Insert(insertPos, new)
|
||||
} else {
|
||||
err = doc.Objects().Add(new)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AddObject appends a new GoObject to the doc root object list.
|
||||
func (doc *GoDocument) AddObject(newObj *GoObject) error {
|
||||
return doc.Objects().Add(newObj)
|
||||
}
|
||||
|
||||
// InsertObject inserts a new GoObject at the desired position to the doc root object list.
|
||||
func (doc *GoDocument) InsertObject(pos int, newObj *GoObject) error {
|
||||
return doc.Objects().Insert(pos, newObj)
|
||||
}
|
||||
|
||||
// Imports returns the GoDocument imports.
|
||||
func (doc *GoDocument) Imports() (GoImports, error) {
|
||||
// If the doc imports are empty, try to load them from the root objects.
|
||||
if len(doc.imports) == 0 {
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type != GoObjectType_Import {
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := ParseImportObject(obj)
|
||||
if err != nil {
|
||||
return doc.imports, err
|
||||
}
|
||||
|
||||
// Combine all the imports into a single definition.
|
||||
for _, n := range res {
|
||||
doc.imports = append(doc.imports, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return doc.imports, nil
|
||||
}
|
||||
|
||||
// Lines returns all the code lines.
|
||||
func (doc *GoDocument) Lines() []string {
|
||||
l := []string{}
|
||||
|
||||
for _, ol := range doc.Objects().Lines() {
|
||||
l = append(l, ol)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns a single value for all the code lines.
|
||||
func (doc *GoDocument) String() string {
|
||||
return strings.Join(doc.Lines(), "\n")
|
||||
}
|
||||
|
||||
// Print writes all the code lines to standard out.
|
||||
func (doc *GoDocument) Print() {
|
||||
for _, l := range doc.Lines() {
|
||||
fmt.Println(l)
|
||||
}
|
||||
}
|
||||
|
||||
// Save renders all the code lines for the document, formats the code
|
||||
// and then saves it to the supplied file path.
|
||||
func (doc *GoDocument) Save(localpath string) error {
|
||||
res, err := format.Source([]byte(doc.String()))
|
||||
if err != nil {
|
||||
err = errors.WithMessage(err, "Failed formatted source code")
|
||||
return err
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(localpath, res, 0644)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed write source code to file %s", localpath)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddImport checks for any duplicate imports by name and adds it if not.
|
||||
func (doc *GoDocument) AddImport(impt GoImport) error {
|
||||
impt.Name = strings.Trim(impt.Name, "\"")
|
||||
|
||||
// Get a list of current imports for the document.
|
||||
impts, err := doc.Imports()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the document has as the import, don't add it.
|
||||
if impts.Has(impt.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Loop through all the document root objects for an object of type import.
|
||||
// If one exists, append the import to the existing list.
|
||||
for _, obj := range doc.Objects().List() {
|
||||
if obj.Type != GoObjectType_Import || len(obj.Lines()) == 1 {
|
||||
continue
|
||||
}
|
||||
obj.subLines = append(obj.subLines, impt.String())
|
||||
obj.goObjects.list = append(obj.goObjects.list, impt.Object())
|
||||
|
||||
doc.imports = append(doc.imports, impt)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Document does not have an existing import object, so need to create one and
|
||||
// then append to the document.
|
||||
newObj := NewObjectImports(impt)
|
||||
|
||||
// Insert after any package, any existing comments or line breaks should be included.
|
||||
var insertPos int
|
||||
for idx, obj := range doc.Objects().List() {
|
||||
switch obj.Type {
|
||||
case GoObjectType_Package, GoObjectType_Comment, GoObjectType_LineBreak:
|
||||
insertPos = idx
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Insert the new import object.
|
||||
err = doc.InsertObject(insertPos, newObj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewObjectImports returns a new GoObject with a single import definition.
|
||||
func NewObjectImports(impt GoImport) *GoObject {
|
||||
lines := []string{
|
||||
"import (",
|
||||
impt.String(),
|
||||
")",
|
||||
"",
|
||||
}
|
||||
|
||||
obj, _ := ParseGoObject(lines, 0)
|
||||
children, err := ParseLines(obj.subLines, 1)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, child := range children.List() {
|
||||
obj.Objects().Add(child)
|
||||
}
|
||||
|
||||
return obj
|
||||
}
|
||||
|
||||
// Has checks to see if an import exists by name or alias.
|
||||
func (impts GoImports) Has(name string) bool {
|
||||
for _, impt := range impts {
|
||||
if name == impt.Name || (impt.Alias != "" && name == impt.Alias) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Line formats an import as a string.
|
||||
func (impt GoImport) String() string {
|
||||
var imptLine string
|
||||
if impt.Alias != "" {
|
||||
imptLine = fmt.Sprintf("\t%s \"%s\"", impt.Alias, impt.Name)
|
||||
} else {
|
||||
imptLine = fmt.Sprintf("\t\"%s\"", impt.Name)
|
||||
}
|
||||
return imptLine
|
||||
}
|
||||
|
||||
// Object returns the first GoObject for an import.
|
||||
func (impt GoImport) Object() *GoObject {
|
||||
imptObj := NewObjectImports(impt)
|
||||
|
||||
return imptObj.Objects().List()[0]
|
||||
}
|
||||
|
||||
// ParseImportObject extracts all the import definitions.
|
||||
func ParseImportObject(obj *GoObject) (resp GoImports, err error) {
|
||||
if obj.Type != GoObjectType_Import {
|
||||
return resp, errors.Errorf("Invalid type %s", string(obj.Type))
|
||||
}
|
||||
|
||||
for _, l := range obj.Lines() {
|
||||
if !strings.Contains(l, "\"") {
|
||||
continue
|
||||
}
|
||||
l = strings.TrimSpace(l)
|
||||
|
||||
pts := strings.Split(l, "\"")
|
||||
|
||||
var impt GoImport
|
||||
if strings.HasPrefix(l, "\"") {
|
||||
impt.Name = pts[1]
|
||||
} else {
|
||||
impt.Alias = strings.TrimSpace(pts[0])
|
||||
impt.Name = pts[1]
|
||||
}
|
||||
|
||||
resp = append(resp, impt)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
@ -1,458 +0,0 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/structtag"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// GoEmptyLine defined a GoObject for a code line break.
|
||||
var GoEmptyLine = GoObject{
|
||||
Type: GoObjectType_LineBreak,
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
|
||||
// GoObjectType defines a set of possible types to group
|
||||
// parsed code by.
|
||||
type GoObjectType = string
|
||||
|
||||
var (
|
||||
GoObjectType_Package = "package"
|
||||
GoObjectType_Import = "import"
|
||||
GoObjectType_Var = "var"
|
||||
GoObjectType_Const = "const"
|
||||
GoObjectType_Func = "func"
|
||||
GoObjectType_Struct = "struct"
|
||||
GoObjectType_Comment = "comment"
|
||||
GoObjectType_LineBreak = "linebreak"
|
||||
GoObjectType_Line = "line"
|
||||
GoObjectType_Type = "type"
|
||||
)
|
||||
|
||||
// GoObject defines a section of code with a nested set of children.
|
||||
type GoObject struct {
|
||||
Type GoObjectType
|
||||
Name string
|
||||
startLines []string
|
||||
endLines []string
|
||||
subLines []string
|
||||
goObjects *GoObjects
|
||||
Index int
|
||||
}
|
||||
|
||||
// GoObjects stores a list of GoObject.
|
||||
type GoObjects struct {
|
||||
list []*GoObject
|
||||
}
|
||||
|
||||
// Objects returns the list of *GoObject.
|
||||
func (obj *GoObject) Objects() *GoObjects {
|
||||
if obj.goObjects == nil {
|
||||
obj.goObjects = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
}
|
||||
return obj.goObjects
|
||||
}
|
||||
|
||||
// Clone performs a deep copy of the struct.
|
||||
func (obj *GoObject) Clone() *GoObject {
|
||||
n := &GoObject{
|
||||
Type: obj.Type,
|
||||
Name: obj.Name,
|
||||
startLines: obj.startLines,
|
||||
endLines: obj.endLines,
|
||||
subLines: obj.subLines,
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
Index: obj.Index,
|
||||
}
|
||||
for _, sub := range obj.Objects().List() {
|
||||
n.Objects().Add(sub.Clone())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// IsComment returns whether an object is of type GoObjectType_Comment.
|
||||
func (obj *GoObject) IsComment() bool {
|
||||
if obj.Type != GoObjectType_Comment {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Contains searches all the lines for the object for a matching string.
|
||||
func (obj *GoObject) Contains(match string) bool {
|
||||
for _, l := range obj.Lines() {
|
||||
if strings.Contains(l, match) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateLines parses the new code and replaces the current GoObject.
|
||||
func (obj *GoObject) UpdateLines(newLines []string) error {
|
||||
|
||||
// Parse the new lines.
|
||||
objs, err := ParseLines(newLines, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var newObj *GoObject
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Type == GoObjectType_LineBreak {
|
||||
continue
|
||||
}
|
||||
|
||||
if newObj == nil {
|
||||
newObj = obj
|
||||
}
|
||||
|
||||
// There should only be one resulting parsed object that is
|
||||
// not of type GoObjectType_LineBreak.
|
||||
return errors.New("Can only update single blocks of code")
|
||||
}
|
||||
|
||||
// No new code was parsed, return error.
|
||||
if newObj == nil {
|
||||
return errors.New("Failed to render replacement code")
|
||||
}
|
||||
|
||||
return obj.Update(newObj)
|
||||
}
|
||||
|
||||
// Update performs a deep copy that overwrites the existing values.
|
||||
func (obj *GoObject) Update(newObj *GoObject) error {
|
||||
obj.Type = newObj.Type
|
||||
obj.Name = newObj.Name
|
||||
obj.startLines = newObj.startLines
|
||||
obj.endLines = newObj.endLines
|
||||
obj.subLines = newObj.subLines
|
||||
obj.goObjects = newObj.goObjects
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lines returns a list of strings for current object and all children.
|
||||
func (obj *GoObject) Lines() []string {
|
||||
l := []string{}
|
||||
|
||||
// First include any lines before the sub objects.
|
||||
for _, sl := range obj.startLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
|
||||
// If there are parsed sub objects include those lines else when
|
||||
// no sub objects, just use the sub lines.
|
||||
if len(obj.Objects().List()) > 0 {
|
||||
for _, sl := range obj.Objects().Lines() {
|
||||
l = append(l, sl)
|
||||
}
|
||||
} else {
|
||||
for _, sl := range obj.subLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
}
|
||||
|
||||
// Lastly include any other lines that are after all parsed sub objects.
|
||||
for _, sl := range obj.endLines {
|
||||
l = append(l, sl)
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns the lines separated by line break.
|
||||
func (obj *GoObject) String() string {
|
||||
return strings.Join(obj.Lines(), "\n")
|
||||
}
|
||||
|
||||
// Lines returns a list of strings for all the list objects.
|
||||
func (objs *GoObjects) Lines() []string {
|
||||
l := []string{}
|
||||
for _, obj := range objs.List() {
|
||||
for _, oj := range obj.Lines() {
|
||||
l = append(l, oj)
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
// String returns all the lines for the list objects.
|
||||
func (objs *GoObjects) String() string {
|
||||
lines := []string{}
|
||||
for _, obj := range objs.List() {
|
||||
lines = append(lines, obj.String())
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// List returns the list of GoObjects.
|
||||
func (objs *GoObjects) List() []*GoObject {
|
||||
return objs.list
|
||||
}
|
||||
|
||||
// HasFunc searches the current list of objects for a function object by name.
|
||||
func (objs *GoObjects) HasFunc(name string) bool {
|
||||
return objs.HasType(name, GoObjectType_Func)
|
||||
}
|
||||
|
||||
// Get returns the GoObject for the matching name and type.
|
||||
func (objs *GoObjects) Get(name string, objType GoObjectType) *GoObject {
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || obj.Type == objType) {
|
||||
return obj
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasType checks is a GoObject exists for the matching name and type.
|
||||
func (objs *GoObjects) HasType(name string, objType GoObjectType) bool {
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || obj.Type == objType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasObject checks to see if the exact code block exists.
|
||||
func (objs *GoObjects) HasObject(src *GoObject) bool {
|
||||
if src == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Generate the code for the supplied object.
|
||||
srcLines := []string{}
|
||||
for _, l := range src.Lines() {
|
||||
// Exclude empty lines.
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
srcLines = append(srcLines, l)
|
||||
}
|
||||
}
|
||||
srcStr := strings.Join(srcLines, "\n")
|
||||
|
||||
// Loop over all the objects and match with src code.
|
||||
for _, obj := range objs.list {
|
||||
objLines := []string{}
|
||||
for _, l := range obj.Lines() {
|
||||
// Exclude empty lines.
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
objLines = append(objLines, l)
|
||||
}
|
||||
}
|
||||
objStr := strings.Join(objLines, "\n")
|
||||
|
||||
// Return true if the current object code matches src code.
|
||||
if srcStr == objStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Add appends a new GoObject to the list.
|
||||
func (objs *GoObjects) Add(newObj *GoObject) error {
|
||||
newObj.Index = len(objs.list)
|
||||
objs.list = append(objs.list, newObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert appends a new GoObject at the desired position to the list.
|
||||
func (objs *GoObjects) Insert(pos int, newObj *GoObject) error {
|
||||
newList := []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range objs.list {
|
||||
if obj.Index < pos {
|
||||
obj.Index = newIdx
|
||||
newList = append(newList, obj)
|
||||
} else {
|
||||
if obj.Index == pos {
|
||||
newObj.Index = newIdx
|
||||
newList = append(newList, newObj)
|
||||
newIdx++
|
||||
}
|
||||
obj.Index = newIdx
|
||||
newList = append(newList, obj)
|
||||
}
|
||||
|
||||
newIdx++
|
||||
}
|
||||
|
||||
objs.list = newList
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove deletes a GoObject from the list.
|
||||
func (objs *GoObjects) Remove(delObjs ...*GoObject) error {
|
||||
for _, delObj := range delObjs {
|
||||
oldList := objs.List()
|
||||
objs.list = []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range oldList {
|
||||
if obj.Index == delObj.Index {
|
||||
continue
|
||||
}
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Replace updates an existing GoObject while maintaining is same position.
|
||||
func (objs *GoObjects) Replace(oldObj *GoObject, newObjs ...*GoObject) error {
|
||||
if oldObj.Index >= len(objs.list) {
|
||||
return errors.WithStack(errGoObjectNotExist)
|
||||
} else if len(newObjs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldList := objs.List()
|
||||
objs.list = []*GoObject{}
|
||||
|
||||
var newIdx int
|
||||
for _, obj := range oldList {
|
||||
if obj.Index < oldObj.Index {
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
} else if obj.Index == oldObj.Index {
|
||||
for _, newObj := range newObjs {
|
||||
newObj.Index = newIdx
|
||||
objs.list = append(objs.list, newObj)
|
||||
newIdx++
|
||||
}
|
||||
} else {
|
||||
obj.Index = newIdx
|
||||
objs.list = append(objs.list, obj)
|
||||
newIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceFuncByName finds an existing GoObject with type GoObjectType_Func by name
|
||||
// and then performs a replace with the supplied new GoObject.
|
||||
func (objs *GoObjects) ReplaceFuncByName(name string, fn *GoObject) error {
|
||||
return objs.ReplaceTypeByName(name, fn, GoObjectType_Func)
|
||||
}
|
||||
|
||||
// ReplaceTypeByName finds an existing GoObject with type by name
|
||||
// and then performs a replace with the supplied new GoObject.
|
||||
func (objs *GoObjects) ReplaceTypeByName(name string, newObj *GoObject, objType GoObjectType) error {
|
||||
if newObj.Name == "" {
|
||||
newObj.Name = name
|
||||
}
|
||||
if newObj.Type == "" && objType != "" {
|
||||
newObj.Type = objType
|
||||
}
|
||||
|
||||
for _, obj := range objs.list {
|
||||
if obj.Name == name && (objType == "" || objType == obj.Type) {
|
||||
return objs.Replace(obj, newObj)
|
||||
}
|
||||
}
|
||||
return errors.WithStack(errGoObjectNotExist)
|
||||
}
|
||||
|
||||
// Empty determines if all the GoObject in the list are line breaks.
|
||||
func (objs *GoObjects) Empty() bool {
|
||||
var hasStuff bool
|
||||
for _, obj := range objs.List() {
|
||||
switch obj.Type {
|
||||
case GoObjectType_LineBreak:
|
||||
//case GoObjectType_Comment:
|
||||
//case GoObjectType_Import:
|
||||
// do nothing
|
||||
default:
|
||||
hasStuff = true
|
||||
}
|
||||
}
|
||||
return hasStuff
|
||||
}
|
||||
|
||||
// Debug prints out the GoObject to logger.
|
||||
func (obj *GoObject) Debug(log *log.Logger) {
|
||||
log.Println(obj.Name)
|
||||
log.Println(" > type:", obj.Type)
|
||||
log.Println(" > start lines:")
|
||||
for _, l := range obj.startLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
|
||||
log.Println(" > sub lines:")
|
||||
for _, l := range obj.subLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
|
||||
log.Println(" > end lines:")
|
||||
for _, l := range obj.endLines {
|
||||
log.Println(" ", l)
|
||||
}
|
||||
}
|
||||
|
||||
// Defines a property of a struct.
|
||||
type structProp struct {
|
||||
Name string
|
||||
Type string
|
||||
Tags *structtag.Tags
|
||||
}
|
||||
|
||||
// ParseStructProp extracts the details for a struct property.
|
||||
func ParseStructProp(obj *GoObject) (structProp, error) {
|
||||
|
||||
if obj.Type != GoObjectType_Line {
|
||||
return structProp{}, errors.Errorf("Unable to parse object of type %s", obj.Type)
|
||||
}
|
||||
|
||||
// Remove any white space from the code line.
|
||||
ls := strings.TrimSpace(strings.Join(obj.Lines(), " "))
|
||||
|
||||
// Extract the property name and type for the line.
|
||||
// ie: ID string `json:"id"`
|
||||
var resp structProp
|
||||
for _, p := range strings.Split(ls, " ") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if resp.Name == "" {
|
||||
resp.Name = p
|
||||
} else if resp.Type == "" {
|
||||
resp.Type = p
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the line contains tags, extract and parse them.
|
||||
if strings.Contains(ls, "`") {
|
||||
tagStr := strings.Split(ls, "`")[1]
|
||||
|
||||
var err error
|
||||
resp.Tags, err = structtag.Parse(tagStr)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to parse struct tag for field %s: %s", resp.Name, tagStr)
|
||||
return structProp{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
@ -1,352 +0,0 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"go/format"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
errGoParseType = errors.New("Unable to determine type for line")
|
||||
errGoTypeMissingCodeTemplate = errors.New("No code defined for type")
|
||||
errGoObjectNotExist = errors.New("GoObject does not exist")
|
||||
)
|
||||
|
||||
// ParseFile reads a go code file and parses into a easily transformable set of objects.
|
||||
func ParseFile(log *log.Logger, localPath string) (*GoDocument, error) {
|
||||
|
||||
// Read the code file.
|
||||
src, err := ioutil.ReadFile(localPath)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to read file %s", localPath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Format the code file source to ensure parse works.
|
||||
dat, err := format.Source(src)
|
||||
if err != nil {
|
||||
err = errors.WithMessagef(err, "Failed to format source for file %s", localPath)
|
||||
log.Printf("ParseFile : %v\n%v", err, string(src))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Loop of the formatted source code and generate a list of code lines.
|
||||
lines := []string{}
|
||||
r := bytes.NewReader(dat)
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
lines = append(lines, scanner.Text())
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
err = errors.WithMessagef(err, "Failed read formatted source code for file %s", localPath)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the code lines into a set of objects.
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Append the resulting objects to the document.
|
||||
doc := &GoDocument{}
|
||||
for _, obj := range objs.List() {
|
||||
if obj.Type == GoObjectType_Package {
|
||||
doc.Package = obj.Name
|
||||
}
|
||||
doc.AddObject(obj)
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// ParseLines takes the list of formatted code lines and returns the GoObjects.
|
||||
func ParseLines(lines []string, depth int) (objs *GoObjects, err error) {
|
||||
objs = &GoObjects{
|
||||
list: []*GoObject{},
|
||||
}
|
||||
|
||||
var (
|
||||
multiLine bool
|
||||
multiComment bool
|
||||
muiliVar bool
|
||||
)
|
||||
curDepth := -1
|
||||
objLines := []string{}
|
||||
|
||||
for idx, l := range lines {
|
||||
ls := strings.TrimSpace(l)
|
||||
|
||||
ld := lineDepth(l)
|
||||
|
||||
//fmt.Println("l", l)
|
||||
//fmt.Println("> Depth", ld, "???", depth)
|
||||
|
||||
if ld == depth {
|
||||
if strings.HasPrefix(ls, "/*") {
|
||||
multiLine = true
|
||||
multiComment = true
|
||||
} else if strings.HasSuffix(ls, "(") ||
|
||||
strings.HasSuffix(ls, "{") {
|
||||
|
||||
if !multiLine {
|
||||
multiLine = true
|
||||
}
|
||||
} else if strings.Contains(ls, "`") {
|
||||
if !multiLine && strings.Count(ls, "`")%2 != 0 {
|
||||
if muiliVar {
|
||||
muiliVar = false
|
||||
} else {
|
||||
muiliVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//fmt.Println("> multiLine", multiLine)
|
||||
//fmt.Println("> multiComment", multiComment)
|
||||
//fmt.Println("> muiliVar", muiliVar)
|
||||
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if multiComment {
|
||||
if strings.HasSuffix(ls, "*/") {
|
||||
multiComment = false
|
||||
multiLine = false
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(ls, ")") ||
|
||||
strings.HasPrefix(ls, "}") {
|
||||
multiLine = false
|
||||
}
|
||||
}
|
||||
|
||||
if !multiLine && !muiliVar {
|
||||
for eidx := idx + 1; eidx < len(lines); eidx++ {
|
||||
if el := lines[eidx]; strings.TrimSpace(el) == "" {
|
||||
objLines = append(objLines, el)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
//fmt.Println(" > objLines", objLines)
|
||||
|
||||
obj, err := ParseGoObject(objLines, depth)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
err = objs.Add(obj)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
|
||||
objLines = []string{}
|
||||
}
|
||||
|
||||
} else if (multiLine && ld >= curDepth && ld >= depth && len(objLines) > 0) || muiliVar {
|
||||
objLines = append(objLines, l)
|
||||
|
||||
if strings.Contains(ls, "`") {
|
||||
if !multiLine && strings.Count(ls, "`")%2 != 0 {
|
||||
if muiliVar {
|
||||
muiliVar = false
|
||||
} else {
|
||||
muiliVar = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, obj := range objs.List() {
|
||||
children, err := ParseLines(obj.subLines, depth+1)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return objs, err
|
||||
}
|
||||
for _, child := range children.List() {
|
||||
obj.Objects().Add(child)
|
||||
}
|
||||
}
|
||||
|
||||
return objs, nil
|
||||
}
|
||||
|
||||
// ParseGoObject generates a GoObjected for the given code lines.
|
||||
func ParseGoObject(lines []string, depth int) (obj *GoObject, err error) {
|
||||
|
||||
// If there are no lines, return a line break.
|
||||
if len(lines) == 0 {
|
||||
return &GoEmptyLine, nil
|
||||
}
|
||||
|
||||
firstLine := lines[0]
|
||||
firstStrip := strings.TrimSpace(firstLine)
|
||||
|
||||
if len(firstStrip) == 0 {
|
||||
return &GoEmptyLine, nil
|
||||
}
|
||||
|
||||
obj = &GoObject{
|
||||
goObjects: &GoObjects{
|
||||
list: []*GoObject{},
|
||||
},
|
||||
}
|
||||
|
||||
if strings.HasPrefix(firstStrip, "var") {
|
||||
obj.Type = GoObjectType_Var
|
||||
|
||||
if !strings.HasSuffix(firstStrip, "(") {
|
||||
if strings.HasPrefix(firstStrip, "var ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "var ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
}
|
||||
} else if strings.HasPrefix(firstStrip, "const") {
|
||||
obj.Type = GoObjectType_Const
|
||||
|
||||
if !strings.HasSuffix(firstStrip, "(") {
|
||||
if strings.HasPrefix(firstStrip, "const ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "const ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
}
|
||||
} else if strings.HasPrefix(firstStrip, "func") {
|
||||
obj.Type = GoObjectType_Func
|
||||
|
||||
if strings.HasPrefix(firstStrip, "func (") {
|
||||
funcLine := strings.TrimLeft(strings.TrimSpace(strings.Replace(firstStrip, "func ", "", 1)), "(")
|
||||
|
||||
var structName string
|
||||
pts := strings.Split(strings.Split(funcLine, ")")[0], " ")
|
||||
for i := len(pts) - 1; i >= 0; i-- {
|
||||
ptVal := strings.TrimSpace(pts[i])
|
||||
if ptVal != "" {
|
||||
structName = ptVal
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var funcName string
|
||||
pts = strings.Split(strings.Split(funcLine, "(")[0], " ")
|
||||
for i := len(pts) - 1; i >= 0; i-- {
|
||||
ptVal := strings.TrimSpace(pts[i])
|
||||
if ptVal != "" {
|
||||
funcName = ptVal
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
obj.Name = fmt.Sprintf("%s.%s", structName, funcName)
|
||||
} else {
|
||||
obj.Name = strings.Replace(firstStrip, "func ", "", 1)
|
||||
obj.Name = strings.Split(obj.Name, "(")[0]
|
||||
}
|
||||
} else if strings.HasSuffix(firstStrip, "struct {") || strings.HasSuffix(firstStrip, "struct{") {
|
||||
obj.Type = GoObjectType_Struct
|
||||
|
||||
if strings.HasPrefix(firstStrip, "type ") {
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
|
||||
}
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
} else if strings.HasPrefix(firstStrip, "type") {
|
||||
obj.Type = GoObjectType_Type
|
||||
firstStrip = strings.TrimSpace(strings.Replace(firstStrip, "type ", "", 1))
|
||||
obj.Name = strings.Split(firstStrip, " ")[0]
|
||||
} else if strings.HasPrefix(firstStrip, "package") {
|
||||
obj.Name = strings.TrimSpace(strings.Replace(firstStrip, "package ", "", 1))
|
||||
|
||||
obj.Type = GoObjectType_Package
|
||||
} else if strings.HasPrefix(firstStrip, "import") {
|
||||
obj.Type = GoObjectType_Import
|
||||
} else if strings.HasPrefix(firstStrip, "//") {
|
||||
obj.Type = GoObjectType_Comment
|
||||
} else if strings.HasPrefix(firstStrip, "/*") {
|
||||
obj.Type = GoObjectType_Comment
|
||||
} else {
|
||||
if depth > 0 {
|
||||
obj.Type = GoObjectType_Line
|
||||
} else {
|
||||
err = errors.WithStack(errGoParseType)
|
||||
return obj, err
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
hasSub bool
|
||||
muiliVarStart bool
|
||||
muiliVarSub bool
|
||||
muiliVarEnd bool
|
||||
)
|
||||
for _, l := range lines {
|
||||
ld := lineDepth(l)
|
||||
if (ld == depth && !muiliVarSub) || muiliVarStart || muiliVarEnd {
|
||||
if hasSub && !muiliVarStart {
|
||||
if strings.TrimSpace(l) != "" {
|
||||
obj.endLines = append(obj.endLines, l)
|
||||
}
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarEnd {
|
||||
muiliVarEnd = false
|
||||
} else {
|
||||
muiliVarEnd = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
obj.startLines = append(obj.startLines, l)
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarStart {
|
||||
muiliVarStart = false
|
||||
} else {
|
||||
muiliVarStart = true
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if ld > depth || muiliVarSub {
|
||||
obj.subLines = append(obj.subLines, l)
|
||||
hasSub = true
|
||||
|
||||
if strings.Count(l, "`")%2 != 0 {
|
||||
if muiliVarSub {
|
||||
muiliVarSub = false
|
||||
} else {
|
||||
muiliVarSub = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add trailing linebreak
|
||||
if len(obj.endLines) > 0 {
|
||||
obj.endLines = append(obj.endLines, "")
|
||||
}
|
||||
|
||||
return obj, err
|
||||
}
|
||||
|
||||
// lineDepth returns the number of spaces for the given code line
|
||||
// used to determine the code level for nesting objects.
|
||||
func lineDepth(l string) int {
|
||||
depth := len(l) - len(strings.TrimLeftFunc(l, unicode.IsSpace))
|
||||
|
||||
ls := strings.TrimSpace(l)
|
||||
if strings.HasPrefix(ls, "}") && strings.Contains(ls, " else ") {
|
||||
depth = depth + 1
|
||||
} else if strings.HasPrefix(ls, "case ") {
|
||||
depth = depth + 1
|
||||
}
|
||||
return depth
|
||||
}
|
@ -1,201 +0,0 @@
|
||||
package goparse
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var logger *log.Logger
|
||||
|
||||
func init() {
|
||||
logger = log.New(os.Stdout, "", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
}
|
||||
|
||||
func TestMultilineVar(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `func ContextAllowedAccountIds(ctx context.Context, db *gorm.DB) (resp akdatamodels.Uint32List, err error) {
|
||||
resp = []uint32{}
|
||||
accountId := akcontext.ContextAccountId(ctx)
|
||||
m := datamodels.UserAccount{}
|
||||
q := fmt.Sprintf("select
|
||||
distinct account_id
|
||||
from %s where account_id = ?", m.TableName())
|
||||
db = db.Raw(q, accountId)
|
||||
}
|
||||
`
|
||||
code = strings.Replace(code, "\"", "`", -1)
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func TestNewDocImports(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
expected := []string{
|
||||
"package goparse",
|
||||
"",
|
||||
"import (",
|
||||
" \"github.com/go/pkg1\"",
|
||||
" \"github.com/go/pkg2\"",
|
||||
")",
|
||||
"",
|
||||
}
|
||||
|
||||
doc := &GoDocument{}
|
||||
doc.SetPackage("goparse")
|
||||
|
||||
doc.AddImport(GoImport{Name: "github.com/go/pkg1"})
|
||||
doc.AddImport(GoImport{Name: "github.com/go/pkg2"})
|
||||
|
||||
g.Expect(doc.Lines()).Should(gomega.Equal(expected))
|
||||
}
|
||||
|
||||
func TestParseLines1(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
codeTests := []string{
|
||||
`func testCreate(t *testing.T, ctx context.Context, sess *datamodels.Session) *datamodels.Model {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
obj := datamodels.MockModelNew()
|
||||
resp, err := ModelCreate(ctx, DB, &obj)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(resp.Name).Should(gomega.Equal(obj.Name))
|
||||
g.Expect(resp.Status).Should(gomega.Equal(datamodels.{{ .Name }}Status_Active))
|
||||
return resp
|
||||
}
|
||||
`,
|
||||
`var (
|
||||
// ErrNotFound abstracts the postgres not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
`,
|
||||
}
|
||||
|
||||
for _, code := range codeTests {
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseLines2(t *testing.T) {
|
||||
code := `func structToMap(s interface{}) (resp map[string]interface{}) {
|
||||
dat, _ := json.Marshal(s)
|
||||
_ = json.Unmarshal(dat, &resp)
|
||||
for k, x := range resp {
|
||||
switch v := x.(type) {
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
delete(resp, k)
|
||||
}
|
||||
|
||||
case *time.Time:
|
||||
if v == nil || v.IsZero() {
|
||||
delete(resp, k)
|
||||
}
|
||||
|
||||
case nil:
|
||||
delete(resp, k)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
`
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
testLineTextMatches(t, objs.Lines(), lines)
|
||||
}
|
||||
|
||||
func TestParseLines3(t *testing.T) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
code := `type UserAccountRoleName = string
|
||||
|
||||
const (
|
||||
UserAccountRoleName_None UserAccountRoleName = ""
|
||||
UserAccountRoleName_Admin UserAccountRoleName = "admin"
|
||||
UserAccountRoleName_User UserAccountRoleName = "user"
|
||||
)
|
||||
|
||||
type UserAccountRole struct {
|
||||
Id uint32 ^gorm:"column:id;type:int(10) unsigned AUTO_INCREMENT;primary_key;not null;auto_increment;" truss:"internal:true"^
|
||||
CreatedAt time.Time ^gorm:"column:created_at;type:datetime;default:CURRENT_TIMESTAMP;not null;" truss:"internal:true"^
|
||||
UpdatedAt time.Time ^gorm:"column:updated_at;type:datetime;" truss:"internal:true"^
|
||||
DeletedAt *time.Time ^gorm:"column:deleted_at;type:datetime;" truss:"internal:true"^
|
||||
Role UserAccountRoleName ^gorm:"unique_index:user_account_role;column:role;type:enum('admin', 'user')"^
|
||||
// belongs to User
|
||||
User *User ^gorm:"foreignkey:UserId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true"^
|
||||
UserId uint32 ^gorm:"unique_index:user_account_role;"^
|
||||
// belongs to Account
|
||||
Account *Account ^gorm:"foreignkey:AccountId;association_foreignkey:Id;association_autoupdate:false;association_autocreate:false;association_save_reference:false;preload:false;" truss:"internal:true;api_ro:true;"^
|
||||
AccountId uint32 ^gorm:"unique_index:user_account_role;" truss:"internal:true;api_ro:true;"^
|
||||
}
|
||||
|
||||
func (UserAccountRole) TableName() string {
|
||||
return "user_account_roles"
|
||||
}
|
||||
`
|
||||
code = strings.Replace(code, "^", "'", -1)
|
||||
lines := strings.Split(code, "\n")
|
||||
|
||||
objs, err := ParseLines(lines, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v", err)
|
||||
}
|
||||
|
||||
g.Expect(objs.Lines()).Should(gomega.Equal(lines))
|
||||
}
|
||||
|
||||
func testLineTextMatches(t *testing.T, l1, l2 []string) {
|
||||
g := gomega.NewGomegaWithT(t)
|
||||
|
||||
m1 := []string{}
|
||||
for _, l := range l1 {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
m1 = append(m1, l)
|
||||
}
|
||||
}
|
||||
|
||||
m2 := []string{}
|
||||
for _, l := range l2 {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
m2 = append(m2, l)
|
||||
}
|
||||
}
|
||||
|
||||
g.Expect(m1).Should(gomega.Equal(m2))
|
||||
}
|
@ -1,228 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"expvar"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"geeks-accelerator/oss/saas-starter-kit/tools/truss/cmd/dbtable2crud"
|
||||
"github.com/kelseyhightower/envconfig"
|
||||
"github.com/lib/pq"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli"
|
||||
sqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/database/sql"
|
||||
sqlxtrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// build is the git version of this program. It is set using build flags in the makefile.
|
||||
var build = "develop"
|
||||
|
||||
// service is the name of the program used for logging, tracing and the
|
||||
// the prefix used for loading env variables
|
||||
// ie: export TRUSS_ENV=dev
|
||||
var service = "TRUSS"
|
||||
|
||||
func main() {
|
||||
// =========================================================================
|
||||
// Logging
|
||||
|
||||
log := log.New(os.Stdout, service+" : ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile)
|
||||
|
||||
// =========================================================================
|
||||
// Configuration
|
||||
var cfg struct {
|
||||
DB struct {
|
||||
Host string `default:"127.0.0.1:5433" envconfig:"HOST"`
|
||||
User string `default:"postgres" envconfig:"USER"`
|
||||
Pass string `default:"postgres" envconfig:"PASS" json:"-"` // don't print
|
||||
Database string `default:"shared" envconfig:"DATABASE"`
|
||||
Driver string `default:"postgres" envconfig:"DRIVER"`
|
||||
Timezone string `default:"utc" envconfig:"TIMEZONE"`
|
||||
DisableTLS bool `default:"false" envconfig:"DISABLE_TLS"`
|
||||
}
|
||||
}
|
||||
|
||||
// For additional details refer to https://github.com/kelseyhightower/envconfig
|
||||
if err := envconfig.Process(service, &cfg); err != nil {
|
||||
log.Fatalf("main : Parsing Config : %v", err)
|
||||
}
|
||||
|
||||
// TODO: can't use flag.Process here since it doesn't support nested arg options
|
||||
//if err := flag.Process(&cfg); err != nil {
|
||||
/// if err != flag.ErrHelp {
|
||||
// log.Fatalf("main : Parsing Command Line : %v", err)
|
||||
// }
|
||||
// return // We displayed help.
|
||||
//}
|
||||
|
||||
// =========================================================================
|
||||
// Log App Info
|
||||
|
||||
// Print the build version for our logs. Also expose it under /debug/vars.
|
||||
expvar.NewString("build").Set(build)
|
||||
log.Printf("main : Started : Application Initializing version %q", build)
|
||||
defer log.Println("main : Completed")
|
||||
|
||||
// Print the config for our logs. It's important to any credentials in the config
|
||||
// that could expose a security risk are excluded from being json encoded by
|
||||
// applying the tag `json:"-"` to the struct var.
|
||||
{
|
||||
cfgJSON, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
log.Fatalf("main : Marshalling Config to JSON : %v", err)
|
||||
}
|
||||
log.Printf("main : Config : %v\n", string(cfgJSON))
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Start Database
|
||||
var dbUrl url.URL
|
||||
{
|
||||
// Query parameters.
|
||||
var q url.Values = make(map[string][]string)
|
||||
|
||||
// Handle SSL Mode
|
||||
if cfg.DB.DisableTLS {
|
||||
q.Set("sslmode", "disable")
|
||||
} else {
|
||||
q.Set("sslmode", "require")
|
||||
}
|
||||
|
||||
q.Set("timezone", cfg.DB.Timezone)
|
||||
|
||||
// Construct url.
|
||||
dbUrl = url.URL{
|
||||
Scheme: cfg.DB.Driver,
|
||||
User: url.UserPassword(cfg.DB.User, cfg.DB.Pass),
|
||||
Host: cfg.DB.Host,
|
||||
Path: cfg.DB.Database,
|
||||
RawQuery: q.Encode(),
|
||||
}
|
||||
}
|
||||
|
||||
// Register informs the sqlxtrace package of the driver that we will be using in our program.
|
||||
// It uses a default service name, in the below case "postgres.db". To use a custom service
|
||||
// name use RegisterWithServiceName.
|
||||
sqltrace.Register(cfg.DB.Driver, &pq.Driver{}, sqltrace.WithServiceName(service))
|
||||
masterDb, err := sqlxtrace.Open(cfg.DB.Driver, dbUrl.String())
|
||||
if err != nil {
|
||||
log.Fatalf("main : Register DB : %s : %v", cfg.DB.Driver, err)
|
||||
}
|
||||
defer masterDb.Close()
|
||||
|
||||
// =========================================================================
|
||||
// Start Truss
|
||||
|
||||
app := cli.NewApp()
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "dbtable2crud",
|
||||
Aliases: []string{},
|
||||
Usage: "-table=projects -file=../../internal/project/models.go -model=Project [-dbtable=TABLE] [-templateDir=DIR] [-projectPath=DIR] [-saveChanges=false] ",
|
||||
Flags: []cli.Flag{
|
||||
cli.StringFlag{Name: "dbtable, table"},
|
||||
cli.StringFlag{Name: "modelFile, modelfile, file"},
|
||||
cli.StringFlag{Name: "modelName, modelname, model"},
|
||||
cli.StringFlag{Name: "templateDir, templates", Value: "./templates/dbtable2crud"},
|
||||
cli.StringFlag{Name: "projectPath"},
|
||||
cli.BoolFlag{Name: "saveChanges, save"},
|
||||
},
|
||||
Action: func(c *cli.Context) error {
|
||||
dbTable := strings.TrimSpace(c.String("dbtable"))
|
||||
modelFile := strings.TrimSpace(c.String("modelFile"))
|
||||
modelName := strings.TrimSpace(c.String("modelName"))
|
||||
templateDir := strings.TrimSpace(c.String("templateDir"))
|
||||
projectPath := strings.TrimSpace(c.String("projectPath"))
|
||||
|
||||
pwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to get current working directory")
|
||||
}
|
||||
|
||||
if !path.IsAbs(templateDir) {
|
||||
templateDir = filepath.Join(pwd, templateDir)
|
||||
}
|
||||
ok, err := exists(templateDir)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load template directory")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Template directory %s does not exist", templateDir)
|
||||
}
|
||||
|
||||
if modelFile == "" {
|
||||
return errors.Errorf("Model file path is required")
|
||||
}
|
||||
|
||||
if !path.IsAbs(modelFile) {
|
||||
modelFile = filepath.Join(pwd, modelFile)
|
||||
}
|
||||
ok, err = exists(modelFile)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load model file")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Model file %s does not exist", modelFile)
|
||||
}
|
||||
|
||||
// Load the project path from go.mod if not set.
|
||||
if projectPath == "" {
|
||||
goModFile := filepath.Join(pwd, "../../go.mod")
|
||||
ok, err = exists(goModFile)
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, "Failed to load go.mod for project")
|
||||
} else if !ok {
|
||||
return errors.Errorf("Failed to locate project go.mod at %s", goModFile)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadFile(goModFile)
|
||||
if err != nil {
|
||||
return errors.WithMessagef(err, "Failed to read go.mod at %s", goModFile)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(b), "\n")
|
||||
for _, l := range lines {
|
||||
if strings.HasPrefix(l, "module ") {
|
||||
projectPath = strings.TrimSpace(strings.Split(l, " ")[1])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if modelName == "" {
|
||||
modelName = strings.Split(filepath.Base(modelFile), ".")[0]
|
||||
modelName = strings.Replace(modelName, "_", " ", -1)
|
||||
modelName = strings.Replace(modelName, "-", " ", -1)
|
||||
modelName = strings.Title(modelName)
|
||||
modelName = strings.Replace(modelName, " ", "", -1)
|
||||
}
|
||||
|
||||
return dbtable2crud.Run(masterDb, log, cfg.DB.Database, dbTable, modelFile, modelName, templateDir, projectPath, c.Bool("saveChanges"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = app.Run(os.Args)
|
||||
if err != nil {
|
||||
log.Fatalf("main : Truss : %+v", err)
|
||||
}
|
||||
|
||||
log.Printf("main : Truss : Completed")
|
||||
}
|
||||
|
||||
// exists returns a bool as to whether a file path exists.
|
||||
func exists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return true, err
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
SHELL := /bin/bash
|
||||
|
||||
install:
|
||||
go install .
|
||||
|
||||
build:
|
||||
go install .
|
||||
|
||||
run:
|
||||
go build . && ./truss
|
@ -1,22 +0,0 @@
|
||||
# Variables to configure Postgres for database migration.
|
||||
export TRUSS_DB_HOST=127.0.0.1:5433
|
||||
export TRUSS_DB_USER=postgres
|
||||
export TRUSS_DB_PASS=postgres
|
||||
export TRUSS_DB_DISABLE_TLS=true
|
||||
|
||||
# Variables to configure AWS for service build and deploy.
|
||||
# Use the same set for AWS credentials for all target envinments.
|
||||
#AWS_ACCESS_KEY_ID=XXXXXXXXXXXXXX
|
||||
#AWS_SECRET_ACCESS_KEY=XXXXXXXXXXXXXX
|
||||
#AWS_REGION=us-west-2
|
||||
|
||||
# AWS credentials can be prefixed with the target uppercased target envinments.
|
||||
# This allows credentials unique accounts to be used for each target envinments.
|
||||
# Default target envinments are: DEV, STAGE, PROD
|
||||
#DEV_AWS_ACCESS_KEY_ID=XXXXXXXXXXXXXX
|
||||
#DEV_AWS_SECRET_ACCESS_KEY=XXXXXXXXXXXXXX
|
||||
#DEV_AWS_REGION=us-west-2
|
||||
|
||||
# GitLab CI/CD environment variables. These are set by the GitLab when the build
|
||||
# pipeline is running. These can be optional set for testing/debugging locally.
|
||||
#CI_COMMIT_REF_NAME=master
|
@ -1,510 +0,0 @@
|
||||
{{ define "imports"}}
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"{{ $.GoSrcPath }}/internal/platform/auth"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
"gopkg.in/go-playground/validator.v9"
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Globals"}}
|
||||
const (
|
||||
// The database table for {{ $.Model.Name }}
|
||||
{{ FormatCamelLower $.Model.Name }}TableName = "{{ $.Model.TableName }}"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotFound abstracts the postgres not found error.
|
||||
ErrNotFound = errors.New("Entity not found")
|
||||
|
||||
// ErrInvalidID occurs when an ID is not in a valid form.
|
||||
ErrInvalidID = errors.New("ID is not in its proper form")
|
||||
|
||||
// ErrForbidden occurs when a user tries to do something that is forbidden to them according to our access control policies.
|
||||
ErrForbidden = errors.New("Attempted action is not allowed")
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Helpers"}}
|
||||
// {{ FormatCamelLower $.Model.Name }}MapColumns is the list of columns needed for mapRowsTo{{ $.Model.Name }}
|
||||
var {{ FormatCamelLower $.Model.Name }}MapColumns = "{{ JoinStrings $.Model.ColumnNames "," }}"
|
||||
|
||||
// mapRowsTo{{ $.Model.Name }} takes the SQL rows and maps it to the {{ $.Model.Name }} struct
|
||||
// with the columns defined by {{ FormatCamelLower $.Model.Name }}MapColumns
|
||||
func mapRowsTo{{ $.Model.Name }}(rows *sql.Rows) (*{{ $.Model.Name }}, error) {
|
||||
var (
|
||||
m {{ $.Model.Name }}
|
||||
err error
|
||||
)
|
||||
err = rows.Scan({{ PrefixAndJoinStrings $.Model.FieldNames "&m." "," }})
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "ACL"}}
|
||||
{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
// CanRead{{ $.Model.Name }} determines if claims has the authority to access the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanRead{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
|
||||
{{ if $hasAccountID }}
|
||||
// If the request has claims from a specific {{ FormatCamelLower $.Model.Name }}, ensure that the claims
|
||||
// has the correct access to the {{ FormatCamelLower $.Model.Name }}.
|
||||
if claims.Audience != "" {
|
||||
// select {{ $.Model.PrimaryColumn }} from {{ $.Model.TableName }} where account_id = [accountID]
|
||||
query := sqlbuilder.NewSelectBuilder().Select("{{ $.Model.PrimaryColumn }}").From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Where(query.And(
|
||||
query.Equal("account_id", claims.Audience),
|
||||
query.Equal("{{ $.Model.PrimaryField }}", {{ FormatCamelLower $.Model.PrimaryField }}),
|
||||
))
|
||||
queryStr, args := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
|
||||
var {{ FormatCamelLower $.Model.PrimaryField }} string
|
||||
err := dbConn.QueryRowContext(ctx, queryStr, args...).Scan(&{{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return err
|
||||
}
|
||||
|
||||
// When there is no {{ $.Model.PrimaryColumn }} returned, then the current claim user does not have access
|
||||
// to the specified {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
if {{ FormatCamelLower $.Model.PrimaryField }} == "" {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
}
|
||||
{{ else }}
|
||||
// TODO: Unable to auto generate sql statement, update accordingly.
|
||||
panic("Not implemented!")
|
||||
{{ end }}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CanModify{{ $.Model.Name }} determines if claims has the authority to modify the specified {{ FormatCamelLowerTitle $.Model.Name}} by {{ $.Model.PrimaryColumn }}.
|
||||
func CanModify{{ $.Model.Name }}(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} string) error {
|
||||
err := CanRead{{ $.Model.Name }}(ctx, claims, dbConn, {{ FormatCamelLower $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyClaimsSelect applies a sub-query to the provided query to enforce ACL based on the claims provided.
|
||||
// 1. No claims, request is internal, no ACL applied
|
||||
{{ if $hasAccountID }}
|
||||
// 2. All role types can access their user ID
|
||||
{{ end }}
|
||||
func applyClaimsSelect(ctx context.Context, claims auth.Claims, query *sqlbuilder.SelectBuilder) error {
|
||||
// Claims are empty, don't apply any ACL
|
||||
if claims.Audience == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ if $hasAccountID }}
|
||||
query.Where(query.Equal("account_id", claims.Audience))
|
||||
{{ end }}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Find"}}
|
||||
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}
|
||||
|
||||
// selectQuery constructs a base select query for {{ $.Model.Name }}
|
||||
func selectQuery() *sqlbuilder.SelectBuilder {
|
||||
query := sqlbuilder.NewSelectBuilder()
|
||||
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
|
||||
query.From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
return query
|
||||
}
|
||||
|
||||
// findRequestQuery generates the select query for the given find request.
|
||||
// TODO: Need to figure out why can't parse the args when appending the where
|
||||
// to the query.
|
||||
func findRequestQuery(req {{ $.Model.Name }}FindRequest) (*sqlbuilder.SelectBuilder, []interface{}) {
|
||||
query := selectQuery()
|
||||
if req.Where != nil {
|
||||
query.Where(query.And(*req.Where))
|
||||
}
|
||||
if len(req.Order) > 0 {
|
||||
query.OrderBy(req.Order...)
|
||||
}
|
||||
if req.Limit != nil {
|
||||
query.Limit(int(*req.Limit))
|
||||
}
|
||||
if req.Offset != nil {
|
||||
query.Offset(int(*req.Offset))
|
||||
}
|
||||
|
||||
return query, req.Args
|
||||
}
|
||||
|
||||
// Find gets all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database based on the request params.
|
||||
func Find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $.Model.Name }}FindRequest) ([]*{{ $.Model.Name }}, error) {
|
||||
query, args := findRequestQuery(req)
|
||||
return find(ctx, claims, dbConn, query, args{{ if $hasArchived }}, req.IncludedArchived {{ end }})
|
||||
}
|
||||
|
||||
// find internal method for getting all the {{ FormatCamelPluralTitleLower $.Model.Name }} from the database using a select query.
|
||||
func find(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, query *sqlbuilder.SelectBuilder, args []interface{}{{ if $hasArchived }}, includedArchived bool{{ end }}) ([]*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Find")
|
||||
defer span.Finish()
|
||||
|
||||
query.Select({{ FormatCamelLower $.Model.Name }}MapColumns)
|
||||
query.From({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
|
||||
{{ if $hasArchived }}
|
||||
if !includedArchived {
|
||||
query.Where(query.IsNull("archived_at"))
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
// Check to see if a sub query needs to be applied for the claims.
|
||||
err := applyClaimsSelect(ctx, claims, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queryStr, queryArgs := query.Build()
|
||||
queryStr = dbConn.Rebind(queryStr)
|
||||
args = append(args, queryArgs...)
|
||||
|
||||
// Fetch all entries from the db.
|
||||
rows, err := dbConn.QueryContext(ctx, queryStr, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "find {{ FormatCamelPluralTitleLower $.Model.Name }} failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate over each row.
|
||||
resp := []*{{ $.Model.Name }}{}
|
||||
for rows.Next() {
|
||||
u, err := mapRowsTo{{ $.Model.Name }}(rows)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
return nil, err
|
||||
}
|
||||
resp = append(resp, u)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Read gets the specified {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Read(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}{{ if $hasArchived }}, includedArchived bool{{ end }}) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Read")
|
||||
defer span.Finish()
|
||||
|
||||
// Filter base select query by {{ FormatCamelLower $.Model.PrimaryField }}
|
||||
query := selectQuery()
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", {{ FormatCamelLower $.Model.PrimaryField }}))
|
||||
|
||||
res, err := find(ctx, claims, dbConn, query, []interface{}{} {{ if $hasArchived }}, includedArchived{{ end }})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if res == nil || len(res) == 0 {
|
||||
err = errors.WithMessagef(ErrNotFound, "{{ FormatCamelLowerTitle $.Model.Name }} %s not found", id)
|
||||
return nil, err
|
||||
}
|
||||
u := res[0]
|
||||
|
||||
return u, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Create"}}
|
||||
{{ $hasAccountID := (StringListHasValue $.Model.ColumnNames "account_id") }}
|
||||
{{ $reqName := (Concat $.Model.Name "CreateRequest") }}
|
||||
{{ $createFields := (index $.StructFields $reqName) }}
|
||||
{{ $reqHasAccountID := false }}{{ $reqAccountID := (index $createFields "AccountID") }}{{ if $reqAccountID }}{{ $reqHasAccountID = true }}{{ end }}
|
||||
// Create inserts a new {{ FormatCamelLowerTitle $.Model.Name }} into the database.
|
||||
func Create(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) (*{{ $.Model.Name }}, error) {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Create")
|
||||
defer span.Finish()
|
||||
|
||||
if claims.Audience != "" {
|
||||
// Admin users can update {{ FormatCamelPluralTitleLower $.Model.Name }} they have access to.
|
||||
if !claims.HasRole(auth.RoleAdmin) {
|
||||
return nil, errors.WithStack(ErrForbidden)
|
||||
}
|
||||
|
||||
{{ if $reqHasAccountID }}
|
||||
if req.AccountID != "" {
|
||||
// Request accountId must match claims.
|
||||
if req.AccountID != claims.Audience {
|
||||
return nil, errors.WithStack(ErrForbidden)
|
||||
}
|
||||
} else {
|
||||
// Set the accountId from claims.
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
m := {{ $.Model.Name }}{
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}
|
||||
{{ if eq $mf.FieldName $.Model.PrimaryField }}{{ $isUuid := (FieldTagHasOption $mf "validate" "uuid") }}{{ $mf.FieldName }}: {{ if $isUuid }}uuid.NewRandom().String(){{ else }}req.{{ $mf.FieldName }}{{ end }},
|
||||
{{ else if or (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}{{ $mf.FieldName }}: now,
|
||||
{{ else if $cf }}{{ $required := (FieldTagHasOption $cf "validate" "required") }}{{ if $required }}{{ $cf.FieldName }}: req.{{ $cf.FieldName }},{{ else if ne $cf.DefaultValue "" }}{{ $cf.FieldName }}: {{ $cf.DefaultValue }},{{ end }}
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
|
||||
{{ if and (not $reqHasAccountID) ($hasAccountID) }}
|
||||
// Set the accountId from claims.
|
||||
if claims.Audience != "" && m.AccountID == "" {
|
||||
req.AccountID = claims.Audience
|
||||
}
|
||||
{{ end }}
|
||||
|
||||
{{ range $fk, $f := $createFields }}{{ $required := (FieldTagHasOption $f "validate" "required") }}{{ if not $required }}
|
||||
if req.{{ $f.FieldName }} != nil {
|
||||
{{ if eq $f.FieldType "sql.NullString" }}
|
||||
m.{{ $f.FieldName }} = sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
|
||||
{{ else if eq $f.FieldType "*sql.NullString" }}
|
||||
m.{{ $f.FieldName }} = &sql.NullString{String: *req.{{ $f.FieldName }}, Valid: true}
|
||||
{{ else }}
|
||||
m.{{ $f.FieldName }} = *req.{{ $f.FieldName }}
|
||||
{{ end }}
|
||||
}
|
||||
{{ end }}{{ end }}
|
||||
|
||||
// Build the insert SQL statement.
|
||||
query := sqlbuilder.NewInsertBuilder()
|
||||
query.InsertInto({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Cols(
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}"{{ $mf.ColumnName }}",
|
||||
{{ end }}{{ end }}
|
||||
)
|
||||
query.Values(
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $cf := (index $createFields $mf.FieldName) }}{{ if or (eq $mf.FieldName $.Model.PrimaryField) ($cf) (eq $mf.FieldName "CreatedAt") (eq $mf.FieldName "UpdatedAt") }}m.{{ $mf.FieldName }},
|
||||
{{ end }}{{ end }}
|
||||
)
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessage(err, "create {{ FormatCamelLowerTitle $.Model.Name }} failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Update"}}
|
||||
{{ $reqName := (Concat $.Model.Name "UpdateRequest") }}
|
||||
{{ $updateFields := (index $.StructFields $reqName) }}
|
||||
// Update replaces an {{ FormatCamelLowerTitle $.Model.Name }} in the database.
|
||||
func Update(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, req {{ $reqName }}, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Update")
|
||||
defer span.Finish()
|
||||
|
||||
v := validator.New()
|
||||
|
||||
// Validate the request.
|
||||
err := v.Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.{{ $.Model.PrimaryField }})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
|
||||
var fields []string
|
||||
|
||||
{{ range $mk, $mf := $.Model.Fields }}{{ $uf := (index $updateFields $mf.FieldName) }}{{ if and ($uf.FieldName) (ne $uf.FieldName $.Model.PrimaryField) }}
|
||||
{{ $optional := (FieldTagHasOption $uf "validate" "omitempty") }}{{ $isUuid := (FieldTagHasOption $uf "validate" "uuid") }}
|
||||
if req.{{ $uf.FieldName }} != nil {
|
||||
{{ if and ($optional) ($isUuid) }}
|
||||
if *req.{{ $uf.FieldName }} != "" {
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
|
||||
} else {
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", nil))
|
||||
}
|
||||
{{ else }}
|
||||
fields = append(fields, query.Assign("{{ $uf.ColumnName }}", req.{{ $uf.FieldName }}))
|
||||
{{ end }}
|
||||
}
|
||||
{{ end }}{{ end }}
|
||||
|
||||
// If there's nothing to update we can quit early.
|
||||
if len(fields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
{{ $hasUpdatedAt := (StringListHasValue $.Model.ColumnNames "updated_at") }}{{ if $hasUpdatedAt }}
|
||||
// Append the updated_at field
|
||||
fields = append(fields, query.Assign("updated_at", now))
|
||||
{{ end }}
|
||||
|
||||
query.Set(fields...)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "update {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Archive"}}
|
||||
// Archive soft deleted the {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Archive(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}, now time.Time) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Archive")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
|
||||
}{
|
||||
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If now empty set it to the current time.
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// Always store the time as UTC.
|
||||
now = now.UTC()
|
||||
|
||||
// Postgres truncates times to milliseconds when storing. We and do the same
|
||||
// here so the value we return is consistent with what we store.
|
||||
now = now.Truncate(time.Millisecond)
|
||||
|
||||
// Build the update SQL statement.
|
||||
query := sqlbuilder.NewUpdateBuilder()
|
||||
query.Update({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Set(
|
||||
query.Assign("archived_at", now),
|
||||
)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "archive {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Delete"}}
|
||||
// Delete removes an {{ FormatCamelLowerTitle $.Model.Name }} from the database.
|
||||
func Delete(ctx context.Context, claims auth.Claims, dbConn *sqlx.DB, {{ FormatCamelLower $.Model.PrimaryField }} {{ $.Model.PrimaryType }}) error {
|
||||
span, ctx := tracer.StartSpanFromContext(ctx, "internal.{{ FormatCamelLowerUnderscore $.Model.Name }}.Delete")
|
||||
defer span.Finish()
|
||||
|
||||
// Defines the struct to apply validation
|
||||
req := struct {
|
||||
{{ $.Model.PrimaryField }} {{ $.Model.PrimaryType }} `validate:"required,uuid"`
|
||||
}{
|
||||
{{ $.Model.PrimaryField }}: {{ FormatCamelLower $.Model.PrimaryField }},
|
||||
}
|
||||
|
||||
// Validate the request.
|
||||
err := validator.New().Struct(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure the claims can modify the {{ FormatCamelLowerTitle $.Model.Name }} specified in the request.
|
||||
err = CanModify{{ $.Model.Name }}(ctx, claims, dbConn, req.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the delete SQL statement.
|
||||
query := sqlbuilder.NewDeleteBuilder()
|
||||
query.DeleteFrom({{ FormatCamelLower $.Model.Name }}TableName)
|
||||
query.Where(query.Equal("{{ $.Model.PrimaryColumn }}", req.{{ $.Model.PrimaryField }}))
|
||||
|
||||
// Execute the query with the provided context.
|
||||
sql, args := query.Build()
|
||||
sql = dbConn.Rebind(sql)
|
||||
_, err = dbConn.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "query - %s", query.String())
|
||||
err = errors.WithMessagef(err, "delete {{ FormatCamelLowerTitle $.Model.Name }} %s failed", req.{{ $.Model.PrimaryField }})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
{{ end }}
|
@ -1,131 +0,0 @@
|
||||
{{ define "imports"}}
|
||||
import (
|
||||
"{{ $.GoSrcPath }}/internal/platform/auth"
|
||||
"{{ $.GoSrcPath }}/internal/platform/tests"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/huandu/go-sqlbuilder"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
{{ end }}
|
||||
{{ define "Globals"}}
|
||||
var test *tests.Test
|
||||
|
||||
// TestMain is the entry point for testing.
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(testMain(m))
|
||||
}
|
||||
|
||||
func testMain(m *testing.M) int {
|
||||
test = tests.New()
|
||||
defer test.TearDown()
|
||||
return m.Run()
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "TestFindRequestQuery"}}
|
||||
// TestFindRequestQuery validates findRequestQuery
|
||||
func TestFindRequestQuery(t *testing.T) {
|
||||
where := "field1 = ? or field2 = ?"
|
||||
var (
|
||||
limit uint = 12
|
||||
offset uint = 34
|
||||
)
|
||||
|
||||
req := {{ $.Model.Name }}FindRequest{
|
||||
Where: &where,
|
||||
Args: []interface{}{
|
||||
"lee brown",
|
||||
"103 East Main St.",
|
||||
},
|
||||
Order: []string{
|
||||
"id asc",
|
||||
"created_at desc",
|
||||
},
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
}
|
||||
expected := "SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE (field1 = ? or field2 = ?) ORDER BY id asc, created_at desc LIMIT 12 OFFSET 34"
|
||||
|
||||
res, args := findRequestQuery(req)
|
||||
|
||||
if diff := cmp.Diff(res.String(), expected); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
if diff := cmp.Diff(args, req.Args); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "TestApplyClaimsSelect"}}
|
||||
// TestApplyClaimsSelect applyClaimsSelect
|
||||
func TestApplyClaimsSelect(t *testing.T) {
|
||||
var claimTests = []struct {
|
||||
name string
|
||||
claims auth.Claims
|
||||
expectedSql string
|
||||
error error
|
||||
}{
|
||||
{"EmptyClaims",
|
||||
auth.Claims{},
|
||||
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName,
|
||||
nil,
|
||||
},
|
||||
{"RoleAccount",
|
||||
auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: "user1",
|
||||
Audience: "acc1",
|
||||
},
|
||||
},
|
||||
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
|
||||
nil,
|
||||
},
|
||||
{"RoleAdmin",
|
||||
auth.Claims{
|
||||
Roles: []string{auth.RoleAdmin},
|
||||
StandardClaims: jwt.StandardClaims{
|
||||
Subject: "user1",
|
||||
Audience: "acc1",
|
||||
},
|
||||
},
|
||||
"SELECT " + {{ FormatCamelLower $.Model.Name }}MapColumns + " FROM " + {{ FormatCamelLower $.Model.Name }}TableName + " WHERE account_id = 'acc1'",
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
t.Log("Given the need to validate ACLs are enforced by claims to a select query.")
|
||||
{
|
||||
for i, tt := range claimTests {
|
||||
t.Logf("\tTest: %d\tWhen running test: %s", i, tt.name)
|
||||
{
|
||||
ctx := tests.Context()
|
||||
|
||||
query := selectQuery()
|
||||
|
||||
err := applyClaimsSelect(ctx, tt.claims, query)
|
||||
if err != tt.error {
|
||||
t.Logf("\t\tGot : %+v", err)
|
||||
t.Logf("\t\tWant: %+v", tt.error)
|
||||
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
|
||||
}
|
||||
|
||||
sql, args := query.Build()
|
||||
|
||||
// Use mysql flavor so placeholders will get replaced for comparison.
|
||||
sql, err = sqlbuilder.MySQL.Interpolate(sql, args)
|
||||
if err != nil {
|
||||
t.Log("\t\tGot :", err)
|
||||
t.Fatalf("\t%s\tapplyClaimsSelect failed.", tests.Failed)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(sql, tt.expectedSql); diff != "" {
|
||||
t.Fatalf("\t%s\tExpected result query to match. Diff:\n%s", tests.Failed, diff)
|
||||
}
|
||||
|
||||
t.Logf("\t%s\tapplyClaimsSelect ok.", tests.Success)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
{{ end }}
|
@ -1,80 +0,0 @@
|
||||
{{ define "CreateRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}CreateRequest contains information needed to create a new {{ FormatCamel $.Model.Name }}.
|
||||
type {{ FormatCamel $.Model.Name }}CreateRequest struct {
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if and ($f.ApiCreate) (ne $f.FieldName $.Model.PrimaryField) }}{{ $optional := (FieldTagHasOption $f "validate" "omitempty") }}
|
||||
{{ $f.FieldName }} {{ if and ($optional) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ FieldTag $f "validate" }}`
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "UpdateRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}UpdateRequest defines what information may be provided to modify an existing
|
||||
// {{ FormatCamel $.Model.Name }}. All fields are optional so clients can send just the fields they want
|
||||
// changed. It uses pointer fields so we can differentiate between a field that
|
||||
// was not provided and a field that was provided as explicitly blank.
|
||||
type {{ FormatCamel $.Model.Name }}UpdateRequest struct {
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if $f.ApiUpdate }}
|
||||
{{ $f.FieldName }} {{ if and (ne $f.FieldName $.Model.PrimaryField) (not $f.FieldIsPtr) }}*{{ end }}{{ $f.FieldType }} `json:"{{ $f.ColumnName }}" {{ if ne $f.FieldName $.Model.PrimaryField }}{{ FieldTagReplaceOrPrepend $f "validate" "required" "omitempty" }}{{ else }}{{ FieldTagReplaceOrPrepend $f "validate" "omitempty" "required" }}{{ end }}`
|
||||
{{ end }}{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "FindRequest"}}
|
||||
// {{ FormatCamel $.Model.Name }}FindRequest defines the possible options to search for {{ FormatCamelPluralTitleLower $.Model.Name }}. By default
|
||||
// archived {{ FormatCamelLowerTitle $.Model.Name }} will be excluded from response.
|
||||
type {{ FormatCamel $.Model.Name }}FindRequest struct {
|
||||
Where *string `schema:"where"`
|
||||
Args []interface{} `schema:"args"`
|
||||
Order []string `schema:"order"`
|
||||
Limit *uint `schema:"limit"`
|
||||
Offset *uint `schema:"offset"`
|
||||
IncludedArchived bool
|
||||
{{ $hasArchived := (StringListHasValue $.Model.ColumnNames "archived_at") }}{{ if $hasArchived }}IncludedArchived bool `schema:"included-archived"`{{ end }}
|
||||
}
|
||||
{{ end }}
|
||||
{{ define "Enums"}}
|
||||
{{ range $fk, $f := .Model.Fields }}{{ if $f.DbColumn }}{{ if $f.DbColumn.IsEnum }}
|
||||
// {{ $f.FieldType }} represents the {{ $f.ColumnName }} of {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
type {{ $f.FieldType }} string
|
||||
|
||||
// {{ $f.FieldType }} values define the {{ $f.ColumnName }} field of {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
const (
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
// {{ $f.FieldType }}_{{ FormatCamel $ev }} defines the {{ $f.ColumnName }} of {{ $ev }} for {{ FormatCamelLowerTitle $.Model.Name }}.
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }} {{ $f.FieldType }} = "{{ $ev }}"
|
||||
{{ end }}
|
||||
)
|
||||
|
||||
// {{ $f.FieldType }}_Values provides list of valid {{ $f.FieldType }} values.
|
||||
var {{ $f.FieldType }}_Values = []{{ $f.FieldType }}{
|
||||
{{ range $evk, $ev := $f.DbColumn.EnumValues }}
|
||||
{{ $f.FieldType }}_{{ FormatCamel $ev }},
|
||||
{{ end }}
|
||||
}
|
||||
|
||||
// Scan supports reading the {{ $f.FieldType }} value from the database.
|
||||
func (s *{{ $f.FieldType }}) Scan(value interface{}) error {
|
||||
asBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source is not []byte")
|
||||
}
|
||||
*s = {{ $f.FieldType }}(string(asBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the {{ $f.FieldType }} value to be stored in the database.
|
||||
func (s {{ $f.FieldType }}) Value() (driver.Value, error) {
|
||||
v := validator.New()
|
||||
|
||||
errs := v.Var(s, "required,oneof={{ JoinStrings $f.DbColumn.EnumValues " " }}")
|
||||
if errs != nil {
|
||||
return nil, errs
|
||||
}
|
||||
|
||||
return string(s), nil
|
||||
}
|
||||
|
||||
// String converts the {{ $f.FieldType }} value to a string.
|
||||
func (s {{ $f.FieldType }}) String() string {
|
||||
return string(s)
|
||||
}
|
||||
{{ end }}{{ end }}{{ end }}
|
||||
{{ end }}
|
Reference in New Issue
Block a user